From 4e698774925468062974b990143c1091e23ed63b Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Mon, 1 May 2023 02:07:23 +0200 Subject: feat: query parameter binding validation * Introduce parameter directions for constraints and DNF * Introduce variable directions for literals * Infer and check variable directions in DNF and topologically sort literals by their input variables --- .../context/RelationalQueryMetaContext.java | 13 ++-- .../query/viatra/internal/pquery/Dnf2PQuery.java | 70 ++++++++++++---------- .../internal/pquery/QueryWrapperFactory.java | 35 +++++++---- 3 files changed, 67 insertions(+), 51 deletions(-) (limited to 'subprojects/store-query-viatra/src/main/java') diff --git a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/context/RelationalQueryMetaContext.java b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/context/RelationalQueryMetaContext.java index cf96b7fd..211eacb4 100644 --- a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/context/RelationalQueryMetaContext.java +++ b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/context/RelationalQueryMetaContext.java @@ -9,7 +9,6 @@ import org.eclipse.viatra.query.runtime.matchers.context.AbstractQueryMetaContex import org.eclipse.viatra.query.runtime.matchers.context.IInputKey; import org.eclipse.viatra.query.runtime.matchers.context.InputKeyImplication; import org.eclipse.viatra.query.runtime.matchers.context.common.JavaTransitiveInstancesKey; -import tools.refinery.store.query.term.DataSort; import tools.refinery.store.query.viatra.internal.pquery.SymbolViewWrapper; import tools.refinery.store.query.view.AnySymbolView; @@ -62,14 +61,14 @@ public class RelationalQueryMetaContext extends AbstractQueryMetaContext { relationViewImplication.impliedIndices())); } } - var sorts = symbolView.getSorts(); + var parameters = symbolView.getParameters(); int arity = symbolView.arity(); for (int i = 0; i < arity; i++) { - var sort = sorts.get(i); - if (sort instanceof DataSort dataSort) { - var javaTransitiveInstancesKey = new JavaTransitiveInstancesKey(dataSort.type()); - var javaImplication = new InputKeyImplication(implyingKey, javaTransitiveInstancesKey, - List.of(i)); + var parameter = parameters.get(i); + var parameterType = parameter.tryGetType(); + if (parameterType.isPresent()) { + var javaTransitiveInstancesKey = new JavaTransitiveInstancesKey(parameterType.get()); + var javaImplication = new InputKeyImplication(implyingKey, javaTransitiveInstancesKey, List.of(i)); inputKeyImplications.add(javaImplication); } } diff --git a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/Dnf2PQuery.java b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/Dnf2PQuery.java index b511a5c7..ec880435 100644 --- a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/Dnf2PQuery.java +++ b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/Dnf2PQuery.java @@ -19,11 +19,13 @@ import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.Consta import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.PositivePatternCall; import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.TypeConstraint; import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PParameter; +import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PParameterDirection; import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PQuery; import org.eclipse.viatra.query.runtime.matchers.tuple.Tuple; import org.eclipse.viatra.query.runtime.matchers.tuple.Tuples; import tools.refinery.store.query.dnf.Dnf; import tools.refinery.store.query.dnf.DnfClause; +import tools.refinery.store.query.dnf.SymbolicParameter; import tools.refinery.store.query.literal.*; import tools.refinery.store.query.term.ConstantTerm; import tools.refinery.store.query.term.StatefulAggregator; @@ -82,15 +84,20 @@ public class Dnf2PQuery { var pQuery = new RawPQuery(dnfQuery.getUniqueName()); pQuery.setEvaluationHints(consumeHint(dnfQuery)); - Map parameters = new HashMap<>(); - for (Variable variable : dnfQuery.getParameters()) { - parameters.put(variable, new PParameter(variable.getUniqueName())); - } - + Map parameters = new HashMap<>(); List parameterList = new ArrayList<>(); - for (var param : dnfQuery.getParameters()) { - parameterList.add(parameters.get(param)); + for (var parameter : dnfQuery.getSymbolicParameters()) { + var direction = switch (parameter.getDirection()) { + case IN_OUT -> PParameterDirection.INOUT; + case OUT -> PParameterDirection.OUT; + case IN -> throw new IllegalArgumentException("Query %s with input parameter %s is not supported" + .formatted(dnfQuery, parameter.getVariable())); + }; + var pParameter = new PParameter(parameter.getVariable().getUniqueName(), null, null, direction); + parameters.put(parameter, pParameter); + parameterList.add(pParameter); } + pQuery.setParameters(parameterList); for (var functionalDependency : dnfQuery.getFunctionalDependencies()) { @@ -110,15 +117,15 @@ public class Dnf2PQuery { synchronized (P_CONSTRAINT_LOCK) { for (DnfClause clause : dnfQuery.getClauses()) { PBody body = new PBody(pQuery); - List symbolicParameters = new ArrayList<>(); - for (var param : dnfQuery.getParameters()) { - PVariable pVar = body.getOrCreateVariableByName(param.getUniqueName()); - symbolicParameters.add(new ExportedParameter(body, pVar, parameters.get(param))); + List parameterExports = new ArrayList<>(); + for (var parameter : dnfQuery.getSymbolicParameters()) { + PVariable pVar = body.getOrCreateVariableByName(parameter.getVariable().getUniqueName()); + parameterExports.add(new ExportedParameter(body, pVar, parameters.get(parameter))); } - body.setSymbolicParameters(symbolicParameters); + body.setSymbolicParameters(parameterExports); pQuery.addBody(body); for (Literal literal : clause.literals()) { - translateLiteral(literal, clause, body); + translateLiteral(literal, body); } } } @@ -126,11 +133,11 @@ public class Dnf2PQuery { return pQuery; } - private void translateLiteral(Literal literal, DnfClause clause, PBody body) { + private void translateLiteral(Literal literal, PBody body) { if (literal instanceof EquivalenceLiteral equivalenceLiteral) { translateEquivalenceLiteral(equivalenceLiteral, body); } else if (literal instanceof CallLiteral callLiteral) { - translateCallLiteral(callLiteral, clause, body); + translateCallLiteral(callLiteral, body); } else if (literal instanceof ConstantLiteral constantLiteral) { translateConstantLiteral(constantLiteral, body); } else if (literal instanceof AssignLiteral assignLiteral) { @@ -138,25 +145,25 @@ public class Dnf2PQuery { } else if (literal instanceof AssumeLiteral assumeLiteral) { translateAssumeLiteral(assumeLiteral, body); } else if (literal instanceof CountLiteral countLiteral) { - translateCountLiteral(countLiteral, clause, body); + translateCountLiteral(countLiteral, body); } else if (literal instanceof AggregationLiteral aggregationLiteral) { - translateAggregationLiteral(aggregationLiteral, clause, body); + translateAggregationLiteral(aggregationLiteral, body); } else { throw new IllegalArgumentException("Unknown literal: " + literal.toString()); } } private void translateEquivalenceLiteral(EquivalenceLiteral equivalenceLiteral, PBody body) { - PVariable varSource = body.getOrCreateVariableByName(equivalenceLiteral.left().getUniqueName()); - PVariable varTarget = body.getOrCreateVariableByName(equivalenceLiteral.right().getUniqueName()); - if (equivalenceLiteral.positive()) { + PVariable varSource = body.getOrCreateVariableByName(equivalenceLiteral.getLeft().getUniqueName()); + PVariable varTarget = body.getOrCreateVariableByName(equivalenceLiteral.getRight().getUniqueName()); + if (equivalenceLiteral.isPositive()) { new Equality(body, varSource, varTarget); } else { new Inequality(body, varSource, varTarget); } } - private void translateCallLiteral(CallLiteral callLiteral, DnfClause clause, PBody body) { + private void translateCallLiteral(CallLiteral callLiteral, PBody body) { var polarity = callLiteral.getPolarity(); switch (polarity) { case POSITIVE -> { @@ -186,7 +193,7 @@ public class Dnf2PQuery { new BinaryTransitiveClosure(body, substitution, pattern); } case NEGATIVE -> { - var wrappedCall = wrapperFactory.maybeWrapConstraint(callLiteral, clause); + var wrappedCall = wrapperFactory.maybeWrapConstraint(callLiteral); var substitution = translateSubstitution(wrappedCall.remappedArguments(), body); var pattern = wrappedCall.pattern(); new NegativePatternCall(body, substitution, pattern); @@ -206,13 +213,13 @@ public class Dnf2PQuery { } private void translateConstantLiteral(ConstantLiteral constantLiteral, PBody body) { - var variable = body.getOrCreateVariableByName(constantLiteral.variable().getUniqueName()); - new ConstantValue(body, variable, constantLiteral.nodeId()); + var variable = body.getOrCreateVariableByName(constantLiteral.getVariable().getUniqueName()); + new ConstantValue(body, variable, constantLiteral.getNodeId()); } private void translateAssignLiteral(AssignLiteral assignLiteral, PBody body) { - var variable = body.getOrCreateVariableByName(assignLiteral.variable().getUniqueName()); - var term = assignLiteral.term(); + var variable = body.getOrCreateVariableByName(assignLiteral.getTargetVariable().getUniqueName()); + var term = assignLiteral.getTerm(); if (term instanceof ConstantTerm constantTerm) { new ConstantValue(body, variable, constantTerm.getValue()); } else { @@ -222,19 +229,18 @@ public class Dnf2PQuery { } private void translateAssumeLiteral(AssumeLiteral assumeLiteral, PBody body) { - var evaluator = new AssumptionEvaluator(assumeLiteral.term()); + var evaluator = new AssumptionEvaluator(assumeLiteral.getTerm()); new ExpressionEvaluation(body, evaluator, null); } - private void translateCountLiteral(CountLiteral countLiteral, DnfClause clause, PBody body) { - var wrappedCall = wrapperFactory.maybeWrapConstraint(countLiteral, clause); + private void translateCountLiteral(CountLiteral countLiteral, PBody body) { + var wrappedCall = wrapperFactory.maybeWrapConstraint(countLiteral); var substitution = translateSubstitution(wrappedCall.remappedArguments(), body); var resultVariable = body.getOrCreateVariableByName(countLiteral.getResultVariable().getUniqueName()); new PatternMatchCounter(body, substitution, wrappedCall.pattern(), resultVariable); } - private void translateAggregationLiteral(AggregationLiteral aggregationLiteral, DnfClause clause, - PBody body) { + private void translateAggregationLiteral(AggregationLiteral aggregationLiteral, PBody body) { var aggregator = aggregationLiteral.getAggregator(); IMultisetAggregationOperator aggregationOperator; if (aggregator instanceof StatelessAggregator statelessAggregator) { @@ -244,7 +250,7 @@ public class Dnf2PQuery { } else { throw new IllegalArgumentException("Unknown aggregator: " + aggregator); } - var wrappedCall = wrapperFactory.maybeWrapConstraint(aggregationLiteral, clause); + var wrappedCall = wrapperFactory.maybeWrapConstraint(aggregationLiteral); var substitution = translateSubstitution(wrappedCall.remappedArguments(), body); var inputVariable = body.getOrCreateVariableByName(aggregationLiteral.getInputVariable().getUniqueName()); var aggregatedColumn = substitution.invertIndex().get(inputVariable); diff --git a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/QueryWrapperFactory.java b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/QueryWrapperFactory.java index 0d046455..2b7280f2 100644 --- a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/QueryWrapperFactory.java +++ b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/QueryWrapperFactory.java @@ -14,12 +14,13 @@ import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.TypeCo import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PParameter; import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PQuery; import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PVisibility; +import org.eclipse.viatra.query.runtime.matchers.tuple.Tuple; import org.eclipse.viatra.query.runtime.matchers.tuple.Tuples; import tools.refinery.store.query.Constraint; import tools.refinery.store.query.dnf.Dnf; -import tools.refinery.store.query.dnf.DnfClause; import tools.refinery.store.query.dnf.DnfUtils; import tools.refinery.store.query.literal.AbstractCallLiteral; +import tools.refinery.store.query.term.ParameterDirection; import tools.refinery.store.query.term.Variable; import tools.refinery.store.query.view.AnySymbolView; import tools.refinery.store.query.view.SymbolView; @@ -45,21 +46,17 @@ class QueryWrapperFactory { } return maybeWrapConstraint(symbolView, identity); } - public WrappedCall maybeWrapConstraint(AbstractCallLiteral callLiteral, DnfClause clause) { + + public WrappedCall maybeWrapConstraint(AbstractCallLiteral callLiteral) { var arguments = callLiteral.getArguments(); int arity = arguments.size(); var remappedParameters = new int[arity]; - var boundVariables = clause.boundVariables(); var unboundVariableIndices = new HashMap(); var appendVariable = new VariableAppender(); for (int i = 0; i < arity; i++) { var variable = arguments.get(i); - if (boundVariables.contains(variable)) { - // Do not join bound variable to make sure that the embedded pattern stays as general as possible. - remappedParameters[i] = appendVariable.applyAsInt(variable); - } else { - remappedParameters[i] = unboundVariableIndices.computeIfAbsent(variable, appendVariable::applyAsInt); - } + // Unify all variables to avoid VIATRA bugs, even if they're bound in the containing clause. + remappedParameters[i] = unboundVariableIndices.computeIfAbsent(variable, appendVariable::applyAsInt); } var pattern = maybeWrapConstraint(callLiteral.getTarget(), remappedParameters); return new WrappedCall(pattern, appendVariable.getRemappedArguments()); @@ -89,6 +86,8 @@ class QueryWrapperFactory { var constraint = remappedConstraint.constraint(); var remappedParameters = remappedConstraint.remappedParameters(); + checkNoInputParameters(constraint); + var embeddedPQuery = new RawPQuery(DnfUtils.generateUniqueName(constraint.name()), PVisibility.EMBEDDED); var body = new PBody(embeddedPQuery); int arity = Arrays.stream(remappedParameters).max().orElse(-1) + 1; @@ -112,6 +111,21 @@ class QueryWrapperFactory { } var argumentTuple = Tuples.flatTupleOf(arguments); + addPositiveConstraint(constraint, body, argumentTuple); + embeddedPQuery.addBody(body); + return embeddedPQuery; + } + + private static void checkNoInputParameters(Constraint constraint) { + for (var constraintParameter : constraint.getParameters()) { + if (constraintParameter.getDirection() == ParameterDirection.IN) { + throw new IllegalArgumentException("Input parameter %s of %s is not supported" + .formatted(constraintParameter, constraint)); + } + } + } + + private void addPositiveConstraint(Constraint constraint, PBody body, Tuple argumentTuple) { if (constraint instanceof SymbolView view) { new TypeConstraint(body, argumentTuple, getInputKey(view)); } else if (constraint instanceof Dnf dnf) { @@ -120,9 +134,6 @@ class QueryWrapperFactory { } else { throw new IllegalArgumentException("Unknown Constraint: " + constraint); } - - embeddedPQuery.addBody(body); - return embeddedPQuery; } public IInputKey getInputKey(AnySymbolView symbolView) { -- cgit v1.2.3-54-g00ecf