From 960af83c7c1cb871da03b9ac4ec6f44c94e78a1d Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Sun, 30 Oct 2022 19:27:34 -0400 Subject: refactor: DNF atoms Restore count != capability. Still needs semantics and tests for count atoms over partial models. --- .../query/viatra/ViatraQueryableModelStore.java | 30 +++++--- .../internal/pquery/CountExpressionEvaluator.java | 38 ++++++++++ .../pquery/CountNotEqualsExpressionEvaluator.java | 30 ++++++++ .../query/viatra/internal/pquery/DNF2PQuery.java | 84 +++++++++++++++------- 4 files changed, 146 insertions(+), 36 deletions(-) create mode 100644 subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/CountExpressionEvaluator.java create mode 100644 subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/CountNotEqualsExpressionEvaluator.java (limited to 'subprojects/store-query-viatra/src/main/java') diff --git a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/ViatraQueryableModelStore.java b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/ViatraQueryableModelStore.java index 702eb659..59fb1171 100644 --- a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/ViatraQueryableModelStore.java +++ b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/ViatraQueryableModelStore.java @@ -4,12 +4,10 @@ import org.eclipse.viatra.query.runtime.api.GenericQuerySpecification; import tools.refinery.store.model.ModelDiffCursor; import tools.refinery.store.model.ModelStore; import tools.refinery.store.model.ModelStoreImpl; +import tools.refinery.store.model.RelationLike; import tools.refinery.store.model.representation.DataRepresentation; import tools.refinery.store.query.*; -import tools.refinery.store.query.atom.DNFAtom; -import tools.refinery.store.query.atom.DNFCallAtom; -import tools.refinery.store.query.atom.EquivalenceAtom; -import tools.refinery.store.query.atom.RelationViewAtom; +import tools.refinery.store.query.atom.*; import tools.refinery.store.query.viatra.internal.RawPatternMatcher; import tools.refinery.store.query.viatra.internal.ViatraQueryableModel; import tools.refinery.store.query.viatra.internal.pquery.DNF2PQuery; @@ -57,9 +55,11 @@ public class ViatraQueryableModelStore implements QueryableModelStore { for (DNFAtom atom : clause.constraints()) { if (atom instanceof RelationViewAtom relationViewAtom) { validateRelationAtom(relationViews, dnfPredicate, relationViewAtom); - } else if (atom instanceof DNFCallAtom queryCallAtom) { + } else if (atom instanceof CallAtom queryCallAtom) { validatePredicateAtom(predicates, dnfPredicate, queryCallAtom); - } else if (!(atom instanceof EquivalenceAtom)) { + } else if (atom instanceof CountNotEqualsAtom countNotEqualsAtom) { + validateCountNotEqualsAtom(predicates, dnfPredicate, countNotEqualsAtom); + } else if (!(atom instanceof EquivalenceAtom || atom instanceof ConstantAtom)) { throw new IllegalArgumentException("Unknown constraint: " + atom.toString()); } } @@ -77,16 +77,24 @@ public class ViatraQueryableModelStore implements QueryableModelStore { } } - private void validatePredicateAtom(Set predicates, DNF dnfPredicate, - DNFCallAtom queryCallAtom) { - if (!predicates.contains(queryCallAtom.getTarget())) { + private void validatePredicateReference(Set predicates, DNF dnfPredicate, RelationLike target) { + if (!(target instanceof DNF dnfTarget) || !predicates.contains(dnfTarget)) { throw new IllegalArgumentException( "%s %s contains reference to a predicate %s that is not in the model.".formatted( - DNF.class.getSimpleName(), dnfPredicate.getUniqueName(), - queryCallAtom.getTarget().getName())); + DNF.class.getSimpleName(), dnfPredicate.getUniqueName(), target.getName())); } } + private void validatePredicateAtom(Set predicates, DNF dnfPredicate, CallAtom queryCallAtom) { + validatePredicateReference(predicates, dnfPredicate, queryCallAtom.getTarget()); + } + + private void validateCountNotEqualsAtom(Set predicates, DNF dnfPredicate, + CountNotEqualsAtom countNotEqualsAtom) { + validatePredicateReference(predicates, dnfPredicate, countNotEqualsAtom.mayTarget()); + validatePredicateReference(predicates, dnfPredicate, countNotEqualsAtom.mustTarget()); + } + private Map> initPredicates(Set predicates) { Map> result = new HashMap<>(); var dnf2PQuery = new DNF2PQuery(); diff --git a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/CountExpressionEvaluator.java b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/CountExpressionEvaluator.java new file mode 100644 index 00000000..6fc96c05 --- /dev/null +++ b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/CountExpressionEvaluator.java @@ -0,0 +1,38 @@ +package tools.refinery.store.query.viatra.internal.pquery; + +import org.eclipse.viatra.query.runtime.matchers.psystem.IExpressionEvaluator; +import org.eclipse.viatra.query.runtime.matchers.psystem.IValueProvider; +import tools.refinery.store.query.atom.ComparisonOperator; +import tools.refinery.store.query.atom.CountCallKind; + +import java.util.List; + +public record CountExpressionEvaluator(String variableName, ComparisonOperator operator, + int threshold) implements IExpressionEvaluator { + public CountExpressionEvaluator(String variableName, CountCallKind callKind) { + this(variableName, callKind.operator(), callKind.threshold()); + } + + @Override + public String getShortDescription() { + return "%s %s %d".formatted(variableName, operator, threshold); + } + + @Override + public Iterable getInputParameterNames() { + return List.of(variableName); + } + + @Override + public Object evaluateExpression(IValueProvider provider) { + int value = (Integer) provider.getValue(variableName); + return switch (operator) { + case EQUALS -> value == threshold; + case NOT_EQUALS -> value != threshold; + case LESS -> value < threshold; + case LESS_EQUALS -> value <= threshold; + case GREATER -> value > threshold; + case GREATER_EQUALS -> value >= threshold; + }; + } +} diff --git a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/CountNotEqualsExpressionEvaluator.java b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/CountNotEqualsExpressionEvaluator.java new file mode 100644 index 00000000..6f333a06 --- /dev/null +++ b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/CountNotEqualsExpressionEvaluator.java @@ -0,0 +1,30 @@ +package tools.refinery.store.query.viatra.internal.pquery; + +import org.eclipse.viatra.query.runtime.matchers.psystem.IExpressionEvaluator; +import org.eclipse.viatra.query.runtime.matchers.psystem.IValueProvider; + +import java.util.List; + +public record CountNotEqualsExpressionEvaluator(boolean must, int threshold, String mayVariableName, + String mustVariableName) implements IExpressionEvaluator { + @Override + public String getShortDescription() { + return "%d %s not in [%s; %s]".formatted(threshold, must ? "must" : "may", mustVariableName, mayVariableName); + } + + @Override + public Iterable getInputParameterNames() { + return List.of(mayVariableName, mustVariableName); + } + + @Override + public Object evaluateExpression(IValueProvider provider) throws Exception { + int mayCount = (Integer) provider.getValue(mayVariableName); + int mustCount = (Integer) provider.getValue(mustVariableName); + if (must) { + return mayCount < threshold || mustCount > threshold; + } else { + return mayCount > threshold || mustCount < threshold; + } + } +} diff --git a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/DNF2PQuery.java b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/DNF2PQuery.java index e3c586a0..61b984ae 100644 --- a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/DNF2PQuery.java +++ b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/DNF2PQuery.java @@ -2,20 +2,16 @@ package tools.refinery.store.query.viatra.internal.pquery; import org.eclipse.viatra.query.runtime.matchers.psystem.PBody; import org.eclipse.viatra.query.runtime.matchers.psystem.PVariable; -import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.Equality; -import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.ExportedParameter; -import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.Inequality; -import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.NegativePatternCall; +import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.*; import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.BinaryTransitiveClosure; +import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.ConstantValue; import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.PositivePatternCall; import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.TypeConstraint; import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PParameter; +import org.eclipse.viatra.query.runtime.matchers.tuple.Tuple; import org.eclipse.viatra.query.runtime.matchers.tuple.Tuples; import tools.refinery.store.query.*; -import tools.refinery.store.query.atom.DNFAtom; -import tools.refinery.store.query.atom.EquivalenceAtom; -import tools.refinery.store.query.atom.DNFCallAtom; -import tools.refinery.store.query.atom.RelationViewAtom; +import tools.refinery.store.query.atom.*; import tools.refinery.store.query.view.RelationView; import java.util.*; @@ -85,8 +81,12 @@ public class DNF2PQuery { translateEquivalenceAtom(equivalenceAtom, body); } else if (constraint instanceof RelationViewAtom relationViewAtom) { translateRelationViewAtom(relationViewAtom, body); - } else if (constraint instanceof DNFCallAtom dnfCallAtom) { - translateDNFCallAtom(dnfCallAtom, body); + } else if (constraint instanceof CallAtom callAtom) { + translateCallAtom(callAtom, body); + } else if (constraint instanceof ConstantAtom constantAtom) { + translateConstantAtom(constantAtom, body); + } else if (constraint instanceof CountNotEqualsAtom countNotEqualsAtom) { + translateCountNotEqualsAtom(countNotEqualsAtom, body); } else { throw new IllegalArgumentException("Unknown constraint: " + constraint.toString()); } @@ -103,32 +103,66 @@ public class DNF2PQuery { } private void translateRelationViewAtom(RelationViewAtom relationViewAtom, PBody body) { - int arity = relationViewAtom.getSubstitution().size(); + new TypeConstraint(body, translateSubstitution(relationViewAtom.getSubstitution(), body), + wrapView(relationViewAtom.getTarget())); + } + + private static Tuple translateSubstitution(List substitution, PBody body) { + int arity = substitution.size(); Object[] variables = new Object[arity]; for (int i = 0; i < arity; i++) { - var variable = relationViewAtom.getSubstitution().get(i); + var variable = substitution.get(i); variables[i] = body.getOrCreateVariableByName(variable.getUniqueName()); } - new TypeConstraint(body, Tuples.flatTupleOf(variables), wrapView(relationViewAtom.getTarget())); + return Tuples.flatTupleOf(variables); } private RelationViewWrapper wrapView(RelationView relationView) { return view2WrapperMap.computeIfAbsent(relationView, RelationViewWrapper::new); } - private void translateDNFCallAtom(DNFCallAtom queryCallAtom, PBody body) { - int arity = queryCallAtom.getSubstitution().size(); - Object[] variables = new Object[arity]; - for (int i = 0; i < arity; i++) { - var variable = queryCallAtom.getSubstitution().get(i); - variables[i] = body.getOrCreateVariableByName(variable.getUniqueName()); + private void translateCallAtom(CallAtom callAtom, PBody body) { + if (!(callAtom.getTarget() instanceof DNF target)) { + throw new IllegalArgumentException("Only calls to DNF are supported"); + } + var variablesTuple = translateSubstitution(callAtom.getSubstitution(), body); + var translatedReferred = translate(target); + var callKind = callAtom.getKind(); + if (callKind instanceof BasicCallKind basicCallKind) { + switch (basicCallKind) { + case POSITIVE -> new PositivePatternCall(body, variablesTuple, translatedReferred); + case TRANSITIVE -> new BinaryTransitiveClosure(body, variablesTuple, translatedReferred); + case NEGATIVE -> new NegativePatternCall(body, variablesTuple, translatedReferred); + default -> throw new IllegalArgumentException("Unknown BasicCallKind: " + basicCallKind); + } + } else if (callKind instanceof CountCallKind countCallKind) { + var countVariableName = DNFUtils.generateUniqueName("count"); + var countPVariable = body.getOrCreateVariableByName(countVariableName); + new PatternMatchCounter(body, variablesTuple, translatedReferred, countPVariable); + new ExpressionEvaluation(body, new CountExpressionEvaluator(countVariableName, countCallKind), null); + } else { + throw new IllegalArgumentException("Unknown CallKind: " + callKind); } - var variablesTuple = Tuples.flatTupleOf(variables); - var translatedReferred = translate(queryCallAtom.getTarget()); - switch (queryCallAtom.getKind()) { - case POSITIVE -> new PositivePatternCall(body, variablesTuple, translatedReferred); - case TRANSITIVE -> new BinaryTransitiveClosure(body, variablesTuple, translatedReferred); - case NEGATIVE -> new NegativePatternCall(body, variablesTuple, translatedReferred); + } + + private void translateConstantAtom(ConstantAtom constantAtom, PBody body) { + var variable = body.getOrCreateVariableByName(constantAtom.variable().getUniqueName()); + new ConstantValue(body, variable, constantAtom.nodeId()); + } + + private void translateCountNotEqualsAtom(CountNotEqualsAtom countNotEqualsAtom, PBody body) { + if (!(countNotEqualsAtom.mayTarget() instanceof DNF mayTarget) || + !(countNotEqualsAtom.mustTarget() instanceof DNF mustTarget)) { + throw new IllegalArgumentException("Only calls to DNF are supported"); } + var variablesTuple = translateSubstitution(countNotEqualsAtom.substitution(), body); + var mayCountName = DNFUtils.generateUniqueName("countMay"); + var mayCountVariable = body.getOrCreateVariableByName(mayCountName); + new PatternMatchCounter(body, variablesTuple, translate(mayTarget), mayCountVariable); + var mustCountName = DNFUtils.generateUniqueName("countMust"); + var mustCountVariable = body.getOrCreateVariableByName(mustCountName); + new PatternMatchCounter(body, variablesTuple, translate(mustTarget), mustCountVariable); + new ExpressionEvaluation(body, new CountNotEqualsExpressionEvaluator(countNotEqualsAtom.must(), + countNotEqualsAtom.threshold(), mayCountName, mustCountName), null); } } -- cgit v1.2.3-54-g00ecf