From 9666818c58bb4c30ef6b0c88cc680bc559b123c6 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Mon, 10 Jul 2023 18:38:11 +0200 Subject: feat: DNF rewriting * DuplicateDnfRewriter replaces DNF with their canonical representatives * ClauseInputParameterResolver removes input parameters by demand set transformation * CompositeRewriter for rewriter stacks --- .../java/tools/refinery/store/query/dnf/Dnf.java | 7 + .../tools/refinery/store/query/dnf/DnfBuilder.java | 40 +++++- .../refinery/store/query/dnf/FunctionalQuery.java | 5 + .../java/tools/refinery/store/query/dnf/Query.java | 20 +++ .../refinery/store/query/dnf/RelationalQuery.java | 5 + .../SubstitutingLiteralHashCodeHelper.java | 12 ++ .../store/query/literal/AbstractCallLiteral.java | 9 ++ .../store/query/literal/AggregationLiteral.java | 5 + .../refinery/store/query/literal/CallLiteral.java | 5 + .../refinery/store/query/literal/CountLiteral.java | 5 + .../query/rewriter/AbstractRecursiveRewriter.java | 26 ++++ .../rewriter/ClauseInputParameterResolver.java | 160 +++++++++++++++++++++ .../store/query/rewriter/CompositeRewriter.java | 29 ++++ .../refinery/store/query/rewriter/DnfRewriter.java | 24 ++++ .../store/query/rewriter/DuplicateDnfRemover.java | 98 +++++++++++++ .../query/rewriter/InputParameterResolver.java | 51 +++++++ .../tools/refinery/store/query/term/Parameter.java | 6 +- 17 files changed, 503 insertions(+), 4 deletions(-) create mode 100644 subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/AbstractRecursiveRewriter.java create mode 100644 subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/ClauseInputParameterResolver.java create mode 100644 subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/CompositeRewriter.java create mode 100644 subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/DnfRewriter.java create mode 100644 subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/DuplicateDnfRemover.java create mode 100644 subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/InputParameterResolver.java (limited to 'subprojects/store-query/src/main/java') diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/Dnf.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/Dnf.java index e3c8924b..55f1aae5 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/Dnf.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/Dnf.java @@ -215,6 +215,13 @@ public final class Dnf implements Constraint { return new DnfBuilder(name); } + public static DnfBuilder builderFrom(Dnf original) { + var builder = builder(original.name()); + builder.symbolicParameters(original.getSymbolicParameters()); + builder.functionalDependencies(original.getFunctionalDependencies()); + return builder; + } + public static Dnf of(Consumer callback) { return of(null, callback); } diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java index 3d3b5198..0538427f 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java @@ -6,6 +6,9 @@ package tools.refinery.store.query.dnf; import tools.refinery.store.query.dnf.callback.*; +import tools.refinery.store.query.equality.DnfEqualityChecker; +import tools.refinery.store.query.equality.SubstitutingLiteralEqualityHelper; +import tools.refinery.store.query.equality.SubstitutingLiteralHashCodeHelper; import tools.refinery.store.query.literal.Literal; import tools.refinery.store.query.term.*; @@ -223,12 +226,12 @@ public final class DnfBuilder { private List postProcessClauses() { var parameterInfoMap = getParameterInfoMap(); - var postProcessedClauses = new ArrayList(clauses.size()); + var postProcessedClauses = new LinkedHashSet(clauses.size()); for (var literals : clauses) { var postProcessor = new ClausePostProcessor(parameterInfoMap, literals); var result = postProcessor.postProcessClause(); if (result instanceof ClausePostProcessor.ClauseResult clauseResult) { - postProcessedClauses.add(clauseResult.clause()); + postProcessedClauses.add(new CanonicalClause(clauseResult.clause())); } else if (result instanceof ClausePostProcessor.ConstantResult constantResult) { switch (constantResult) { case ALWAYS_TRUE -> { @@ -245,7 +248,7 @@ public final class DnfBuilder { throw new IllegalStateException("Unexpected ClausePostProcessor.Result: " + result); } } - return postProcessedClauses; + return postProcessedClauses.stream().map(CanonicalClause::getDnfClause).toList(); } private Map getParameterInfoMap() { @@ -268,4 +271,35 @@ public final class DnfBuilder { } return Collections.unmodifiableSet(inputParameters); } + + private class CanonicalClause { + private final DnfClause dnfClause; + + public CanonicalClause(DnfClause dnfClause) { + this.dnfClause = dnfClause; + } + + public DnfClause getDnfClause() { + return dnfClause; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + var otherCanonicalClause = (CanonicalClause) obj; + var helper = new SubstitutingLiteralEqualityHelper(DnfEqualityChecker.DEFAULT, parameters, parameters); + return dnfClause.equalsWithSubstitution(helper, otherCanonicalClause.dnfClause); + } + + @Override + public int hashCode() { + var helper = new SubstitutingLiteralHashCodeHelper(parameters); + return dnfClause.hashCodeWithSubstitution(helper); + } + } } diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/FunctionalQuery.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/FunctionalQuery.java index 5a32b1ba..aaebfcc2 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/FunctionalQuery.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/FunctionalQuery.java @@ -54,6 +54,11 @@ public final class FunctionalQuery extends Query { return null; } + @Override + protected Query withDnfInternal(Dnf newDnf) { + return newDnf.asFunction(type); + } + public AssignedValue call(List arguments) { return targetVariable -> { var argumentsWithTarget = new ArrayList(arguments.size() + 1); diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/Query.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/Query.java index aaa52ce6..55f748da 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/Query.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/Query.java @@ -43,6 +43,26 @@ public abstract sealed class Query implements AnyQuery permits FunctionalQuer public abstract T defaultValue(); + public Query withDnf(Dnf newDnf) { + int arity = dnf.arity(); + if (newDnf.arity() != arity) { + throw new IllegalArgumentException("Arity of %s and %s do not match".formatted(dnf, newDnf)); + } + var parameters = dnf.getParameters(); + var newParameters = newDnf.getParameters(); + for (int i = 0; i < arity; i++) { + var parameter = parameters.get(i); + var newParameter = newParameters.get(i); + if (!parameter.matches(newParameter)) { + throw new IllegalArgumentException("Parameter #%d mismatch: %s does not match %s" + .formatted(i, parameter, newParameter)); + } + } + return withDnfInternal(newDnf); + } + + protected abstract Query withDnfInternal(Dnf newDnf); + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/RelationalQuery.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/RelationalQuery.java index d34a7ace..c1892ee1 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/RelationalQuery.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/RelationalQuery.java @@ -40,6 +40,11 @@ public final class RelationalQuery extends Query { return false; } + @Override + protected Query withDnfInternal(Dnf newDnf) { + return newDnf.asRelation(); + } + public CallLiteral call(CallPolarity polarity, List arguments) { return getDnf().call(polarity, Collections.unmodifiableList(arguments)); } diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/equality/SubstitutingLiteralHashCodeHelper.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/equality/SubstitutingLiteralHashCodeHelper.java index a40ecd58..754f6976 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/equality/SubstitutingLiteralHashCodeHelper.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/equality/SubstitutingLiteralHashCodeHelper.java @@ -5,9 +5,11 @@ */ package tools.refinery.store.query.equality; +import tools.refinery.store.query.dnf.SymbolicParameter; import tools.refinery.store.query.term.Variable; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; public class SubstitutingLiteralHashCodeHelper implements LiteralHashCodeHelper { @@ -16,6 +18,16 @@ public class SubstitutingLiteralHashCodeHelper implements LiteralHashCodeHelper // 0 is for {@code null}, so we start with 1. private int next = 1; + public SubstitutingLiteralHashCodeHelper() { + this(List.of()); + } + + public SubstitutingLiteralHashCodeHelper(List parameters) { + for (var parameter : parameters) { + getVariableHashCode(parameter.getVariable()); + } + } + @Override public int getVariableHashCode(Variable variable) { if (variable == null) { diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCallLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCallLiteral.java index 263c2e20..b309f24b 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCallLiteral.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AbstractCallLiteral.java @@ -96,6 +96,15 @@ public abstract class AbstractCallLiteral extends AbstractLiteral { protected abstract Literal doSubstitute(Substitution substitution, List substitutedArguments); + public AbstractCallLiteral withTarget(Constraint newTarget) { + if (Objects.equals(target, newTarget)) { + return this; + } + return internalWithTarget(newTarget); + } + + protected abstract AbstractCallLiteral internalWithTarget(Constraint newTarget); + @Override public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) { if (!super.equalsWithSubstitution(helper, other)) { diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java index 615fd493..dac34332 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/AggregationLiteral.java @@ -91,6 +91,11 @@ public class AggregationLiteral extends AbstractCallLiteral { substitution.getTypeSafeSubstitute(inputVariable), getTarget(), substitutedArguments); } + @Override + protected AbstractCallLiteral internalWithTarget(Constraint newTarget) { + return new AggregationLiteral<>(resultVariable, aggregator, inputVariable, newTarget, getArguments()); + } + @Override public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) { if (!super.equalsWithSubstitution(helper, other)) { diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CallLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CallLiteral.java index a311dada..b1585c77 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CallLiteral.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CallLiteral.java @@ -97,6 +97,11 @@ public final class CallLiteral extends AbstractCallLiteral implements CanNegate< return new CallLiteral(polarity.negate(), getTarget(), getArguments()); } + @Override + protected AbstractCallLiteral internalWithTarget(Constraint newTarget) { + return new CallLiteral(polarity, newTarget, getArguments()); + } + @Override public String toString() { var builder = new StringBuilder(); diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CountLiteral.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CountLiteral.java index ac4b8788..77b77389 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CountLiteral.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CountLiteral.java @@ -61,6 +61,11 @@ public class CountLiteral extends AbstractCallLiteral { return new CountLiteral(substitution.getTypeSafeSubstitute(resultVariable), getTarget(), substitutedArguments); } + @Override + protected AbstractCallLiteral internalWithTarget(Constraint newTarget) { + return new CountLiteral(resultVariable, newTarget, getArguments()); + } + @Override public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) { if (!super.equalsWithSubstitution(helper, other)) { diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/AbstractRecursiveRewriter.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/AbstractRecursiveRewriter.java new file mode 100644 index 00000000..fb4c14a7 --- /dev/null +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/AbstractRecursiveRewriter.java @@ -0,0 +1,26 @@ +/* + * SPDX-FileCopyrightText: 2023 The Refinery Authors + * + * SPDX-License-Identifier: EPL-2.0 + */ +package tools.refinery.store.query.rewriter; + +import tools.refinery.store.query.dnf.Dnf; +import tools.refinery.store.query.equality.DnfEqualityChecker; +import tools.refinery.store.util.CycleDetectingMapper; + +public abstract class AbstractRecursiveRewriter implements DnfRewriter { + private final CycleDetectingMapper mapper = new CycleDetectingMapper<>(Dnf::name, this::map); + + @Override + public Dnf rewrite(Dnf dnf) { + return mapper.map(dnf); + } + + protected Dnf map(Dnf dnf) { + var result = doRewrite(dnf); + return dnf.equalsWithSubstitution(DnfEqualityChecker.DEFAULT, result) ? dnf : result; + } + + protected abstract Dnf doRewrite(Dnf dnf); +} diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/ClauseInputParameterResolver.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/ClauseInputParameterResolver.java new file mode 100644 index 00000000..bdd07f19 --- /dev/null +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/ClauseInputParameterResolver.java @@ -0,0 +1,160 @@ +/* + * SPDX-FileCopyrightText: 2023 The Refinery Authors + * + * SPDX-License-Identifier: EPL-2.0 + */ +package tools.refinery.store.query.rewriter; + +import tools.refinery.store.query.dnf.Dnf; +import tools.refinery.store.query.dnf.DnfClause; +import tools.refinery.store.query.literal.*; +import tools.refinery.store.query.substitution.Substitution; +import tools.refinery.store.query.term.ParameterDirection; +import tools.refinery.store.query.term.Variable; + +import java.util.*; + +class ClauseInputParameterResolver { + private final InputParameterResolver rewriter; + private final String dnfName; + private final int clauseIndex; + private final Set positiveVariables = new LinkedHashSet<>(); + private final List inlinedLiterals = new ArrayList<>(); + private final Deque workList; + private int helperIndex = 0; + + public ClauseInputParameterResolver(InputParameterResolver rewriter, List context, DnfClause clause, + String dnfName, int clauseIndex) { + this.rewriter = rewriter; + this.dnfName = dnfName; + this.clauseIndex = clauseIndex; + workList = new ArrayDeque<>(clause.literals().size() + context.size()); + for (var literal : context) { + workList.addLast(literal); + } + for (var literal : clause.literals()) { + workList.addLast(literal); + } + } + + public List rewriteClause() { + while (!workList.isEmpty()) { + var literal = workList.removeFirst(); + processLiteral(literal); + } + return inlinedLiterals; + } + + private void processLiteral(Literal literal) { + if (!(literal instanceof AbstractCallLiteral abstractCallLiteral) || + !(abstractCallLiteral.getTarget() instanceof Dnf targetDnf)) { + markAsDone(literal); + return; + } + boolean hasInputParameter = hasInputParameter(targetDnf); + if (!hasInputParameter) { + targetDnf = rewriter.doRewrite(targetDnf); + } + if (inlinePositiveClause(abstractCallLiteral, targetDnf)) { + return; + } + if (eliminateDoubleNegation(abstractCallLiteral, targetDnf)) { + return; + } + if (hasInputParameter) { + rewriteWithCurrentContext(abstractCallLiteral, targetDnf); + return; + } + markAsDone(abstractCallLiteral.withTarget(targetDnf)); + } + + private void markAsDone(Literal literal) { + positiveVariables.addAll(literal.getOutputVariables()); + inlinedLiterals.add(literal); + } + + private boolean inlinePositiveClause(AbstractCallLiteral abstractCallLiteral, Dnf targetDnf) { + var targetLiteral = getSingleLiteral(abstractCallLiteral, targetDnf, CallPolarity.POSITIVE); + if (targetLiteral == null) { + return false; + } + var substitution = asSubstitution(abstractCallLiteral, targetDnf); + var substitutedLiteral = targetLiteral.substitute(substitution); + workList.addFirst(substitutedLiteral); + return true; + } + + private boolean eliminateDoubleNegation(AbstractCallLiteral abstractCallLiteral, Dnf targetDnf) { + var targetLiteral = getSingleLiteral(abstractCallLiteral, targetDnf, CallPolarity.NEGATIVE); + if (!(targetLiteral instanceof CallLiteral targetCallLiteral) || + targetCallLiteral.getPolarity() != CallPolarity.NEGATIVE) { + return false; + } + var substitution = asSubstitution(abstractCallLiteral, targetDnf); + var substitutedLiteral = (CallLiteral) targetCallLiteral.substitute(substitution); + workList.addFirst(substitutedLiteral.negate()); + return true; + } + + private void rewriteWithCurrentContext(AbstractCallLiteral abstractCallLiteral, Dnf targetDnf) { + var contextBuilder = Dnf.builder("%s#clause%d#helper%d".formatted(dnfName, clauseIndex, helperIndex)); + helperIndex++; + contextBuilder.parameters(positiveVariables, ParameterDirection.OUT); + contextBuilder.clause(inlinedLiterals); + var contextDnf = contextBuilder.build(); + var contextCall = new CallLiteral(CallPolarity.POSITIVE, contextDnf, List.copyOf(positiveVariables)); + inlinedLiterals.clear(); + var substitution = Substitution.builder().renewing().build(); + var context = new ArrayList(); + context.add(contextCall.substitute(substitution)); + int arity = targetDnf.arity(); + for (int i = 0; i < arity; i++) { + var parameter = targetDnf.getSymbolicParameters().get(i).getVariable(); + var argument = abstractCallLiteral.getArguments().get(i); + context.add(new EquivalenceLiteral(true, parameter, substitution.getSubstitute(argument))); + } + var rewrittenDnf = rewriter.rewriteWithContext(context, targetDnf); + workList.addFirst(abstractCallLiteral.withTarget(rewrittenDnf)); + workList.addFirst(contextCall); + } + + private static boolean hasInputParameter(Dnf targetDnf) { + for (var parameter : targetDnf.getParameters()) { + if (parameter.getDirection() != ParameterDirection.OUT) { + return true; + } + } + return false; + } + + private static Literal getSingleLiteral(AbstractCallLiteral abstractCallLiteral, Dnf targetDnf, + CallPolarity polarity) { + if (!(abstractCallLiteral instanceof CallLiteral callLiteral) || + callLiteral.getPolarity() != polarity) { + return null; + } + var clauses = targetDnf.getClauses(); + if (clauses.size() != 1) { + return null; + } + var targetLiterals = clauses.get(0).literals(); + if (targetLiterals.size() != 1) { + return null; + } + return targetLiterals.get(0); + } + + private static Substitution asSubstitution(AbstractCallLiteral callLiteral, Dnf targetDnf) { + var builder = Substitution.builder().renewing(); + var arguments = callLiteral.getArguments(); + var parameters = targetDnf.getSymbolicParameters(); + int arity = arguments.size(); + if (parameters.size() != arity) { + throw new IllegalArgumentException("Call %s of %s arity mismatch".formatted(callLiteral, targetDnf)); + } + for (int i = 0; i < arity; i++) { + builder.putChecked(parameters.get(i).getVariable(), arguments.get(i)); + } + return builder.build(); + } +} diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/CompositeRewriter.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/CompositeRewriter.java new file mode 100644 index 00000000..5b4f65e5 --- /dev/null +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/CompositeRewriter.java @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: 2023 The Refinery Authors + * + * SPDX-License-Identifier: EPL-2.0 + */ +package tools.refinery.store.query.rewriter; + +import tools.refinery.store.query.dnf.Dnf; + +import java.util.ArrayList; +import java.util.List; + +public class CompositeRewriter implements DnfRewriter { + private final List rewriterList = new ArrayList<>(); + + public void addFirst(DnfRewriter rewriter) { + rewriterList.add(rewriter); + } + + @Override + public Dnf rewrite(Dnf dnf) { + Dnf rewrittenDnf = dnf; + for (int i = rewriterList.size() - 1; i >= 0; i--) { + var rewriter = rewriterList.get(i); + rewrittenDnf = rewriter.rewrite(rewrittenDnf); + } + return rewrittenDnf; + } +} diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/DnfRewriter.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/DnfRewriter.java new file mode 100644 index 00000000..5d8359d1 --- /dev/null +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/DnfRewriter.java @@ -0,0 +1,24 @@ +/* + * SPDX-FileCopyrightText: 2023 The Refinery Authors + * + * SPDX-License-Identifier: EPL-2.0 + */ +package tools.refinery.store.query.rewriter; + +import tools.refinery.store.query.dnf.AnyQuery; +import tools.refinery.store.query.dnf.Dnf; +import tools.refinery.store.query.dnf.Query; + +@FunctionalInterface +public interface DnfRewriter { + Dnf rewrite(Dnf dnf); + + default AnyQuery rewrite(AnyQuery query) { + return rewrite((Query) query); + } + + default Query rewrite(Query query) { + var rewrittenDnf = rewrite(query.getDnf()); + return query.withDnf(rewrittenDnf); + } +} diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/DuplicateDnfRemover.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/DuplicateDnfRemover.java new file mode 100644 index 00000000..0c786470 --- /dev/null +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/DuplicateDnfRemover.java @@ -0,0 +1,98 @@ +/* + * SPDX-FileCopyrightText: 2023 The Refinery Authors + * + * SPDX-License-Identifier: EPL-2.0 + */ +package tools.refinery.store.query.rewriter; + +import tools.refinery.store.query.dnf.Dnf; +import tools.refinery.store.query.dnf.DnfClause; +import tools.refinery.store.query.dnf.Query; +import tools.refinery.store.query.equality.DnfEqualityChecker; +import tools.refinery.store.query.literal.AbstractCallLiteral; +import tools.refinery.store.query.literal.Literal; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class DuplicateDnfRemover extends AbstractRecursiveRewriter { + private final Map dnfCache = new HashMap<>(); + private final Map> queryCache = new HashMap<>(); + + @Override + protected Dnf map(Dnf dnf) { + var result = super.map(dnf); + return dnfCache.computeIfAbsent(new CanonicalDnf(result), CanonicalDnf::getDnf); + } + + @Override + protected Dnf doRewrite(Dnf dnf) { + var builder = Dnf.builderFrom(dnf); + for (var clause : dnf.getClauses()) { + builder.clause(rewriteClause(clause)); + } + return builder.build(); + } + + private List rewriteClause(DnfClause clause) { + var originalLiterals = clause.literals(); + var literals = new ArrayList(originalLiterals.size()); + for (var literal : originalLiterals) { + var rewrittenLiteral = literal; + if (literal instanceof AbstractCallLiteral abstractCallLiteral && + abstractCallLiteral.getTarget() instanceof Dnf targetDnf) { + var rewrittenTarget = rewrite(targetDnf); + rewrittenLiteral = abstractCallLiteral.withTarget(rewrittenTarget); + } + literals.add(rewrittenLiteral); + } + return literals; + } + + @Override + public Query rewrite(Query query) { + var rewrittenDnf = rewrite(query.getDnf()); + // {@code withDnf} will always return the appropriate type. + @SuppressWarnings("unchecked") + var rewrittenQuery = (Query) queryCache.computeIfAbsent(rewrittenDnf, query::withDnf); + return rewrittenQuery; + } + + private static class CanonicalDnf { + private final Dnf dnf; + private final int hash; + + public CanonicalDnf(Dnf dnf) { + this.dnf = dnf; + hash = dnf.hashCodeWithSubstitution(); + } + + public Dnf getDnf() { + return dnf; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + var otherCanonicalDnf = (CanonicalDnf) obj; + return dnf.equalsWithSubstitution(DnfEqualityChecker.DEFAULT, otherCanonicalDnf.dnf); + } + + @Override + public int hashCode() { + return hash; + } + + @Override + public String toString() { + return dnf.name(); + } + } +} diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/InputParameterResolver.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/InputParameterResolver.java new file mode 100644 index 00000000..cd8a2e7d --- /dev/null +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/rewriter/InputParameterResolver.java @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: 2023 The Refinery Authors + * + * SPDX-License-Identifier: EPL-2.0 + */ +package tools.refinery.store.query.rewriter; + +import tools.refinery.store.query.dnf.Dnf; +import tools.refinery.store.query.dnf.DnfBuilder; +import tools.refinery.store.query.literal.Literal; +import tools.refinery.store.query.term.ParameterDirection; +import tools.refinery.store.query.term.Variable; + +import java.util.HashSet; +import java.util.List; + +public class InputParameterResolver extends AbstractRecursiveRewriter { + @Override + protected Dnf doRewrite(Dnf dnf) { + return rewriteWithContext(List.of(), dnf); + } + + Dnf rewriteWithContext(List context, Dnf dnf) { + var dnfName = dnf.name(); + var builder = Dnf.builder(dnfName); + createSymbolicParameters(context, dnf, builder); + builder.functionalDependencies(dnf.getFunctionalDependencies()); + var clauses = dnf.getClauses(); + int clauseCount = clauses.size(); + for (int i = 0; i < clauseCount; i++) { + var clause = clauses.get(i); + var clauseRewriter = new ClauseInputParameterResolver(this, context, clause, dnfName, i); + builder.clause(clauseRewriter.rewriteClause()); + } + return builder.build(); + } + + private static void createSymbolicParameters(List context, Dnf dnf, DnfBuilder builder) { + var positiveInContext = new HashSet(); + for (var literal : context) { + positiveInContext.addAll(literal.getOutputVariables()); + } + for (var symbolicParameter : dnf.getSymbolicParameters()) { + var variable = symbolicParameter.getVariable(); + var isOutput = symbolicParameter.getDirection() == ParameterDirection.OUT || + positiveInContext.contains(variable); + var direction = isOutput ? ParameterDirection.OUT : ParameterDirection.IN; + builder.parameter(variable, direction); + } + } +} diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/term/Parameter.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/term/Parameter.java index e5a0cdf1..dbb76177 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/term/Parameter.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/term/Parameter.java @@ -35,6 +35,10 @@ public class Parameter { return direction; } + public boolean matches(Parameter other) { + return Objects.equals(dataType, other.dataType) && direction == other.direction; + } + public boolean isAssignable(Variable variable) { if (variable instanceof AnyDataVariable dataVariable) { return dataVariable.getType().equals(dataType); @@ -50,7 +54,7 @@ public class Parameter { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Parameter parameter = (Parameter) o; - return Objects.equals(dataType, parameter.dataType) && direction == parameter.direction; + return matches(parameter); } @Override -- cgit v1.2.3-54-g00ecf