aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java')
-rw-r--r--subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java36
1 files changed, 18 insertions, 18 deletions
diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java
index 3a5eb5c7..e3acfacc 100644
--- a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java
+++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java
@@ -6,7 +6,9 @@
6package tools.refinery.store.query.literal; 6package tools.refinery.store.query.literal;
7 7
8import tools.refinery.store.query.Constraint; 8import tools.refinery.store.query.Constraint;
9import tools.refinery.store.query.InvalidQueryException;
9import tools.refinery.store.query.equality.LiteralEqualityHelper; 10import tools.refinery.store.query.equality.LiteralEqualityHelper;
11import tools.refinery.store.query.equality.LiteralHashCodeHelper;
10import tools.refinery.store.query.substitution.Substitution; 12import tools.refinery.store.query.substitution.Substitution;
11import tools.refinery.store.query.term.*; 13import tools.refinery.store.query.term.*;
12 14
@@ -14,6 +16,8 @@ import java.util.List;
14import java.util.Objects; 16import java.util.Objects;
15import java.util.Set; 17import java.util.Set;
16 18
19// {@link Object#equals(Object)} is implemented by {@link AbstractLiteral}.
20@SuppressWarnings("squid:S2160")
17public class AggregationLiteral<R, T> extends AbstractCallLiteral { 21public class AggregationLiteral<R, T> extends AbstractCallLiteral {
18 private final DataVariable<R> resultVariable; 22 private final DataVariable<R> resultVariable;
19 private final DataVariable<T> inputVariable; 23 private final DataVariable<T> inputVariable;
@@ -23,19 +27,19 @@ public class AggregationLiteral<R, T> extends AbstractCallLiteral {
23 DataVariable<T> inputVariable, Constraint target, List<Variable> arguments) { 27 DataVariable<T> inputVariable, Constraint target, List<Variable> arguments) {
24 super(target, arguments); 28 super(target, arguments);
25 if (!inputVariable.getType().equals(aggregator.getInputType())) { 29 if (!inputVariable.getType().equals(aggregator.getInputType())) {
26 throw new IllegalArgumentException("Input variable %s must of type %s, got %s instead".formatted( 30 throw new InvalidQueryException("Input variable %s must of type %s, got %s instead".formatted(
27 inputVariable, aggregator.getInputType().getName(), inputVariable.getType().getName())); 31 inputVariable, aggregator.getInputType().getName(), inputVariable.getType().getName()));
28 } 32 }
29 if (!getArgumentsOfDirection(ParameterDirection.OUT).contains(inputVariable)) { 33 if (!getArgumentsOfDirection(ParameterDirection.OUT).contains(inputVariable)) {
30 throw new IllegalArgumentException("Input variable %s must be bound with direction %s in the argument list" 34 throw new InvalidQueryException("Input variable %s must be bound with direction %s in the argument list"
31 .formatted(inputVariable, ParameterDirection.OUT)); 35 .formatted(inputVariable, ParameterDirection.OUT));
32 } 36 }
33 if (!resultVariable.getType().equals(aggregator.getResultType())) { 37 if (!resultVariable.getType().equals(aggregator.getResultType())) {
34 throw new IllegalArgumentException("Result variable %s must of type %s, got %s instead".formatted( 38 throw new InvalidQueryException("Result variable %s must of type %s, got %s instead".formatted(
35 resultVariable, aggregator.getResultType().getName(), resultVariable.getType().getName())); 39 resultVariable, aggregator.getResultType().getName(), resultVariable.getType().getName()));
36 } 40 }
37 if (arguments.contains(resultVariable)) { 41 if (arguments.contains(resultVariable)) {
38 throw new IllegalArgumentException("Result variable %s must not appear in the argument list".formatted( 42 throw new InvalidQueryException("Result variable %s must not appear in the argument list".formatted(
39 resultVariable)); 43 resultVariable));
40 } 44 }
41 this.resultVariable = resultVariable; 45 this.resultVariable = resultVariable;
@@ -63,7 +67,7 @@ public class AggregationLiteral<R, T> extends AbstractCallLiteral {
63 @Override 67 @Override
64 public Set<Variable> getInputVariables(Set<? extends Variable> positiveVariablesInClause) { 68 public Set<Variable> getInputVariables(Set<? extends Variable> positiveVariablesInClause) {
65 if (positiveVariablesInClause.contains(inputVariable)) { 69 if (positiveVariablesInClause.contains(inputVariable)) {
66 throw new IllegalArgumentException("Aggregation variable %s must not be bound".formatted(inputVariable)); 70 throw new InvalidQueryException("Aggregation variable %s must not be bound".formatted(inputVariable));
67 } 71 }
68 return super.getInputVariables(positiveVariablesInClause); 72 return super.getInputVariables(positiveVariablesInClause);
69 } 73 }
@@ -77,7 +81,7 @@ public class AggregationLiteral<R, T> extends AbstractCallLiteral {
77 yield emptyValue == null ? BooleanLiteral.FALSE : 81 yield emptyValue == null ? BooleanLiteral.FALSE :
78 resultVariable.assign(new ConstantTerm<>(resultVariable.getType(), emptyValue)); 82 resultVariable.assign(new ConstantTerm<>(resultVariable.getType(), emptyValue));
79 } 83 }
80 case ALWAYS_TRUE -> throw new IllegalArgumentException("Trying to aggregate over an infinite set"); 84 case ALWAYS_TRUE -> throw new InvalidQueryException("Trying to aggregate over an infinite set");
81 case NOT_REDUCIBLE -> this; 85 case NOT_REDUCIBLE -> this;
82 }; 86 };
83 } 87 }
@@ -89,6 +93,11 @@ public class AggregationLiteral<R, T> extends AbstractCallLiteral {
89 } 93 }
90 94
91 @Override 95 @Override
96 public AbstractCallLiteral withArguments(Constraint newTarget, List<Variable> newArguments) {
97 return new AggregationLiteral<>(resultVariable, aggregator, inputVariable, newTarget, newArguments);
98 }
99
100 @Override
92 public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) { 101 public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) {
93 if (!super.equalsWithSubstitution(helper, other)) { 102 if (!super.equalsWithSubstitution(helper, other)) {
94 return false; 103 return false;
@@ -100,18 +109,9 @@ public class AggregationLiteral<R, T> extends AbstractCallLiteral {
100 } 109 }
101 110
102 @Override 111 @Override
103 public boolean equals(Object o) { 112 public int hashCodeWithSubstitution(LiteralHashCodeHelper helper) {
104 if (this == o) return true; 113 return Objects.hash(super.hashCodeWithSubstitution(helper), helper.getVariableHashCode(resultVariable),
105 if (o == null || getClass() != o.getClass()) return false; 114 helper.getVariableHashCode(inputVariable), aggregator);
106 if (!super.equals(o)) return false;
107 AggregationLiteral<?, ?> that = (AggregationLiteral<?, ?>) o;
108 return resultVariable.equals(that.resultVariable) && inputVariable.equals(that.inputVariable) &&
109 aggregator.equals(that.aggregator);
110 }
111
112 @Override
113 public int hashCode() {
114 return Objects.hash(super.hashCode(), resultVariable, inputVariable, aggregator);
115 } 115 }
116 116
117 @Override 117 @Override