diff options
Diffstat (limited to 'subprojects/logic/src/main/java/tools/refinery/logic/literal/AggregationLiteral.java')
-rw-r--r-- | subprojects/logic/src/main/java/tools/refinery/logic/literal/AggregationLiteral.java | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/subprojects/logic/src/main/java/tools/refinery/logic/literal/AggregationLiteral.java b/subprojects/logic/src/main/java/tools/refinery/logic/literal/AggregationLiteral.java new file mode 100644 index 00000000..d2cc23f9 --- /dev/null +++ b/subprojects/logic/src/main/java/tools/refinery/logic/literal/AggregationLiteral.java | |||
@@ -0,0 +1,143 @@ | |||
1 | /* | ||
2 | * SPDX-FileCopyrightText: 2021-2024 The Refinery Authors <https://refinery.tools/> | ||
3 | * | ||
4 | * SPDX-License-Identifier: EPL-2.0 | ||
5 | */ | ||
6 | package tools.refinery.logic.literal; | ||
7 | |||
8 | import tools.refinery.logic.Constraint; | ||
9 | import tools.refinery.logic.InvalidQueryException; | ||
10 | import tools.refinery.logic.equality.LiteralEqualityHelper; | ||
11 | import tools.refinery.logic.equality.LiteralHashCodeHelper; | ||
12 | import tools.refinery.logic.substitution.Substitution; | ||
13 | import tools.refinery.logic.term.*; | ||
14 | |||
15 | import java.util.List; | ||
16 | import java.util.Objects; | ||
17 | import java.util.Set; | ||
18 | |||
19 | // {@link Object#equals(Object)} is implemented by {@link AbstractLiteral}. | ||
20 | @SuppressWarnings("squid:S2160") | ||
21 | public class AggregationLiteral<R, T> extends AbstractCallLiteral { | ||
22 | private final DataVariable<R> resultVariable; | ||
23 | private final DataVariable<T> inputVariable; | ||
24 | private final Aggregator<R, T> aggregator; | ||
25 | |||
26 | public AggregationLiteral(DataVariable<R> resultVariable, Aggregator<R, T> aggregator, | ||
27 | DataVariable<T> inputVariable, Constraint target, List<Variable> arguments) { | ||
28 | super(target, arguments); | ||
29 | if (!inputVariable.getType().equals(aggregator.getInputType())) { | ||
30 | throw new InvalidQueryException("Input variable %s must of type %s, got %s instead".formatted( | ||
31 | inputVariable, aggregator.getInputType().getName(), inputVariable.getType().getName())); | ||
32 | } | ||
33 | if (!getArgumentsOfDirection(ParameterDirection.OUT).contains(inputVariable)) { | ||
34 | throw new InvalidQueryException("Input variable %s must be bound with direction %s in the argument list" | ||
35 | .formatted(inputVariable, ParameterDirection.OUT)); | ||
36 | } | ||
37 | if (!resultVariable.getType().equals(aggregator.getResultType())) { | ||
38 | throw new InvalidQueryException("Result variable %s must of type %s, got %s instead".formatted( | ||
39 | resultVariable, aggregator.getResultType().getName(), resultVariable.getType().getName())); | ||
40 | } | ||
41 | if (arguments.contains(resultVariable)) { | ||
42 | throw new InvalidQueryException("Result variable %s must not appear in the argument list".formatted( | ||
43 | resultVariable)); | ||
44 | } | ||
45 | this.resultVariable = resultVariable; | ||
46 | this.inputVariable = inputVariable; | ||
47 | this.aggregator = aggregator; | ||
48 | } | ||
49 | |||
50 | public DataVariable<R> getResultVariable() { | ||
51 | return resultVariable; | ||
52 | } | ||
53 | |||
54 | public DataVariable<T> getInputVariable() { | ||
55 | return inputVariable; | ||
56 | } | ||
57 | |||
58 | public Aggregator<R, T> getAggregator() { | ||
59 | return aggregator; | ||
60 | } | ||
61 | |||
62 | @Override | ||
63 | public Set<Variable> getOutputVariables() { | ||
64 | return Set.of(resultVariable); | ||
65 | } | ||
66 | |||
67 | @Override | ||
68 | public Set<Variable> getInputVariables(Set<? extends Variable> positiveVariablesInClause) { | ||
69 | if (positiveVariablesInClause.contains(inputVariable)) { | ||
70 | throw new InvalidQueryException("Aggregation variable %s must not be bound".formatted(inputVariable)); | ||
71 | } | ||
72 | return super.getInputVariables(positiveVariablesInClause); | ||
73 | } | ||
74 | |||
75 | @Override | ||
76 | public Literal reduce() { | ||
77 | var reduction = getTarget().getReduction(); | ||
78 | return switch (reduction) { | ||
79 | case ALWAYS_FALSE -> { | ||
80 | var emptyValue = aggregator.getEmptyResult(); | ||
81 | yield emptyValue == null ? BooleanLiteral.FALSE : | ||
82 | resultVariable.assign(new ConstantTerm<>(resultVariable.getType(), emptyValue)); | ||
83 | } | ||
84 | case ALWAYS_TRUE -> throw new InvalidQueryException("Trying to aggregate over an infinite set"); | ||
85 | case NOT_REDUCIBLE -> this; | ||
86 | }; | ||
87 | } | ||
88 | |||
89 | @Override | ||
90 | protected Literal doSubstitute(Substitution substitution, List<Variable> substitutedArguments) { | ||
91 | return new AggregationLiteral<>(substitution.getTypeSafeSubstitute(resultVariable), aggregator, | ||
92 | substitution.getTypeSafeSubstitute(inputVariable), getTarget(), substitutedArguments); | ||
93 | } | ||
94 | |||
95 | @Override | ||
96 | public AbstractCallLiteral withArguments(Constraint newTarget, List<Variable> newArguments) { | ||
97 | return new AggregationLiteral<>(resultVariable, aggregator, inputVariable, newTarget, newArguments); | ||
98 | } | ||
99 | |||
100 | @Override | ||
101 | public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) { | ||
102 | if (!super.equalsWithSubstitution(helper, other)) { | ||
103 | return false; | ||
104 | } | ||
105 | var otherAggregationLiteral = (AggregationLiteral<?, ?>) other; | ||
106 | return helper.variableEqual(resultVariable, otherAggregationLiteral.resultVariable) && | ||
107 | aggregator.equals(otherAggregationLiteral.aggregator) && | ||
108 | helper.variableEqual(inputVariable, otherAggregationLiteral.inputVariable); | ||
109 | } | ||
110 | |||
111 | @Override | ||
112 | public int hashCodeWithSubstitution(LiteralHashCodeHelper helper) { | ||
113 | return Objects.hash(super.hashCodeWithSubstitution(helper), helper.getVariableHashCode(resultVariable), | ||
114 | helper.getVariableHashCode(inputVariable), aggregator); | ||
115 | } | ||
116 | |||
117 | @Override | ||
118 | public String toString() { | ||
119 | var builder = new StringBuilder(); | ||
120 | builder.append(resultVariable); | ||
121 | builder.append(" is "); | ||
122 | builder.append(getTarget().toReferenceString()); | ||
123 | builder.append("("); | ||
124 | var argumentIterator = getArguments().iterator(); | ||
125 | if (argumentIterator.hasNext()) { | ||
126 | var argument = argumentIterator.next(); | ||
127 | if (inputVariable.equals(argument)) { | ||
128 | builder.append("@Aggregate(\"").append(aggregator).append("\") "); | ||
129 | } | ||
130 | builder.append(argument); | ||
131 | while (argumentIterator.hasNext()) { | ||
132 | builder.append(", "); | ||
133 | argument = argumentIterator.next(); | ||
134 | if (inputVariable.equals(argument)) { | ||
135 | builder.append("@Aggregate(\"").append(aggregator).append("\") "); | ||
136 | } | ||
137 | builder.append(argument); | ||
138 | } | ||
139 | } | ||
140 | builder.append(")"); | ||
141 | return builder.toString(); | ||
142 | } | ||
143 | } | ||