aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/logic/src/main/java/tools/refinery/logic/literal/AggregationLiteral.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/logic/src/main/java/tools/refinery/logic/literal/AggregationLiteral.java')
-rw-r--r--subprojects/logic/src/main/java/tools/refinery/logic/literal/AggregationLiteral.java143
1 files changed, 143 insertions, 0 deletions
diff --git a/subprojects/logic/src/main/java/tools/refinery/logic/literal/AggregationLiteral.java b/subprojects/logic/src/main/java/tools/refinery/logic/literal/AggregationLiteral.java
new file mode 100644
index 00000000..d2cc23f9
--- /dev/null
+++ b/subprojects/logic/src/main/java/tools/refinery/logic/literal/AggregationLiteral.java
@@ -0,0 +1,143 @@
1/*
2 * SPDX-FileCopyrightText: 2021-2024 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.logic.literal;
7
8import tools.refinery.logic.Constraint;
9import tools.refinery.logic.InvalidQueryException;
10import tools.refinery.logic.equality.LiteralEqualityHelper;
11import tools.refinery.logic.equality.LiteralHashCodeHelper;
12import tools.refinery.logic.substitution.Substitution;
13import tools.refinery.logic.term.*;
14
15import java.util.List;
16import java.util.Objects;
17import java.util.Set;
18
19// {@link Object#equals(Object)} is implemented by {@link AbstractLiteral}.
20@SuppressWarnings("squid:S2160")
21public class AggregationLiteral<R, T> extends AbstractCallLiteral {
22 private final DataVariable<R> resultVariable;
23 private final DataVariable<T> inputVariable;
24 private final Aggregator<R, T> aggregator;
25
26 public AggregationLiteral(DataVariable<R> resultVariable, Aggregator<R, T> aggregator,
27 DataVariable<T> inputVariable, Constraint target, List<Variable> arguments) {
28 super(target, arguments);
29 if (!inputVariable.getType().equals(aggregator.getInputType())) {
30 throw new InvalidQueryException("Input variable %s must of type %s, got %s instead".formatted(
31 inputVariable, aggregator.getInputType().getName(), inputVariable.getType().getName()));
32 }
33 if (!getArgumentsOfDirection(ParameterDirection.OUT).contains(inputVariable)) {
34 throw new InvalidQueryException("Input variable %s must be bound with direction %s in the argument list"
35 .formatted(inputVariable, ParameterDirection.OUT));
36 }
37 if (!resultVariable.getType().equals(aggregator.getResultType())) {
38 throw new InvalidQueryException("Result variable %s must of type %s, got %s instead".formatted(
39 resultVariable, aggregator.getResultType().getName(), resultVariable.getType().getName()));
40 }
41 if (arguments.contains(resultVariable)) {
42 throw new InvalidQueryException("Result variable %s must not appear in the argument list".formatted(
43 resultVariable));
44 }
45 this.resultVariable = resultVariable;
46 this.inputVariable = inputVariable;
47 this.aggregator = aggregator;
48 }
49
50 public DataVariable<R> getResultVariable() {
51 return resultVariable;
52 }
53
54 public DataVariable<T> getInputVariable() {
55 return inputVariable;
56 }
57
58 public Aggregator<R, T> getAggregator() {
59 return aggregator;
60 }
61
62 @Override
63 public Set<Variable> getOutputVariables() {
64 return Set.of(resultVariable);
65 }
66
67 @Override
68 public Set<Variable> getInputVariables(Set<? extends Variable> positiveVariablesInClause) {
69 if (positiveVariablesInClause.contains(inputVariable)) {
70 throw new InvalidQueryException("Aggregation variable %s must not be bound".formatted(inputVariable));
71 }
72 return super.getInputVariables(positiveVariablesInClause);
73 }
74
75 @Override
76 public Literal reduce() {
77 var reduction = getTarget().getReduction();
78 return switch (reduction) {
79 case ALWAYS_FALSE -> {
80 var emptyValue = aggregator.getEmptyResult();
81 yield emptyValue == null ? BooleanLiteral.FALSE :
82 resultVariable.assign(new ConstantTerm<>(resultVariable.getType(), emptyValue));
83 }
84 case ALWAYS_TRUE -> throw new InvalidQueryException("Trying to aggregate over an infinite set");
85 case NOT_REDUCIBLE -> this;
86 };
87 }
88
89 @Override
90 protected Literal doSubstitute(Substitution substitution, List<Variable> substitutedArguments) {
91 return new AggregationLiteral<>(substitution.getTypeSafeSubstitute(resultVariable), aggregator,
92 substitution.getTypeSafeSubstitute(inputVariable), getTarget(), substitutedArguments);
93 }
94
95 @Override
96 public AbstractCallLiteral withArguments(Constraint newTarget, List<Variable> newArguments) {
97 return new AggregationLiteral<>(resultVariable, aggregator, inputVariable, newTarget, newArguments);
98 }
99
100 @Override
101 public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) {
102 if (!super.equalsWithSubstitution(helper, other)) {
103 return false;
104 }
105 var otherAggregationLiteral = (AggregationLiteral<?, ?>) other;
106 return helper.variableEqual(resultVariable, otherAggregationLiteral.resultVariable) &&
107 aggregator.equals(otherAggregationLiteral.aggregator) &&
108 helper.variableEqual(inputVariable, otherAggregationLiteral.inputVariable);
109 }
110
111 @Override
112 public int hashCodeWithSubstitution(LiteralHashCodeHelper helper) {
113 return Objects.hash(super.hashCodeWithSubstitution(helper), helper.getVariableHashCode(resultVariable),
114 helper.getVariableHashCode(inputVariable), aggregator);
115 }
116
117 @Override
118 public String toString() {
119 var builder = new StringBuilder();
120 builder.append(resultVariable);
121 builder.append(" is ");
122 builder.append(getTarget().toReferenceString());
123 builder.append("(");
124 var argumentIterator = getArguments().iterator();
125 if (argumentIterator.hasNext()) {
126 var argument = argumentIterator.next();
127 if (inputVariable.equals(argument)) {
128 builder.append("@Aggregate(\"").append(aggregator).append("\") ");
129 }
130 builder.append(argument);
131 while (argumentIterator.hasNext()) {
132 builder.append(", ");
133 argument = argumentIterator.next();
134 if (inputVariable.equals(argument)) {
135 builder.append("@Aggregate(\"").append(aggregator).append("\") ");
136 }
137 builder.append(argument);
138 }
139 }
140 builder.append(")");
141 return builder.toString();
142 }
143}