aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java')
-rw-r--r--subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java138
1 files changed, 138 insertions, 0 deletions
diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java
new file mode 100644
index 00000000..3a5eb5c7
--- /dev/null
+++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java
@@ -0,0 +1,138 @@
1/*
2 * SPDX-FileCopyrightText: 2021-2023 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.store.query.literal;
7
8import tools.refinery.store.query.Constraint;
9import tools.refinery.store.query.equality.LiteralEqualityHelper;
10import tools.refinery.store.query.substitution.Substitution;
11import tools.refinery.store.query.term.*;
12
13import java.util.List;
14import java.util.Objects;
15import java.util.Set;
16
17public class AggregationLiteral<R, T> extends AbstractCallLiteral {
18 private final DataVariable<R> resultVariable;
19 private final DataVariable<T> inputVariable;
20 private final Aggregator<R, T> aggregator;
21
22 public AggregationLiteral(DataVariable<R> resultVariable, Aggregator<R, T> aggregator,
23 DataVariable<T> inputVariable, Constraint target, List<Variable> arguments) {
24 super(target, arguments);
25 if (!inputVariable.getType().equals(aggregator.getInputType())) {
26 throw new IllegalArgumentException("Input variable %s must of type %s, got %s instead".formatted(
27 inputVariable, aggregator.getInputType().getName(), inputVariable.getType().getName()));
28 }
29 if (!getArgumentsOfDirection(ParameterDirection.OUT).contains(inputVariable)) {
30 throw new IllegalArgumentException("Input variable %s must be bound with direction %s in the argument list"
31 .formatted(inputVariable, ParameterDirection.OUT));
32 }
33 if (!resultVariable.getType().equals(aggregator.getResultType())) {
34 throw new IllegalArgumentException("Result variable %s must of type %s, got %s instead".formatted(
35 resultVariable, aggregator.getResultType().getName(), resultVariable.getType().getName()));
36 }
37 if (arguments.contains(resultVariable)) {
38 throw new IllegalArgumentException("Result variable %s must not appear in the argument list".formatted(
39 resultVariable));
40 }
41 this.resultVariable = resultVariable;
42 this.inputVariable = inputVariable;
43 this.aggregator = aggregator;
44 }
45
46 public DataVariable<R> getResultVariable() {
47 return resultVariable;
48 }
49
50 public DataVariable<T> getInputVariable() {
51 return inputVariable;
52 }
53
54 public Aggregator<R, T> getAggregator() {
55 return aggregator;
56 }
57
58 @Override
59 public Set<Variable> getOutputVariables() {
60 return Set.of(resultVariable);
61 }
62
63 @Override
64 public Set<Variable> getInputVariables(Set<? extends Variable> positiveVariablesInClause) {
65 if (positiveVariablesInClause.contains(inputVariable)) {
66 throw new IllegalArgumentException("Aggregation variable %s must not be bound".formatted(inputVariable));
67 }
68 return super.getInputVariables(positiveVariablesInClause);
69 }
70
71 @Override
72 public Literal reduce() {
73 var reduction = getTarget().getReduction();
74 return switch (reduction) {
75 case ALWAYS_FALSE -> {
76 var emptyValue = aggregator.getEmptyResult();
77 yield emptyValue == null ? BooleanLiteral.FALSE :
78 resultVariable.assign(new ConstantTerm<>(resultVariable.getType(), emptyValue));
79 }
80 case ALWAYS_TRUE -> throw new IllegalArgumentException("Trying to aggregate over an infinite set");
81 case NOT_REDUCIBLE -> this;
82 };
83 }
84
85 @Override
86 protected Literal doSubstitute(Substitution substitution, List<Variable> substitutedArguments) {
87 return new AggregationLiteral<>(substitution.getTypeSafeSubstitute(resultVariable), aggregator,
88 substitution.getTypeSafeSubstitute(inputVariable), getTarget(), substitutedArguments);
89 }
90
91 @Override
92 public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) {
93 if (!super.equalsWithSubstitution(helper, other)) {
94 return false;
95 }
96 var otherAggregationLiteral = (AggregationLiteral<?, ?>) other;
97 return helper.variableEqual(resultVariable, otherAggregationLiteral.resultVariable) &&
98 aggregator.equals(otherAggregationLiteral.aggregator) &&
99 helper.variableEqual(inputVariable, otherAggregationLiteral.inputVariable);
100 }
101
102 @Override
103 public boolean equals(Object o) {
104 if (this == o) return true;
105 if (o == null || getClass() != o.getClass()) return false;
106 if (!super.equals(o)) return false;
107 AggregationLiteral<?, ?> that = (AggregationLiteral<?, ?>) o;
108 return resultVariable.equals(that.resultVariable) && inputVariable.equals(that.inputVariable) &&
109 aggregator.equals(that.aggregator);
110 }
111
112 @Override
113 public int hashCode() {
114 return Objects.hash(super.hashCode(), resultVariable, inputVariable, aggregator);
115 }
116
117 @Override
118 public String toString() {
119 var builder = new StringBuilder();
120 builder.append(resultVariable);
121 builder.append(" is ");
122 builder.append(getTarget().toReferenceString());
123 builder.append("(");
124 var argumentIterator = getArguments().iterator();
125 if (argumentIterator.hasNext()) {
126 var argument = argumentIterator.next();
127 if (inputVariable.equals(argument)) {
128 builder.append("@Aggregate(\"").append(aggregator).append("\") ");
129 }
130 builder.append(argument);
131 while (argumentIterator.hasNext()) {
132 builder.append(", ").append(argumentIterator.next());
133 }
134 }
135 builder.append(")");
136 return builder.toString();
137 }
138}