From de351479641196a85b21875fcab4f6527779b2ff Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Sat, 17 Jun 2023 02:10:30 +0200 Subject: fix: further Dnf tests and fixes --- .../tools/refinery/store/query/Constraint.java | 4 +- .../store/query/dnf/ClausePostProcessor.java | 4 +- .../tools/refinery/store/query/dnf/DnfBuilder.java | 13 +- .../IncompatibleParameterDirectionException.java | 16 - .../store/query/literal/AbstractCallLiteral.java | 15 +- .../store/query/literal/AggregationLiteral.java | 8 + .../refinery/store/query/literal/CallLiteral.java | 16 + .../tools/refinery/store/query/term/Parameter.java | 2 +- .../store/query/view/AbstractFunctionView.java | 2 +- .../store/query/view/NodeFunctionView.java | 2 +- .../store/query/view/TuplePreservingView.java | 2 +- .../store/query/dnf/TopologicalSortTest.java | 112 ++++++ .../store/query/dnf/VariableDirectionTest.java | 428 +++++++++++++++++++++ .../query/literal/AggregationLiteralTest.java | 88 +++++ .../store/query/literal/CallLiteralTest.java | 94 +++++ .../reasoning/representation/PartialRelation.java | 2 +- 16 files changed, 779 insertions(+), 29 deletions(-) delete mode 100644 subprojects/store-query/src/main/java/tools/refinery/store/query/exceptions/IncompatibleParameterDirectionException.java create mode 100644 subprojects/store-query/src/test/java/tools/refinery/store/query/dnf/TopologicalSortTest.java create mode 100644 subprojects/store-query/src/test/java/tools/refinery/store/query/dnf/VariableDirectionTest.java create mode 100644 subprojects/store-query/src/test/java/tools/refinery/store/query/literal/AggregationLiteralTest.java create mode 100644 subprojects/store-query/src/test/java/tools/refinery/store/query/literal/CallLiteralTest.java diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/Constraint.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/Constraint.java index e841da9e..c0995e53 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/Constraint.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/Constraint.java @@ -32,7 +32,9 @@ public interface Constraint { return equals(other); } - String toReferenceString(); + default String toReferenceString() { + return name(); + } default CallLiteral call(CallPolarity polarity, List arguments) { return new CallLiteral(polarity, this, arguments); diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/ClausePostProcessor.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/ClausePostProcessor.java index dd45ecd4..b5e7092b 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/ClausePostProcessor.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/ClausePostProcessor.java @@ -106,7 +106,7 @@ class ClausePostProcessor { private Set getEquivalentVariables(NodeVariable variable) { var representative = getRepresentative(variable); if (!representative.equals(variable)) { - throw new IllegalStateException("NodeVariable %s already has a representative %s" + throw new AssertionError("NodeVariable %s already has a representative %s" .formatted(variable, representative)); } return equivalencePartition.computeIfAbsent(variable, key -> { @@ -249,7 +249,7 @@ class ClausePostProcessor { private void bindVariable(Variable input) { if (!remainingInputs.remove(input)) { - throw new IllegalStateException("Already processed input %s of literal %s".formatted(input, literal)); + throw new AssertionError("Already processed input %s of literal %s".formatted(input, literal)); } if (allInputsBound()) { addToAllInputsBoundQueue(); diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java index dcf7611d..8e38ca6b 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java @@ -223,7 +223,8 @@ public final class DnfBuilder { } else if (result instanceof ClausePostProcessor.ConstantResult constantResult) { switch (constantResult) { case ALWAYS_TRUE -> { - return List.of(new DnfClause(Set.of(), List.of())); + var inputVariables = getInputVariables(); + return List.of(new DnfClause(inputVariables, List.of())); } case ALWAYS_FALSE -> { // Skip this clause because it can never match. @@ -248,4 +249,14 @@ public final class DnfBuilder { } return Collections.unmodifiableMap(mutableParameterInfoMap); } + + private Set getInputVariables() { + var inputParameters = new LinkedHashSet(); + for (var parameter : parameters) { + if (parameter.getDirection() == ParameterDirection.IN) { + inputParameters.add(parameter.getVariable()); + } + } + return Collections.unmodifiableSet(inputParameters); + } } diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/exceptions/IncompatibleParameterDirectionException.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/exceptions/IncompatibleParameterDirectionException.java deleted file mode 100644 index 52da20ae..00000000 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/exceptions/IncompatibleParameterDirectionException.java +++ /dev/null @@ -1,16 +0,0 @@ -/* - * SPDX-FileCopyrightText: 2023 The Refinery Authors - * - * SPDX-License-Identifier: EPL-2.0 - */ -package tools.refinery.store.query.exceptions; - -public class IncompatibleParameterDirectionException extends RuntimeException { - public IncompatibleParameterDirectionException(String message) { - super(message); - } - - public IncompatibleParameterDirectionException(String message, Throwable cause) { - super(message, cause); - } -} diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCallLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCallLiteral.java index ed7d3401..8ef8e8b4 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCallLiteral.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCallLiteral.java @@ -19,6 +19,8 @@ public abstract class AbstractCallLiteral implements Literal { private final Set inArguments; private final Set outArguments; + // Use exhaustive switch over enums. + @SuppressWarnings("squid:S1301") protected AbstractCallLiteral(Constraint target, List arguments) { int arity = target.arity(); if (arguments.size() != arity) { @@ -59,14 +61,14 @@ public abstract class AbstractCallLiteral implements Literal { private static void checkInOutUnifiable(Variable argument) { if (!argument.isUnifiable()) { - throw new IllegalArgumentException("Arguments %s cannot appear with both %s and %s direction" + throw new IllegalArgumentException("Argument %s cannot appear with both %s and %s direction" .formatted(argument, ParameterDirection.IN, ParameterDirection.OUT)); } } private static void checkDuplicateOutUnifiable(Variable argument) { if (!argument.isUnifiable()) { - throw new IllegalArgumentException("Arguments %s cannot be bound multiple times".formatted(argument)); + throw new IllegalArgumentException("Argument %s cannot be bound multiple times".formatted(argument)); } } @@ -87,12 +89,17 @@ public abstract class AbstractCallLiteral implements Literal { @Override public Set getInputVariables(Set positiveVariablesInClause) { - return getArgumentsOfDirection(ParameterDirection.IN); + var inputVariables = new LinkedHashSet<>(getArgumentsOfDirection(ParameterDirection.OUT)); + inputVariables.retainAll(positiveVariablesInClause); + inputVariables.addAll(getArgumentsOfDirection(ParameterDirection.IN)); + return Collections.unmodifiableSet(inputVariables); } @Override public Set getPrivateVariables(Set positiveVariablesInClause) { - return Set.of(); + var privateVariables = new LinkedHashSet<>(getArgumentsOfDirection(ParameterDirection.OUT)); + privateVariables.removeAll(positiveVariablesInClause); + return Collections.unmodifiableSet(privateVariables); } @Override diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java index b2fec430..3a5eb5c7 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java @@ -60,6 +60,14 @@ public class AggregationLiteral extends AbstractCallLiteral { return Set.of(resultVariable); } + @Override + public Set getInputVariables(Set positiveVariablesInClause) { + if (positiveVariablesInClause.contains(inputVariable)) { + throw new IllegalArgumentException("Aggregation variable %s must not be bound".formatted(inputVariable)); + } + return super.getInputVariables(positiveVariablesInClause); + } + @Override public Literal reduce() { var reduction = getTarget().getReduction(); diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CallLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CallLiteral.java index 27d8ad60..29772aee 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CallLiteral.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CallLiteral.java @@ -48,6 +48,22 @@ public final class CallLiteral extends AbstractCallLiteral implements CanNegate< return Set.of(); } + @Override + public Set getInputVariables(Set positiveVariablesInClause) { + if (polarity.isPositive()) { + return getArgumentsOfDirection(ParameterDirection.IN); + } + return super.getInputVariables(positiveVariablesInClause); + } + + @Override + public Set getPrivateVariables(Set positiveVariablesInClause) { + if (polarity.isPositive()) { + return Set.of(); + } + return super.getPrivateVariables(positiveVariablesInClause); + } + @Override public Literal reduce() { var reduction = getTarget().getReduction(); diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/term/Parameter.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/term/Parameter.java index 0fe297ab..e5a0cdf1 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/term/Parameter.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/term/Parameter.java @@ -9,7 +9,7 @@ import java.util.Objects; import java.util.Optional; public class Parameter { - public static final Parameter NODE_IN_OUT = new Parameter(null, ParameterDirection.OUT); + public static final Parameter NODE_OUT = new Parameter(null, ParameterDirection.OUT); private final Class dataType; private final ParameterDirection direction; diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/view/AbstractFunctionView.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/view/AbstractFunctionView.java index c6f3dd43..c1f9d688 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/view/AbstractFunctionView.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/view/AbstractFunctionView.java @@ -105,7 +105,7 @@ public abstract class AbstractFunctionView extends SymbolView { private static List createParameters(int symbolArity, Parameter outParameter) { var parameters = new Parameter[symbolArity + 1]; - Arrays.fill(parameters, Parameter.NODE_IN_OUT); + Arrays.fill(parameters, Parameter.NODE_OUT); parameters[symbolArity] = outParameter; return List.of(parameters); } diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/view/NodeFunctionView.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/view/NodeFunctionView.java index e9785c67..fcf11506 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/view/NodeFunctionView.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/view/NodeFunctionView.java @@ -11,7 +11,7 @@ import tools.refinery.store.tuple.Tuple1; public final class NodeFunctionView extends AbstractFunctionView { public NodeFunctionView(Symbol symbol, String name) { - super(symbol, name, Parameter.NODE_IN_OUT); + super(symbol, name, Parameter.NODE_OUT); } public NodeFunctionView(Symbol symbol) { diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/view/TuplePreservingView.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/view/TuplePreservingView.java index 7e5b7788..6bc5a708 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/view/TuplePreservingView.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/view/TuplePreservingView.java @@ -76,7 +76,7 @@ public abstract class TuplePreservingView extends SymbolView { private static List createParameters(int arity) { var parameters = new Parameter[arity]; - Arrays.fill(parameters, Parameter.NODE_IN_OUT); + Arrays.fill(parameters, Parameter.NODE_OUT); return List.of(parameters); } } diff --git a/subprojects/store-query/src/test/java/tools/refinery/store/query/dnf/TopologicalSortTest.java b/subprojects/store-query/src/test/java/tools/refinery/store/query/dnf/TopologicalSortTest.java new file mode 100644 index 00000000..6d53f184 --- /dev/null +++ b/subprojects/store-query/src/test/java/tools/refinery/store/query/dnf/TopologicalSortTest.java @@ -0,0 +1,112 @@ +/* + * SPDX-FileCopyrightText: 2023 The Refinery Authors + * + * SPDX-License-Identifier: EPL-2.0 + */ +package tools.refinery.store.query.dnf; + +import org.junit.jupiter.api.Test; +import tools.refinery.store.query.term.NodeVariable; +import tools.refinery.store.query.term.ParameterDirection; +import tools.refinery.store.query.term.Variable; +import tools.refinery.store.query.view.AnySymbolView; +import tools.refinery.store.query.view.KeyOnlyView; +import tools.refinery.store.representation.Symbol; + +import java.util.List; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static tools.refinery.store.query.literal.Literals.not; +import static tools.refinery.store.query.tests.QueryMatchers.structurallyEqualTo; + +class TopologicalSortTest { + private static final Symbol friend = new Symbol<>("friend", 2, Boolean.class, false); + private static final AnySymbolView friendView = new KeyOnlyView<>(friend); + private static final Dnf example = Dnf.of("example", builder -> { + var a = builder.parameter("a", ParameterDirection.IN); + var b = builder.parameter("b", ParameterDirection.IN); + var c = builder.parameter("c", ParameterDirection.OUT); + var d = builder.parameter("d", ParameterDirection.OUT); + builder.clause( + friendView.call(a, b), + friendView.call(b, c), + friendView.call(c, d) + ); + }); + private static final NodeVariable p = Variable.of("p"); + private static final NodeVariable q = Variable.of("q"); + private static final NodeVariable r = Variable.of("r"); + private static final NodeVariable s = Variable.of("s"); + private static final NodeVariable t = Variable.of("t"); + + @Test + void topologicalSortTest() { + var actual = Dnf.builder("Actual") + .parameter(p, ParameterDirection.IN) + .parameter(q, ParameterDirection.OUT) + .clause( + not(friendView.call(p, q)), + example.call(p, q, r, s), + example.call(r, t, q, s), + friendView.call(r, t) + ) + .build(); + + assertThat(actual, structurallyEqualTo( + List.of( + new SymbolicParameter(p, ParameterDirection.IN), + new SymbolicParameter(q, ParameterDirection.OUT) + ), + List.of( + List.of( + friendView.call(r, t), + example.call(r, t, q, s), + not(friendView.call(p, q)), + example.call(p, q, r, s) + ) + ) + )); + } + + @Test + void missingInputTest() { + var builder = Dnf.builder("Actual") + .parameter(p, ParameterDirection.OUT) + .parameter(q, ParameterDirection.OUT) + .clause( + not(friendView.call(p, q)), + example.call(p, q, r, s), + example.call(r, t, q, s), + friendView.call(r, t) + ); + assertThrows(IllegalArgumentException.class, builder::build); + } + + @Test + void missingVariableTest() { + var builder = Dnf.builder("Actual") + .parameter(p, ParameterDirection.IN) + .parameter(q, ParameterDirection.OUT) + .clause( + not(friendView.call(p, q)), + example.call(p, q, r, s), + example.call(r, t, q, s) + ); + assertThrows(IllegalArgumentException.class, builder::build); + } + + @Test + void circularDependencyTest() { + var builder = Dnf.builder("Actual") + .parameter(p, ParameterDirection.IN) + .parameter(q, ParameterDirection.OUT) + .clause( + not(friendView.call(p, q)), + example.call(p, q, r, s), + example.call(r, t, q, s), + example.call(p, q, r, t) + ); + assertThrows(IllegalArgumentException.class, builder::build); + } +} diff --git a/subprojects/store-query/src/test/java/tools/refinery/store/query/dnf/VariableDirectionTest.java b/subprojects/store-query/src/test/java/tools/refinery/store/query/dnf/VariableDirectionTest.java new file mode 100644 index 00000000..0a44664e --- /dev/null +++ b/subprojects/store-query/src/test/java/tools/refinery/store/query/dnf/VariableDirectionTest.java @@ -0,0 +1,428 @@ +/* + * SPDX-FileCopyrightText: 2023 The Refinery Authors + * + * SPDX-License-Identifier: EPL-2.0 + */ +package tools.refinery.store.query.dnf; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import tools.refinery.store.query.literal.BooleanLiteral; +import tools.refinery.store.query.literal.Literal; +import tools.refinery.store.query.term.DataVariable; +import tools.refinery.store.query.term.NodeVariable; +import tools.refinery.store.query.term.ParameterDirection; +import tools.refinery.store.query.term.Variable; +import tools.refinery.store.query.view.AnySymbolView; +import tools.refinery.store.query.view.FunctionView; +import tools.refinery.store.query.view.KeyOnlyView; +import tools.refinery.store.representation.Symbol; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.not; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static tools.refinery.store.query.literal.Literals.assume; +import static tools.refinery.store.query.literal.Literals.not; +import static tools.refinery.store.query.term.int_.IntTerms.*; + +class VariableDirectionTest { + private static final Symbol person = new Symbol<>("Person", 1, Boolean.class, false); + private static final Symbol friend = new Symbol<>("friend", 2, Boolean.class, false); + private static final Symbol age = new Symbol<>("age", 1, Integer.class, null); + private static final AnySymbolView personView = new KeyOnlyView<>(person); + private static final AnySymbolView friendView = new KeyOnlyView<>(friend); + private static final AnySymbolView ageView = new FunctionView<>(age); + private static final NodeVariable p = Variable.of("p"); + private static final NodeVariable q = Variable.of("q"); + private static final DataVariable x = Variable.of("x", Integer.class); + private static final DataVariable y = Variable.of("y", Integer.class); + private static final DataVariable z = Variable.of("z", Integer.class); + + @ParameterizedTest + @MethodSource("clausesWithVariableInput") + void unboundOutVariableTest(List clause) { + var builder = Dnf.builder().parameter(p, ParameterDirection.OUT).clause(clause); + assertThrows(IllegalArgumentException.class, builder::build); + } + + @ParameterizedTest + @MethodSource("clausesWithVariableInput") + void unboundInVariableTest(List clause) { + var builder = Dnf.builder().parameter(p, ParameterDirection.IN).clause(clause); + var dnf = assertDoesNotThrow(builder::build); + var clauses = dnf.getClauses(); + if (clauses.size() > 0) { + assertThat(clauses.get(0).positiveVariables(), hasItem(p)); + } + } + + @ParameterizedTest + @MethodSource("clausesWithVariableInput") + void boundPrivateVariableTest(List clause) { + var clauseWithBinding = new ArrayList(clause); + clauseWithBinding.add(personView.call(p)); + var builder = Dnf.builder().clause(clauseWithBinding); + var dnf = assertDoesNotThrow(builder::build); + var clauses = dnf.getClauses(); + if (clauses.size() > 0) { + assertThat(clauses.get(0).positiveVariables(), hasItem(p)); + } + } + + static Stream clausesWithVariableInput() { + return Stream.concat( + clausesNotBindingVariable(), + literalToClauseArgumentStream(literalsWithRequiredVariableInput()) + ); + } + + @ParameterizedTest + @MethodSource("clausesNotBindingVariable") + void unboundPrivateVariableTest(List clause) { + var builder = Dnf.builder().clause(clause); + var dnf = assertDoesNotThrow(builder::build); + var clauses = dnf.getClauses(); + if (clauses.size() > 0) { + assertThat(clauses.get(0).positiveVariables(), not(hasItem(p))); + } + } + + @ParameterizedTest + @MethodSource("clausesNotBindingVariable") + void unboundByEquivalencePrivateVariableTest(List clause) { + var r = Variable.of("r"); + var clauseWithEquivalence = new ArrayList(clause); + clauseWithEquivalence.add(r.isEquivalent(p)); + var builder = Dnf.builder().clause(clauseWithEquivalence); + assertThrows(IllegalArgumentException.class, builder::build); + } + + static Stream clausesNotBindingVariable() { + return Stream.concat( + Stream.of( + Arguments.of(List.of()), + Arguments.of(List.of(BooleanLiteral.TRUE)), + Arguments.of(List.of(BooleanLiteral.FALSE)) + ), + literalToClauseArgumentStream(literalsWithPrivateVariable()) + ); + } + + @ParameterizedTest + @MethodSource("literalsWithPrivateVariable") + void unboundTwicePrivateVariableTest(Literal literal) { + var builder = Dnf.builder().clause(not(personView.call(p)), literal); + assertThrows(IllegalArgumentException.class, builder::build); + } + + @ParameterizedTest + @MethodSource("literalsWithPrivateVariable") + void unboundTwiceByEquivalencePrivateVariableTest(Literal literal) { + var r = Variable.of("r"); + var builder = Dnf.builder().clause(not(personView.call(r)), r.isEquivalent(p), literal); + assertThrows(IllegalArgumentException.class, builder::build); + } + + static Stream literalsWithPrivateVariable() { + var dnfWithOutput = Dnf.builder("WithOutput") + .parameter(p, ParameterDirection.OUT) + .parameter(q, ParameterDirection.OUT) + .clause(friendView.call(p, q)) + .build(); + var dnfWithOutputToAggregate = Dnf.builder("WithOutputToAggregate") + .parameter(p, ParameterDirection.OUT) + .parameter(q, ParameterDirection.OUT) + .parameter(x, ParameterDirection.OUT) + .clause( + friendView.call(p, q), + ageView.call(q, x) + ) + .build(); + + return Stream.of( + Arguments.of(not(friendView.call(p, q))), + Arguments.of(y.assign(friendView.count(p, q))), + Arguments.of(y.assign(ageView.aggregate(z, INT_SUM, p, z))), + Arguments.of(not(dnfWithOutput.call(p, q))), + Arguments.of(y.assign(dnfWithOutput.count(p, q))), + Arguments.of(y.assign(dnfWithOutputToAggregate.aggregate(z, INT_SUM, p, q, z))) + ); + } + + @ParameterizedTest + @MethodSource("literalsWithRequiredVariableInput") + void unboundPrivateVariableTest(Literal literal) { + var builder = Dnf.builder().clause(literal); + assertThrows(IllegalArgumentException.class, builder::build); + } + + @ParameterizedTest + @MethodSource("literalsWithRequiredVariableInput") + void boundPrivateVariableInputTest(Literal literal) { + var builder = Dnf.builder().clause(personView.call(p), literal); + var dnf = assertDoesNotThrow(builder::build); + assertThat(dnf.getClauses().get(0).positiveVariables(), hasItem(p)); + } + + static Stream literalsWithRequiredVariableInput() { + var dnfWithInput = Dnf.builder("WithInput") + .parameter(p, ParameterDirection.IN) + .parameter(q, ParameterDirection.OUT) + .clause(friendView.call(p, q)).build(); + var dnfWithInputToAggregate = Dnf.builder("WithInputToAggregate") + .parameter(p, ParameterDirection.IN) + .parameter(q, ParameterDirection.OUT) + .parameter(x, ParameterDirection.OUT) + .clause( + friendView.call(p, q), + ageView.call(q, x) + ).build(); + + return Stream.of( + Arguments.of(dnfWithInput.call(p, q)), + Arguments.of(dnfWithInput.call(p, p)), + Arguments.of(not(dnfWithInput.call(p, q))), + Arguments.of(not(dnfWithInput.call(p, p))), + Arguments.of(y.assign(dnfWithInput.count(p, q))), + Arguments.of(y.assign(dnfWithInput.count(p, p))), + Arguments.of(y.assign(dnfWithInputToAggregate.aggregate(z, INT_SUM, p, q, z))), + Arguments.of(y.assign(dnfWithInputToAggregate.aggregate(z, INT_SUM, p, p, z))) + ); + } + + @ParameterizedTest + @MethodSource("literalsWithVariableOutput") + void boundParameterTest(Literal literal) { + var builder = Dnf.builder().parameter(p, ParameterDirection.OUT).clause(literal); + var dnf = assertDoesNotThrow(builder::build); + assertThat(dnf.getClauses().get(0).positiveVariables(), hasItem(p)); + } + + @ParameterizedTest + @MethodSource("literalsWithVariableOutput") + void boundTwiceParameterTest(Literal literal) { + var builder = Dnf.builder().parameter(p, ParameterDirection.IN).clause(literal); + var dnf = assertDoesNotThrow(builder::build); + assertThat(dnf.getClauses().get(0).positiveVariables(), hasItem(p)); + } + + @ParameterizedTest + @MethodSource("literalsWithVariableOutput") + void boundPrivateVariableOutputTest(Literal literal) { + var dnfWithInput = Dnf.builder("WithInput") + .parameter(p, ParameterDirection.IN) + .clause(personView.call(p)) + .build(); + var builder = Dnf.builder().clause(dnfWithInput.call(p), literal); + var dnf = assertDoesNotThrow(builder::build); + assertThat(dnf.getClauses().get(0).positiveVariables(), hasItem(p)); + } + + @ParameterizedTest + @MethodSource("literalsWithVariableOutput") + void boundTwicePrivateVariableOutputTest(Literal literal) { + var builder = Dnf.builder().clause(personView.call(p), literal); + var dnf = assertDoesNotThrow(builder::build); + assertThat(dnf.getClauses().get(0).positiveVariables(), hasItem(p)); + } + + static Stream literalsWithVariableOutput() { + var dnfWithOutput = Dnf.builder("WithOutput") + .parameter(p, ParameterDirection.OUT) + .parameter(q, ParameterDirection.OUT) + .clause(friendView.call(p, q)) + .build(); + + return Stream.of( + Arguments.of(friendView.call(p, q)), + Arguments.of(dnfWithOutput.call(p, q)) + ); + } + + @ParameterizedTest + @MethodSource("clausesWithDataVariableInput") + void unboundOutDataVariableTest(List clause) { + var builder = Dnf.builder().parameter(x, ParameterDirection.OUT).clause(clause); + assertThrows(IllegalArgumentException.class, builder::build); + } + + @ParameterizedTest + @MethodSource("clausesWithDataVariableInput") + void unboundInDataVariableTest(List clause) { + var builder = Dnf.builder().parameter(x, ParameterDirection.IN).clause(clause); + var dnf = assertDoesNotThrow(builder::build); + var clauses = dnf.getClauses(); + if (clauses.size() > 0) { + assertThat(clauses.get(0).positiveVariables(), hasItem(x)); + } + } + + @ParameterizedTest + @MethodSource("clausesWithDataVariableInput") + void boundPrivateDataVariableTest(List clause) { + var clauseWithBinding = new ArrayList(clause); + clauseWithBinding.add(x.assign(constant(27))); + var builder = Dnf.builder().clause(clauseWithBinding); + var dnf = assertDoesNotThrow(builder::build); + var clauses = dnf.getClauses(); + if (clauses.size() > 0) { + assertThat(clauses.get(0).positiveVariables(), hasItem(x)); + } + } + + static Stream clausesWithDataVariableInput() { + return Stream.concat( + clausesNotBindingDataVariable(), + literalToClauseArgumentStream(literalsWithRequiredDataVariableInput()) + ); + } + + @ParameterizedTest + @MethodSource("clausesNotBindingDataVariable") + void unboundPrivateDataVariableTest(List clause) { + var builder = Dnf.builder().clause(clause); + var dnf = assertDoesNotThrow(builder::build); + var clauses = dnf.getClauses(); + if (clauses.size() > 0) { + assertThat(clauses.get(0).positiveVariables(), not(hasItem(x))); + } + } + + static Stream clausesNotBindingDataVariable() { + return Stream.concat( + Stream.of( + Arguments.of(List.of()), + Arguments.of(List.of(BooleanLiteral.TRUE)), + Arguments.of(List.of(BooleanLiteral.FALSE)) + ), + literalToClauseArgumentStream(literalsWithPrivateDataVariable()) + ); + } + + @ParameterizedTest + @MethodSource("literalsWithPrivateDataVariable") + void unboundTwicePrivateDataVariableTest(Literal literal) { + var builder = Dnf.builder().clause(not(ageView.call(p, x)), literal); + assertThrows(IllegalArgumentException.class, builder::build); + } + + static Stream literalsWithPrivateDataVariable() { + var dnfWithOutput = Dnf.builder("WithDataOutput") + .parameter(y, ParameterDirection.OUT) + .parameter(q, ParameterDirection.OUT) + .clause(ageView.call(q, y)) + .build(); + + return Stream.of( + Arguments.of(not(ageView.call(q, x))), + Arguments.of(y.assign(ageView.count(q, x))), + Arguments.of(not(dnfWithOutput.call(x, q))) + ); + } + + @ParameterizedTest + @MethodSource("literalsWithRequiredDataVariableInput") + void unboundPrivateDataVariableTest(Literal literal) { + var builder = Dnf.builder().clause(literal); + assertThrows(IllegalArgumentException.class, builder::build); + } + + static Stream literalsWithRequiredDataVariableInput() { + var dnfWithInput = Dnf.builder("WithDataInput") + .parameter(y, ParameterDirection.IN) + .parameter(q, ParameterDirection.OUT) + .clause(ageView.call(q, x)) + .build(); + // We are passing {@code y} to the parameter named {@code right} of {@code greaterEq}. + @SuppressWarnings("SuspiciousNameCombination") + var dnfWithInputToAggregate = Dnf.builder("WithDataInputToAggregate") + .parameter(y, ParameterDirection.IN) + .parameter(q, ParameterDirection.OUT) + .parameter(x, ParameterDirection.OUT) + .clause( + friendView.call(p, q), + ageView.call(q, x), + assume(greaterEq(x, y)) + ) + .build(); + + return Stream.of( + Arguments.of(dnfWithInput.call(x, q)), + Arguments.of(not(dnfWithInput.call(x, q))), + Arguments.of(y.assign(dnfWithInput.count(x, q))), + Arguments.of(y.assign(dnfWithInputToAggregate.aggregate(z, INT_SUM, x, q, z))) + ); + } + + @ParameterizedTest + @MethodSource("literalsWithDataVariableOutput") + void boundDataParameterTest(Literal literal) { + var builder = Dnf.builder().parameter(x, ParameterDirection.OUT).clause(literal); + var dnf = assertDoesNotThrow(builder::build); + assertThat(dnf.getClauses().get(0).positiveVariables(), hasItem(x)); + } + + @ParameterizedTest + @MethodSource("literalsWithDataVariableOutput") + void boundTwiceDataParameterTest(Literal literal) { + var builder = Dnf.builder().parameter(x, ParameterDirection.IN).clause(literal); + assertThrows(IllegalArgumentException.class, builder::build); + } + + @ParameterizedTest + @MethodSource("literalsWithDataVariableOutput") + void boundPrivateDataVariableOutputTest(Literal literal) { + var dnfWithInput = Dnf.builder("WithInput") + .parameter(x, ParameterDirection.IN) + .clause(assume(greaterEq(x, constant(24)))) + .build(); + var builder = Dnf.builder().clause(dnfWithInput.call(x), literal); + var dnf = assertDoesNotThrow(builder::build); + assertThat(dnf.getClauses().get(0).positiveVariables(), hasItem(x)); + } + + @ParameterizedTest + @MethodSource("literalsWithDataVariableOutput") + void boundTwicePrivateDataVariableOutputTest(Literal literal) { + var builder = Dnf.builder().clause(x.assign(constant(27)), literal); + assertThrows(IllegalArgumentException.class, builder::build); + } + + static Stream literalsWithDataVariableOutput() { + var dnfWithOutput = Dnf.builder("WithOutput") + .parameter(q, ParameterDirection.OUT) + .clause(personView.call(q)) + .build(); + var dnfWithDataOutput = Dnf.builder("WithDataOutput") + .parameter(y, ParameterDirection.OUT) + .parameter(q, ParameterDirection.OUT) + .clause(ageView.call(q, y)) + .build(); + var dnfWithOutputToAggregate = Dnf.builder("WithDataOutputToAggregate") + .parameter(q, ParameterDirection.OUT) + .parameter(y, ParameterDirection.OUT) + .clause(ageView.call(q, y)) + .build(); + + return Stream.of( + Arguments.of(x.assign(constant(24))), + Arguments.of(ageView.call(q, x)), + Arguments.of(x.assign(personView.count(q))), + Arguments.of(x.assign(ageView.aggregate(z, INT_SUM, q, z))), + Arguments.of(dnfWithDataOutput.call(x, q)), + Arguments.of(x.assign(dnfWithOutput.count(q))), + Arguments.of(x.assign(dnfWithOutputToAggregate.aggregate(z, INT_SUM, q, z))) + ); + } + + private static Stream literalToClauseArgumentStream(Stream literalArgumentsStream) { + return literalArgumentsStream.map(arguments -> Arguments.of(List.of(arguments.get()[0]))); + } +} diff --git a/subprojects/store-query/src/test/java/tools/refinery/store/query/literal/AggregationLiteralTest.java b/subprojects/store-query/src/test/java/tools/refinery/store/query/literal/AggregationLiteralTest.java new file mode 100644 index 00000000..5293b273 --- /dev/null +++ b/subprojects/store-query/src/test/java/tools/refinery/store/query/literal/AggregationLiteralTest.java @@ -0,0 +1,88 @@ +/* + * SPDX-FileCopyrightText: 2023 The Refinery Authors + * + * SPDX-License-Identifier: EPL-2.0 + */ +package tools.refinery.store.query.literal; + +import org.junit.jupiter.api.Test; +import tools.refinery.store.query.Constraint; +import tools.refinery.store.query.dnf.Dnf; +import tools.refinery.store.query.term.*; + +import java.util.List; +import java.util.Set; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static tools.refinery.store.query.literal.Literals.not; +import static tools.refinery.store.query.term.int_.IntTerms.INT_SUM; +import static tools.refinery.store.query.term.int_.IntTerms.constant; + +class AggregationLiteralTest { + private static final NodeVariable p = Variable.of("p"); + private static final DataVariable x = Variable.of("x", Integer.class); + private static final DataVariable y = Variable.of("y", Integer.class); + private static final DataVariable z = Variable.of("z", Integer.class); + private static final Constraint fakeConstraint = new Constraint() { + @Override + public String name() { + return getClass().getName(); + } + + @Override + public List getParameters() { + return List.of( + new Parameter(null, ParameterDirection.OUT), + new Parameter(Integer.class, ParameterDirection.OUT) + ); + } + }; + + @Test + void parameterDirectionTest() { + var literal = x.assign(fakeConstraint.aggregate(y, INT_SUM, p, y)); + assertAll( + () -> assertThat(literal.getOutputVariables(), containsInAnyOrder(x)), + () -> assertThat(literal.getInputVariables(Set.of()), empty()), + () -> assertThat(literal.getInputVariables(Set.of(p)), containsInAnyOrder(p)), + () -> assertThat(literal.getPrivateVariables(Set.of()), containsInAnyOrder(p, y)), + () -> assertThat(literal.getPrivateVariables(Set.of(p)), containsInAnyOrder(y)) + ); + } + + @Test + void missingAggregationVariableTest() { + var aggregation = fakeConstraint.aggregate(y, INT_SUM, p, z); + assertThrows(IllegalArgumentException.class, () -> x.assign(aggregation)); + } + + @Test + void circularAggregationVariableTest() { + var aggregation = fakeConstraint.aggregate(x, INT_SUM, p, x); + assertThrows(IllegalArgumentException.class, () -> x.assign(aggregation)); + } + + @Test + void unboundTwiceVariableTest() { + var builder = Dnf.builder() + .clause( + not(fakeConstraint.call(p, y)), + x.assign(fakeConstraint.aggregate(y, INT_SUM, p, y)) + ); + assertThrows(IllegalArgumentException.class, builder::build); + } + + @Test + void unboundBoundVariableTest() { + var builder = Dnf.builder() + .clause( + y.assign(constant(27)), + x.assign(fakeConstraint.aggregate(y, INT_SUM, p, y)) + ); + assertThrows(IllegalArgumentException.class, builder::build); + } +} diff --git a/subprojects/store-query/src/test/java/tools/refinery/store/query/literal/CallLiteralTest.java b/subprojects/store-query/src/test/java/tools/refinery/store/query/literal/CallLiteralTest.java new file mode 100644 index 00000000..a01c6586 --- /dev/null +++ b/subprojects/store-query/src/test/java/tools/refinery/store/query/literal/CallLiteralTest.java @@ -0,0 +1,94 @@ +/* + * SPDX-FileCopyrightText: 2023 The Refinery Authors + * + * SPDX-License-Identifier: EPL-2.0 + */ +package tools.refinery.store.query.literal; + +import org.junit.jupiter.api.Test; +import tools.refinery.store.query.Constraint; +import tools.refinery.store.query.term.NodeVariable; +import tools.refinery.store.query.term.Parameter; +import tools.refinery.store.query.term.ParameterDirection; +import tools.refinery.store.query.term.Variable; + +import java.util.List; +import java.util.Set; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.junit.jupiter.api.Assertions.assertAll; +import static tools.refinery.store.query.literal.Literals.not; + +class CallLiteralTest { + private static final NodeVariable p = Variable.of("p"); + private static final NodeVariable q = Variable.of("q"); + private static final NodeVariable r = Variable.of("r"); + private static final NodeVariable s = Variable.of("s"); + + private static final Constraint fakeConstraint = new Constraint() { + @Override + public String name() { + return getClass().getName(); + } + + @Override + public List getParameters() { + return List.of( + new Parameter(null, ParameterDirection.IN), + new Parameter(null, ParameterDirection.IN), + new Parameter(null, ParameterDirection.OUT), + new Parameter(null, ParameterDirection.OUT) + ); + } + }; + + @Test + void notRepeatedPositiveDirectionTest() { + var literal = fakeConstraint.call(p, q, r, s); + assertAll( + () -> assertThat(literal.getOutputVariables(), containsInAnyOrder(r, s)), + () -> assertThat(literal.getInputVariables(Set.of()), containsInAnyOrder(p, q)), + () -> assertThat(literal.getInputVariables(Set.of(p, q, r)), containsInAnyOrder(p, q)), + () -> assertThat(literal.getPrivateVariables(Set.of()), empty()), + () -> assertThat(literal.getPrivateVariables(Set.of(p, q, r)), empty()) + ); + } + + @Test + void notRepeatedNegativeDirectionTest() { + var literal = not(fakeConstraint.call(p, q, r, s)); + assertAll( + () -> assertThat(literal.getOutputVariables(), empty()), + () -> assertThat(literal.getInputVariables(Set.of()), containsInAnyOrder(p, q)), + () -> assertThat(literal.getInputVariables(Set.of(p, q, r)), containsInAnyOrder(p, q, r)), + () -> assertThat(literal.getPrivateVariables(Set.of()), containsInAnyOrder(r, s)), + () -> assertThat(literal.getPrivateVariables(Set.of(p, q, r)), containsInAnyOrder(s)) + ); + } + + @Test + void repeatedPositiveDirectionTest() { + var literal = fakeConstraint.call(p, p, q, q); + assertAll( + () -> assertThat(literal.getOutputVariables(), containsInAnyOrder(q)), + () -> assertThat(literal.getInputVariables(Set.of()), containsInAnyOrder(p)), + () -> assertThat(literal.getInputVariables(Set.of(p, q)), containsInAnyOrder(p)), + () -> assertThat(literal.getPrivateVariables(Set.of()), empty()), + () -> assertThat(literal.getPrivateVariables(Set.of(p, q)), empty()) + ); + } + + @Test + void repeatedNegativeDirectionTest() { + var literal = not(fakeConstraint.call(p, p, q, q)); + assertAll( + () -> assertThat(literal.getOutputVariables(), empty()), + () -> assertThat(literal.getInputVariables(Set.of()), containsInAnyOrder(p)), + () -> assertThat(literal.getInputVariables(Set.of(p, q)), containsInAnyOrder(p, q)), + () -> assertThat(literal.getPrivateVariables(Set.of()), containsInAnyOrder(q)), + () -> assertThat(literal.getPrivateVariables(Set.of(p, q)), empty()) + ); + } +} diff --git a/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/representation/PartialRelation.java b/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/representation/PartialRelation.java index 1f74ce38..6b2f050b 100644 --- a/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/representation/PartialRelation.java +++ b/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/representation/PartialRelation.java @@ -33,7 +33,7 @@ public record PartialRelation(String name, int arity) implements PartialSymbol getParameters() { var parameters = new Parameter[arity]; - Arrays.fill(parameters, Parameter.NODE_IN_OUT); + Arrays.fill(parameters, Parameter.NODE_OUT); return List.of(parameters); } -- cgit v1.2.3-54-g00ecf