diff options
author | Kristóf Marussy <kristof@marussy.com> | 2023-07-25 16:06:36 +0200 |
---|---|---|
committer | Kristóf Marussy <kristof@marussy.com> | 2023-07-25 16:06:36 +0200 |
commit | 6a25ba145844c79d3507f8eabdbed854be2b8097 (patch) | |
tree | 0ea9d4c7a9b5b94a0d4341eaa25eeb7e4d3f4f56 /subprojects/store-query/src | |
parent | feat: custom connected component RETE node (diff) | |
download | refinery-6a25ba145844c79d3507f8eabdbed854be2b8097.tar.gz refinery-6a25ba145844c79d3507f8eabdbed854be2b8097.tar.zst refinery-6a25ba145844c79d3507f8eabdbed854be2b8097.zip |
feat: concrete count in partial models
Diffstat (limited to 'subprojects/store-query/src')
3 files changed, 118 insertions, 67 deletions
diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCountLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCountLiteral.java new file mode 100644 index 00000000..75f4bd49 --- /dev/null +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCountLiteral.java | |||
@@ -0,0 +1,106 @@ | |||
1 | /* | ||
2 | * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/> | ||
3 | * | ||
4 | * SPDX-License-Identifier: EPL-2.0 | ||
5 | */ | ||
6 | package tools.refinery.store.query.literal; | ||
7 | |||
8 | import tools.refinery.store.query.Constraint; | ||
9 | import tools.refinery.store.query.equality.LiteralEqualityHelper; | ||
10 | import tools.refinery.store.query.equality.LiteralHashCodeHelper; | ||
11 | import tools.refinery.store.query.term.ConstantTerm; | ||
12 | import tools.refinery.store.query.term.DataVariable; | ||
13 | import tools.refinery.store.query.term.Variable; | ||
14 | |||
15 | import java.util.List; | ||
16 | import java.util.Objects; | ||
17 | import java.util.Set; | ||
18 | |||
19 | // {@link Object#equals(Object)} is implemented by {@link AbstractLiteral}. | ||
20 | @SuppressWarnings("squid:S2160") | ||
21 | public abstract class AbstractCountLiteral<T> extends AbstractCallLiteral { | ||
22 | private final Class<T> resultType; | ||
23 | private final DataVariable<T> resultVariable; | ||
24 | |||
25 | protected AbstractCountLiteral(Class<T> resultType, DataVariable<T> resultVariable, Constraint target, | ||
26 | List<Variable> arguments) { | ||
27 | super(target, arguments); | ||
28 | if (!resultVariable.getType().equals(resultType)) { | ||
29 | throw new IllegalArgumentException("Count result variable %s must be of type %s, got %s instead".formatted( | ||
30 | resultVariable, resultType, resultVariable.getType().getName())); | ||
31 | } | ||
32 | if (arguments.contains(resultVariable)) { | ||
33 | throw new IllegalArgumentException("Count result variable %s must not appear in the argument list" | ||
34 | .formatted(resultVariable)); | ||
35 | } | ||
36 | this.resultType = resultType; | ||
37 | this.resultVariable = resultVariable; | ||
38 | } | ||
39 | |||
40 | public Class<T> getResultType() { | ||
41 | return resultType; | ||
42 | } | ||
43 | |||
44 | public DataVariable<T> getResultVariable() { | ||
45 | return resultVariable; | ||
46 | } | ||
47 | |||
48 | @Override | ||
49 | public Set<Variable> getOutputVariables() { | ||
50 | return Set.of(resultVariable); | ||
51 | } | ||
52 | |||
53 | protected abstract T zero(); | ||
54 | |||
55 | protected abstract T one(); | ||
56 | |||
57 | @Override | ||
58 | public Literal reduce() { | ||
59 | var reduction = getTarget().getReduction(); | ||
60 | return switch (reduction) { | ||
61 | case ALWAYS_FALSE -> getResultVariable().assign(new ConstantTerm<>(resultType, zero())); | ||
62 | // The only way a constant {@code true} predicate can be called in a negative position is to have all of | ||
63 | // its arguments bound as input variables. Thus, there will only be a single match. | ||
64 | case ALWAYS_TRUE -> getResultVariable().assign(new ConstantTerm<>(resultType, one())); | ||
65 | case NOT_REDUCIBLE -> this; | ||
66 | }; | ||
67 | } | ||
68 | |||
69 | @Override | ||
70 | public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) { | ||
71 | if (!super.equalsWithSubstitution(helper, other)) { | ||
72 | return false; | ||
73 | } | ||
74 | var otherCountLiteral = (AbstractCountLiteral<?>) other; | ||
75 | return Objects.equals(resultType, otherCountLiteral.resultType) && | ||
76 | helper.variableEqual(resultVariable, otherCountLiteral.resultVariable); | ||
77 | } | ||
78 | |||
79 | @Override | ||
80 | public int hashCodeWithSubstitution(LiteralHashCodeHelper helper) { | ||
81 | return Objects.hash(super.hashCodeWithSubstitution(helper), resultType, | ||
82 | helper.getVariableHashCode(resultVariable)); | ||
83 | } | ||
84 | |||
85 | protected abstract String operatorName(); | ||
86 | |||
87 | @Override | ||
88 | public String toString() { | ||
89 | var builder = new StringBuilder(); | ||
90 | builder.append(resultVariable); | ||
91 | builder.append(" is "); | ||
92 | builder.append(operatorName()); | ||
93 | builder.append(' '); | ||
94 | builder.append(getTarget().toReferenceString()); | ||
95 | builder.append('('); | ||
96 | var argumentIterator = getArguments().iterator(); | ||
97 | if (argumentIterator.hasNext()) { | ||
98 | builder.append(argumentIterator.next()); | ||
99 | while (argumentIterator.hasNext()) { | ||
100 | builder.append(", ").append(argumentIterator.next()); | ||
101 | } | ||
102 | } | ||
103 | builder.append(')'); | ||
104 | return builder.toString(); | ||
105 | } | ||
106 | } | ||
diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CountLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CountLiteral.java index 77b77389..e5f6ac0c 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CountLiteral.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CountLiteral.java | |||
@@ -6,95 +6,40 @@ | |||
6 | package tools.refinery.store.query.literal; | 6 | package tools.refinery.store.query.literal; |
7 | 7 | ||
8 | import tools.refinery.store.query.Constraint; | 8 | import tools.refinery.store.query.Constraint; |
9 | import tools.refinery.store.query.equality.LiteralEqualityHelper; | ||
10 | import tools.refinery.store.query.equality.LiteralHashCodeHelper; | ||
11 | import tools.refinery.store.query.substitution.Substitution; | 9 | import tools.refinery.store.query.substitution.Substitution; |
12 | import tools.refinery.store.query.term.DataVariable; | 10 | import tools.refinery.store.query.term.DataVariable; |
13 | import tools.refinery.store.query.term.Variable; | 11 | import tools.refinery.store.query.term.Variable; |
14 | import tools.refinery.store.query.term.int_.IntTerms; | ||
15 | 12 | ||
16 | import java.util.List; | 13 | import java.util.List; |
17 | import java.util.Objects; | ||
18 | import java.util.Set; | ||
19 | |||
20 | // {@link Object#equals(Object)} is implemented by {@link AbstractLiteral}. | ||
21 | @SuppressWarnings("squid:S2160") | ||
22 | public class CountLiteral extends AbstractCallLiteral { | ||
23 | private final DataVariable<Integer> resultVariable; | ||
24 | 14 | ||
15 | public class CountLiteral extends AbstractCountLiteral<Integer> { | ||
25 | public CountLiteral(DataVariable<Integer> resultVariable, Constraint target, List<Variable> arguments) { | 16 | public CountLiteral(DataVariable<Integer> resultVariable, Constraint target, List<Variable> arguments) { |
26 | super(target, arguments); | 17 | super(Integer.class, resultVariable, target, arguments); |
27 | if (!resultVariable.getType().equals(Integer.class)) { | ||
28 | throw new IllegalArgumentException("Count result variable %s must be of type %s, got %s instead".formatted( | ||
29 | resultVariable, Integer.class.getName(), resultVariable.getType().getName())); | ||
30 | } | ||
31 | if (arguments.contains(resultVariable)) { | ||
32 | throw new IllegalArgumentException("Count result variable %s must not appear in the argument list" | ||
33 | .formatted(resultVariable)); | ||
34 | } | ||
35 | this.resultVariable = resultVariable; | ||
36 | } | ||
37 | |||
38 | public DataVariable<Integer> getResultVariable() { | ||
39 | return resultVariable; | ||
40 | } | 18 | } |
41 | 19 | ||
42 | @Override | 20 | @Override |
43 | public Set<Variable> getOutputVariables() { | 21 | protected Integer zero() { |
44 | return Set.of(resultVariable); | 22 | return 0; |
45 | } | 23 | } |
46 | 24 | ||
47 | @Override | 25 | @Override |
48 | public Literal reduce() { | 26 | protected Integer one() { |
49 | var reduction = getTarget().getReduction(); | 27 | return 1; |
50 | return switch (reduction) { | ||
51 | case ALWAYS_FALSE -> getResultVariable().assign(IntTerms.constant(0)); | ||
52 | // The only way a constant {@code true} predicate can be called in a negative position is to have all of | ||
53 | // its arguments bound as input variables. Thus, there will only be a single match. | ||
54 | case ALWAYS_TRUE -> getResultVariable().assign(IntTerms.constant(1)); | ||
55 | case NOT_REDUCIBLE -> this; | ||
56 | }; | ||
57 | } | 28 | } |
58 | 29 | ||
59 | @Override | 30 | @Override |
60 | protected Literal doSubstitute(Substitution substitution, List<Variable> substitutedArguments) { | 31 | protected Literal doSubstitute(Substitution substitution, List<Variable> substitutedArguments) { |
61 | return new CountLiteral(substitution.getTypeSafeSubstitute(resultVariable), getTarget(), substitutedArguments); | 32 | return new CountLiteral(substitution.getTypeSafeSubstitute(getResultVariable()), getTarget(), |
33 | substitutedArguments); | ||
62 | } | 34 | } |
63 | 35 | ||
64 | @Override | 36 | @Override |
65 | protected AbstractCallLiteral internalWithTarget(Constraint newTarget) { | 37 | protected AbstractCallLiteral internalWithTarget(Constraint newTarget) { |
66 | return new CountLiteral(resultVariable, newTarget, getArguments()); | 38 | return new CountLiteral(getResultVariable(), newTarget, getArguments()); |
67 | } | ||
68 | |||
69 | @Override | ||
70 | public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) { | ||
71 | if (!super.equalsWithSubstitution(helper, other)) { | ||
72 | return false; | ||
73 | } | ||
74 | var otherCountLiteral = (CountLiteral) other; | ||
75 | return helper.variableEqual(resultVariable, otherCountLiteral.resultVariable); | ||
76 | } | ||
77 | |||
78 | @Override | ||
79 | public int hashCodeWithSubstitution(LiteralHashCodeHelper helper) { | ||
80 | return Objects.hash(super.hashCodeWithSubstitution(helper), helper.getVariableHashCode(resultVariable)); | ||
81 | } | 39 | } |
82 | 40 | ||
83 | @Override | 41 | @Override |
84 | public String toString() { | 42 | protected String operatorName() { |
85 | var builder = new StringBuilder(); | 43 | return "count"; |
86 | builder.append(resultVariable); | ||
87 | builder.append(" is count "); | ||
88 | builder.append(getTarget().toReferenceString()); | ||
89 | builder.append("("); | ||
90 | var argumentIterator = getArguments().iterator(); | ||
91 | if (argumentIterator.hasNext()) { | ||
92 | builder.append(argumentIterator.next()); | ||
93 | while (argumentIterator.hasNext()) { | ||
94 | builder.append(", ").append(argumentIterator.next()); | ||
95 | } | ||
96 | } | ||
97 | builder.append(")"); | ||
98 | return builder.toString(); | ||
99 | } | 44 | } |
100 | } | 45 | } |
diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/term/uppercardinality/UpperCardinalitySumAggregator.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/term/uppercardinality/UpperCardinalitySumAggregator.java index 5bbd3081..d31f00a2 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/term/uppercardinality/UpperCardinalitySumAggregator.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/term/uppercardinality/UpperCardinalitySumAggregator.java | |||
@@ -70,7 +70,7 @@ public class UpperCardinalitySumAggregator implements StatefulAggregator<UpperCa | |||
70 | 70 | ||
71 | @Override | 71 | @Override |
72 | public UpperCardinality getResult() { | 72 | public UpperCardinality getResult() { |
73 | return countUnbounded > 0 ? UpperCardinalities.UNBOUNDED : UpperCardinalities.valueOf(sumFiniteUpperBounds); | 73 | return countUnbounded > 0 ? UpperCardinalities.UNBOUNDED : UpperCardinalities.atMost(sumFiniteUpperBounds); |
74 | } | 74 | } |
75 | 75 | ||
76 | @Override | 76 | @Override |