From 6a25ba145844c79d3507f8eabdbed854be2b8097 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Tue, 25 Jul 2023 16:06:36 +0200 Subject: feat: concrete count in partial models --- .../store/query/literal/AbstractCountLiteral.java | 106 +++++++++++++++++++++ .../refinery/store/query/literal/CountLiteral.java | 77 +++------------ .../UpperCardinalitySumAggregator.java | 2 +- 3 files changed, 118 insertions(+), 67 deletions(-) create mode 100644 subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCountLiteral.java (limited to 'subprojects/store-query/src/main/java') diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCountLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCountLiteral.java new file mode 100644 index 00000000..75f4bd49 --- /dev/null +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCountLiteral.java @@ -0,0 +1,106 @@ +/* + * SPDX-FileCopyrightText: 2023 The Refinery Authors + * + * SPDX-License-Identifier: EPL-2.0 + */ +package tools.refinery.store.query.literal; + +import tools.refinery.store.query.Constraint; +import tools.refinery.store.query.equality.LiteralEqualityHelper; +import tools.refinery.store.query.equality.LiteralHashCodeHelper; +import tools.refinery.store.query.term.ConstantTerm; +import tools.refinery.store.query.term.DataVariable; +import tools.refinery.store.query.term.Variable; + +import java.util.List; +import java.util.Objects; +import java.util.Set; + +// {@link Object#equals(Object)} is implemented by {@link AbstractLiteral}. +@SuppressWarnings("squid:S2160") +public abstract class AbstractCountLiteral extends AbstractCallLiteral { + private final Class resultType; + private final DataVariable resultVariable; + + protected AbstractCountLiteral(Class resultType, DataVariable resultVariable, Constraint target, + List arguments) { + super(target, arguments); + if (!resultVariable.getType().equals(resultType)) { + throw new IllegalArgumentException("Count result variable %s must be of type %s, got %s instead".formatted( + resultVariable, resultType, resultVariable.getType().getName())); + } + if (arguments.contains(resultVariable)) { + throw new IllegalArgumentException("Count result variable %s must not appear in the argument list" + .formatted(resultVariable)); + } + this.resultType = resultType; + this.resultVariable = resultVariable; + } + + public Class getResultType() { + return resultType; + } + + public DataVariable getResultVariable() { + return resultVariable; + } + + @Override + public Set getOutputVariables() { + return Set.of(resultVariable); + } + + protected abstract T zero(); + + protected abstract T one(); + + @Override + public Literal reduce() { + var reduction = getTarget().getReduction(); + return switch (reduction) { + case ALWAYS_FALSE -> getResultVariable().assign(new ConstantTerm<>(resultType, zero())); + // The only way a constant {@code true} predicate can be called in a negative position is to have all of + // its arguments bound as input variables. Thus, there will only be a single match. + case ALWAYS_TRUE -> getResultVariable().assign(new ConstantTerm<>(resultType, one())); + case NOT_REDUCIBLE -> this; + }; + } + + @Override + public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) { + if (!super.equalsWithSubstitution(helper, other)) { + return false; + } + var otherCountLiteral = (AbstractCountLiteral) other; + return Objects.equals(resultType, otherCountLiteral.resultType) && + helper.variableEqual(resultVariable, otherCountLiteral.resultVariable); + } + + @Override + public int hashCodeWithSubstitution(LiteralHashCodeHelper helper) { + return Objects.hash(super.hashCodeWithSubstitution(helper), resultType, + helper.getVariableHashCode(resultVariable)); + } + + protected abstract String operatorName(); + + @Override + public String toString() { + var builder = new StringBuilder(); + builder.append(resultVariable); + builder.append(" is "); + builder.append(operatorName()); + builder.append(' '); + builder.append(getTarget().toReferenceString()); + builder.append('('); + var argumentIterator = getArguments().iterator(); + if (argumentIterator.hasNext()) { + builder.append(argumentIterator.next()); + while (argumentIterator.hasNext()) { + builder.append(", ").append(argumentIterator.next()); + } + } + builder.append(')'); + return builder.toString(); + } +} diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CountLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CountLiteral.java index 77b77389..e5f6ac0c 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CountLiteral.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CountLiteral.java @@ -6,95 +6,40 @@ package tools.refinery.store.query.literal; import tools.refinery.store.query.Constraint; -import tools.refinery.store.query.equality.LiteralEqualityHelper; -import tools.refinery.store.query.equality.LiteralHashCodeHelper; import tools.refinery.store.query.substitution.Substitution; import tools.refinery.store.query.term.DataVariable; import tools.refinery.store.query.term.Variable; -import tools.refinery.store.query.term.int_.IntTerms; import java.util.List; -import java.util.Objects; -import java.util.Set; - -// {@link Object#equals(Object)} is implemented by {@link AbstractLiteral}. -@SuppressWarnings("squid:S2160") -public class CountLiteral extends AbstractCallLiteral { - private final DataVariable resultVariable; +public class CountLiteral extends AbstractCountLiteral { public CountLiteral(DataVariable resultVariable, Constraint target, List arguments) { - super(target, arguments); - if (!resultVariable.getType().equals(Integer.class)) { - throw new IllegalArgumentException("Count result variable %s must be of type %s, got %s instead".formatted( - resultVariable, Integer.class.getName(), resultVariable.getType().getName())); - } - if (arguments.contains(resultVariable)) { - throw new IllegalArgumentException("Count result variable %s must not appear in the argument list" - .formatted(resultVariable)); - } - this.resultVariable = resultVariable; - } - - public DataVariable getResultVariable() { - return resultVariable; + super(Integer.class, resultVariable, target, arguments); } @Override - public Set getOutputVariables() { - return Set.of(resultVariable); + protected Integer zero() { + return 0; } @Override - public Literal reduce() { - var reduction = getTarget().getReduction(); - return switch (reduction) { - case ALWAYS_FALSE -> getResultVariable().assign(IntTerms.constant(0)); - // The only way a constant {@code true} predicate can be called in a negative position is to have all of - // its arguments bound as input variables. Thus, there will only be a single match. - case ALWAYS_TRUE -> getResultVariable().assign(IntTerms.constant(1)); - case NOT_REDUCIBLE -> this; - }; + protected Integer one() { + return 1; } @Override protected Literal doSubstitute(Substitution substitution, List substitutedArguments) { - return new CountLiteral(substitution.getTypeSafeSubstitute(resultVariable), getTarget(), substitutedArguments); + return new CountLiteral(substitution.getTypeSafeSubstitute(getResultVariable()), getTarget(), + substitutedArguments); } @Override protected AbstractCallLiteral internalWithTarget(Constraint newTarget) { - return new CountLiteral(resultVariable, newTarget, getArguments()); - } - - @Override - public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) { - if (!super.equalsWithSubstitution(helper, other)) { - return false; - } - var otherCountLiteral = (CountLiteral) other; - return helper.variableEqual(resultVariable, otherCountLiteral.resultVariable); - } - - @Override - public int hashCodeWithSubstitution(LiteralHashCodeHelper helper) { - return Objects.hash(super.hashCodeWithSubstitution(helper), helper.getVariableHashCode(resultVariable)); + return new CountLiteral(getResultVariable(), newTarget, getArguments()); } @Override - public String toString() { - var builder = new StringBuilder(); - builder.append(resultVariable); - builder.append(" is count "); - builder.append(getTarget().toReferenceString()); - builder.append("("); - var argumentIterator = getArguments().iterator(); - if (argumentIterator.hasNext()) { - builder.append(argumentIterator.next()); - while (argumentIterator.hasNext()) { - builder.append(", ").append(argumentIterator.next()); - } - } - builder.append(")"); - return builder.toString(); + protected String operatorName() { + return "count"; } } diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/term/uppercardinality/UpperCardinalitySumAggregator.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/term/uppercardinality/UpperCardinalitySumAggregator.java index 5bbd3081..d31f00a2 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/term/uppercardinality/UpperCardinalitySumAggregator.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/term/uppercardinality/UpperCardinalitySumAggregator.java @@ -70,7 +70,7 @@ public class UpperCardinalitySumAggregator implements StatefulAggregator 0 ? UpperCardinalities.UNBOUNDED : UpperCardinalities.valueOf(sumFiniteUpperBounds); + return countUnbounded > 0 ? UpperCardinalities.UNBOUNDED : UpperCardinalities.atMost(sumFiniteUpperBounds); } @Override -- cgit v1.2.3-54-g00ecf