aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java')
-rw-r--r--subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java18
1 files changed, 11 insertions, 7 deletions
diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java
index 93e59291..2aa0a0d5 100644
--- a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java
+++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java
@@ -14,12 +14,12 @@ import tools.refinery.store.query.term.Variable;
14 14
15import java.util.List; 15import java.util.List;
16import java.util.Objects; 16import java.util.Objects;
17import java.util.Set;
18 17
19public class AggregationLiteral<R, T> extends AbstractCallLiteral { 18public class AggregationLiteral<R, T> extends AbstractCallLiteral {
20 private final DataVariable<R> resultVariable; 19 private final DataVariable<R> resultVariable;
21 private final DataVariable<T> inputVariable; 20 private final DataVariable<T> inputVariable;
22 private final Aggregator<R, T> aggregator; 21 private final Aggregator<R, T> aggregator;
22 private final VariableBinder variableBinder;
23 23
24 public AggregationLiteral(DataVariable<R> resultVariable, Aggregator<R, T> aggregator, 24 public AggregationLiteral(DataVariable<R> resultVariable, Aggregator<R, T> aggregator,
25 DataVariable<T> inputVariable, Constraint target, List<Variable> arguments) { 25 DataVariable<T> inputVariable, Constraint target, List<Variable> arguments) {
@@ -32,10 +32,6 @@ public class AggregationLiteral<R, T> extends AbstractCallLiteral {
32 throw new IllegalArgumentException("Result variable %s must of type %s, got %s instead".formatted( 32 throw new IllegalArgumentException("Result variable %s must of type %s, got %s instead".formatted(
33 resultVariable, aggregator.getResultType().getName(), resultVariable.getType().getName())); 33 resultVariable, aggregator.getResultType().getName(), resultVariable.getType().getName()));
34 } 34 }
35 if (!arguments.contains(inputVariable)) {
36 throw new IllegalArgumentException("Input variable %s must appear in the argument list".formatted(
37 inputVariable));
38 }
39 if (arguments.contains(resultVariable)) { 35 if (arguments.contains(resultVariable)) {
40 throw new IllegalArgumentException("Result variable %s must not appear in the argument list".formatted( 36 throw new IllegalArgumentException("Result variable %s must not appear in the argument list".formatted(
41 resultVariable)); 37 resultVariable));
@@ -43,6 +39,14 @@ public class AggregationLiteral<R, T> extends AbstractCallLiteral {
43 this.resultVariable = resultVariable; 39 this.resultVariable = resultVariable;
44 this.inputVariable = inputVariable; 40 this.inputVariable = inputVariable;
45 this.aggregator = aggregator; 41 this.aggregator = aggregator;
42 variableBinder = VariableBinder.builder()
43 .variable(resultVariable, VariableDirection.OUT)
44 .parameterList(false, target.getParameters(), arguments)
45 .build();
46 if (variableBinder.getDirection(inputVariable) != VariableDirection.CLOSURE) {
47 throw new IllegalArgumentException("Input variable %s must appear in the argument list".formatted(
48 inputVariable));
49 }
46 } 50 }
47 51
48 public DataVariable<R> getResultVariable() { 52 public DataVariable<R> getResultVariable() {
@@ -58,8 +62,8 @@ public class AggregationLiteral<R, T> extends AbstractCallLiteral {
58 } 62 }
59 63
60 @Override 64 @Override
61 public Set<Variable> getBoundVariables() { 65 public VariableBinder getVariableBinder() {
62 return Set.of(resultVariable); 66 return variableBinder;
63 } 67 }
64 68
65 @Override 69 @Override