aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/Dnf2PQuery.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/Dnf2PQuery.java')
-rw-r--r--subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/Dnf2PQuery.java201
1 files changed, 126 insertions, 75 deletions
diff --git a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/Dnf2PQuery.java b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/Dnf2PQuery.java
index 201e0ed0..7afeb977 100644
--- a/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/Dnf2PQuery.java
+++ b/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/Dnf2PQuery.java
@@ -1,45 +1,44 @@
1package tools.refinery.store.query.viatra.internal.pquery; 1package tools.refinery.store.query.viatra.internal.pquery;
2 2
3import org.eclipse.viatra.query.runtime.matchers.backend.IQueryBackendFactory;
3import org.eclipse.viatra.query.runtime.matchers.backend.QueryEvaluationHint; 4import org.eclipse.viatra.query.runtime.matchers.backend.QueryEvaluationHint;
4import org.eclipse.viatra.query.runtime.matchers.context.IInputKey; 5import org.eclipse.viatra.query.runtime.matchers.context.IInputKey;
5import org.eclipse.viatra.query.runtime.matchers.psystem.PBody; 6import org.eclipse.viatra.query.runtime.matchers.psystem.PBody;
6import org.eclipse.viatra.query.runtime.matchers.psystem.PVariable; 7import org.eclipse.viatra.query.runtime.matchers.psystem.PVariable;
8import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.BoundAggregator;
9import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.IMultisetAggregationOperator;
7import org.eclipse.viatra.query.runtime.matchers.psystem.annotations.PAnnotation; 10import org.eclipse.viatra.query.runtime.matchers.psystem.annotations.PAnnotation;
8import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.Equality; 11import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.*;
9import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.ExportedParameter;
10import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.Inequality;
11import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.NegativePatternCall;
12import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.BinaryTransitiveClosure; 12import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.BinaryTransitiveClosure;
13import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.ConstantValue; 13import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.ConstantValue;
14import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.PositivePatternCall; 14import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.PositivePatternCall;
15import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.TypeConstraint; 15import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.TypeConstraint;
16import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PParameter; 16import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PParameter;
17import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PVisibility; 17import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PQuery;
18import org.eclipse.viatra.query.runtime.matchers.tuple.Tuple; 18import org.eclipse.viatra.query.runtime.matchers.tuple.Tuple;
19import org.eclipse.viatra.query.runtime.matchers.tuple.Tuples; 19import org.eclipse.viatra.query.runtime.matchers.tuple.Tuples;
20import tools.refinery.store.query.Dnf; 20import tools.refinery.store.query.dnf.Dnf;
21import tools.refinery.store.query.DnfClause; 21import tools.refinery.store.query.dnf.DnfClause;
22import tools.refinery.store.query.DnfUtils;
23import tools.refinery.store.query.Variable;
24import tools.refinery.store.query.literal.*; 22import tools.refinery.store.query.literal.*;
23import tools.refinery.store.query.term.ConstantTerm;
24import tools.refinery.store.query.term.StatefulAggregator;
25import tools.refinery.store.query.term.StatelessAggregator;
26import tools.refinery.store.query.term.Variable;
25import tools.refinery.store.query.view.AnyRelationView; 27import tools.refinery.store.query.view.AnyRelationView;
26import tools.refinery.store.util.CycleDetectingMapper; 28import tools.refinery.store.util.CycleDetectingMapper;
27 29
28import java.util.*; 30import java.util.*;
29import java.util.function.Function; 31import java.util.function.Function;
32import java.util.stream.Collectors;
30 33
31public class Dnf2PQuery { 34public class Dnf2PQuery {
32 private static final Object P_CONSTRAINT_LOCK = new Object(); 35 private static final Object P_CONSTRAINT_LOCK = new Object();
33
34 private final CycleDetectingMapper<Dnf, RawPQuery> mapper = new CycleDetectingMapper<>(Dnf::name, 36 private final CycleDetectingMapper<Dnf, RawPQuery> mapper = new CycleDetectingMapper<>(Dnf::name,
35 this::doTranslate); 37 this::doTranslate);
36 38 private final QueryWrapperFactory wrapperFactory = new QueryWrapperFactory(this);
37 private final Map<AnyRelationView, RelationViewWrapper> view2WrapperMap = new LinkedHashMap<>(); 39 private final Map<Dnf, QueryEvaluationHint> hintOverrides = new LinkedHashMap<>();
38
39 private final Map<AnyRelationView, RawPQuery> view2EmbeddedMap = new HashMap<>();
40
41 private Function<Dnf, QueryEvaluationHint> computeHint = dnf -> new QueryEvaluationHint(null, 40 private Function<Dnf, QueryEvaluationHint> computeHint = dnf -> new QueryEvaluationHint(null,
42 QueryEvaluationHint.BackendRequirement.UNSPECIFIED); 41 (IQueryBackendFactory) null);
43 42
44 public void setComputeHint(Function<Dnf, QueryEvaluationHint> computeHint) { 43 public void setComputeHint(Function<Dnf, QueryEvaluationHint> computeHint) {
45 this.computeHint = computeHint; 44 this.computeHint = computeHint;
@@ -50,16 +49,33 @@ public class Dnf2PQuery {
50 } 49 }
51 50
52 public Map<AnyRelationView, IInputKey> getRelationViews() { 51 public Map<AnyRelationView, IInputKey> getRelationViews() {
53 return Collections.unmodifiableMap(view2WrapperMap); 52 return wrapperFactory.getRelationViews();
54 } 53 }
55 54
56 public RawPQuery getAlreadyTranslated(Dnf dnfQuery) { 55 public void hint(Dnf dnf, QueryEvaluationHint hint) {
57 return mapper.getAlreadyMapped(dnfQuery); 56 hintOverrides.compute(dnf, (ignoredKey, existingHint) ->
57 existingHint == null ? hint : existingHint.overrideBy(hint));
58 }
59
60 private QueryEvaluationHint consumeHint(Dnf dnf) {
61 var defaultHint = computeHint.apply(dnf);
62 var existingHint = hintOverrides.remove(dnf);
63 return defaultHint.overrideBy(existingHint);
64 }
65
66 public void assertNoUnusedHints() {
67 if (hintOverrides.isEmpty()) {
68 return;
69 }
70 var unusedHints = hintOverrides.keySet().stream().map(Dnf::name).collect(Collectors.joining(", "));
71 throw new IllegalStateException(
72 "Unused query evaluation hints for %s. Hints must be set before a query is added to the engine"
73 .formatted(unusedHints));
58 } 74 }
59 75
60 private RawPQuery doTranslate(Dnf dnfQuery) { 76 private RawPQuery doTranslate(Dnf dnfQuery) {
61 var pQuery = new RawPQuery(dnfQuery.getUniqueName()); 77 var pQuery = new RawPQuery(dnfQuery.getUniqueName());
62 pQuery.setEvaluationHints(computeHint.apply(dnfQuery)); 78 pQuery.setEvaluationHints(consumeHint(dnfQuery));
63 79
64 Map<Variable, PParameter> parameters = new HashMap<>(); 80 Map<Variable, PParameter> parameters = new HashMap<>();
65 for (Variable variable : dnfQuery.getParameters()) { 81 for (Variable variable : dnfQuery.getParameters()) {
@@ -97,7 +113,7 @@ public class Dnf2PQuery {
97 body.setSymbolicParameters(symbolicParameters); 113 body.setSymbolicParameters(symbolicParameters);
98 pQuery.addBody(body); 114 pQuery.addBody(body);
99 for (Literal literal : clause.literals()) { 115 for (Literal literal : clause.literals()) {
100 translateLiteral(literal, body); 116 translateLiteral(literal, clause, body);
101 } 117 }
102 } 118 }
103 } 119 }
@@ -105,15 +121,21 @@ public class Dnf2PQuery {
105 return pQuery; 121 return pQuery;
106 } 122 }
107 123
108 private void translateLiteral(Literal literal, PBody body) { 124 private void translateLiteral(Literal literal, DnfClause clause, PBody body) {
109 if (literal instanceof EquivalenceLiteral equivalenceLiteral) { 125 if (literal instanceof EquivalenceLiteral equivalenceLiteral) {
110 translateEquivalenceLiteral(equivalenceLiteral, body); 126 translateEquivalenceLiteral(equivalenceLiteral, body);
111 } else if (literal instanceof RelationViewLiteral relationViewLiteral) { 127 } else if (literal instanceof CallLiteral callLiteral) {
112 translateRelationViewLiteral(relationViewLiteral, body); 128 translateCallLiteral(callLiteral, clause, body);
113 } else if (literal instanceof DnfCallLiteral dnfCallLiteral) {
114 translateDnfCallLiteral(dnfCallLiteral, body);
115 } else if (literal instanceof ConstantLiteral constantLiteral) { 129 } else if (literal instanceof ConstantLiteral constantLiteral) {
116 translateConstantLiteral(constantLiteral, body); 130 translateConstantLiteral(constantLiteral, body);
131 } else if (literal instanceof AssignLiteral<?> assignLiteral) {
132 translateAssignLiteral(assignLiteral, body);
133 } else if (literal instanceof AssumeLiteral assumeLiteral) {
134 translateAssumeLiteral(assumeLiteral, body);
135 } else if (literal instanceof CountLiteral countLiteral) {
136 translateCountLiteral(countLiteral, clause, body);
137 } else if (literal instanceof AggregationLiteral<?, ?> aggregationLiteral) {
138 translateAggregationLiteral(aggregationLiteral, clause, body);
117 } else { 139 } else {
118 throw new IllegalArgumentException("Unknown literal: " + literal.toString()); 140 throw new IllegalArgumentException("Unknown literal: " + literal.toString());
119 } 141 }
@@ -129,20 +151,43 @@ public class Dnf2PQuery {
129 } 151 }
130 } 152 }
131 153
132 private void translateRelationViewLiteral(RelationViewLiteral relationViewLiteral, PBody body) { 154 private void translateCallLiteral(CallLiteral callLiteral, DnfClause clause, PBody body) {
133 var substitution = translateSubstitution(relationViewLiteral.getArguments(), body); 155 var polarity = callLiteral.getPolarity();
134 var polarity = relationViewLiteral.getPolarity(); 156 switch (polarity) {
135 var relationView = relationViewLiteral.getTarget(); 157 case POSITIVE -> {
136 if (polarity == CallPolarity.POSITIVE) { 158 var substitution = translateSubstitution(callLiteral.getArguments(), body);
137 new TypeConstraint(body, substitution, wrapView(relationView)); 159 var constraint = callLiteral.getTarget();
138 } else { 160 if (constraint instanceof Dnf dnf) {
139 var embeddedPQuery = translateEmbeddedRelationViewPQuery(relationView); 161 var pattern = translate(dnf);
140 switch (polarity) { 162 new PositivePatternCall(body, substitution, pattern);
141 case TRANSITIVE -> new BinaryTransitiveClosure(body, substitution, embeddedPQuery); 163 } else if (constraint instanceof AnyRelationView relationView) {
142 case NEGATIVE -> new NegativePatternCall(body, substitution, embeddedPQuery); 164 var inputKey = wrapperFactory.getInputKey(relationView);
143 default -> throw new IllegalArgumentException("Unknown polarity: " + polarity); 165 new TypeConstraint(body, substitution, inputKey);
166 } else {
167 throw new IllegalArgumentException("Unknown Constraint: " + constraint);
144 } 168 }
145 } 169 }
170 case TRANSITIVE -> {
171 var substitution = translateSubstitution(callLiteral.getArguments(), body);
172 var constraint = callLiteral.getTarget();
173 PQuery pattern;
174 if (constraint instanceof Dnf dnf) {
175 pattern = translate(dnf);
176 } else if (constraint instanceof AnyRelationView relationView) {
177 pattern = wrapperFactory.wrapRelationViewIdentityArguments(relationView);
178 } else {
179 throw new IllegalArgumentException("Unknown Constraint: " + constraint);
180 }
181 new BinaryTransitiveClosure(body, substitution, pattern);
182 }
183 case NEGATIVE -> {
184 var wrappedCall = wrapperFactory.maybeWrapConstraint(callLiteral, clause);
185 var substitution = translateSubstitution(wrappedCall.remappedArguments(), body);
186 var pattern = wrappedCall.pattern();
187 new NegativePatternCall(body, substitution, pattern);
188 }
189 default -> throw new IllegalArgumentException("Unknown polarity: " + polarity);
190 }
146 } 191 }
147 192
148 private static Tuple translateSubstitution(List<Variable> substitution, PBody body) { 193 private static Tuple translateSubstitution(List<Variable> substitution, PBody body) {
@@ -155,51 +200,57 @@ public class Dnf2PQuery {
155 return Tuples.flatTupleOf(variables); 200 return Tuples.flatTupleOf(variables);
156 } 201 }
157 202
158 private RawPQuery translateEmbeddedRelationViewPQuery(AnyRelationView relationView) { 203 private void translateConstantLiteral(ConstantLiteral constantLiteral, PBody body) {
159 return view2EmbeddedMap.computeIfAbsent(relationView, this::doTranslateEmbeddedRelationViewPQuery); 204 var variable = body.getOrCreateVariableByName(constantLiteral.variable().getUniqueName());
205 new ConstantValue(body, variable, constantLiteral.nodeId());
160 } 206 }
161 207
162 private RawPQuery doTranslateEmbeddedRelationViewPQuery(AnyRelationView relationView) { 208 private <T> void translateAssignLiteral(AssignLiteral<T> assignLiteral, PBody body) {
163 var embeddedPQuery = new RawPQuery(DnfUtils.generateUniqueName(relationView.name()), PVisibility.EMBEDDED); 209 var variable = body.getOrCreateVariableByName(assignLiteral.variable().getUniqueName());
164 var body = new PBody(embeddedPQuery); 210 var term = assignLiteral.term();
165 int arity = relationView.arity(); 211 if (term instanceof ConstantTerm<T> constantTerm) {
166 var parameters = new ArrayList<PParameter>(arity); 212 new ConstantValue(body, variable, constantTerm.getValue());
167 var arguments = new Object[arity]; 213 } else {
168 var symbolicParameters = new ArrayList<ExportedParameter>(arity); 214 var evaluator = new TermEvaluator<>(term);
169 for (int i = 0; i < arity; i++) { 215 new ExpressionEvaluation(body, evaluator, variable);
170 var parameterName = "p" + i;
171 var parameter = new PParameter(parameterName);
172 parameters.add(parameter);
173 var variable = body.getOrCreateVariableByName(parameterName);
174 arguments[i] = variable;
175 symbolicParameters.add(new ExportedParameter(body, variable, parameter));
176 } 216 }
177 embeddedPQuery.setParameters(parameters);
178 body.setSymbolicParameters(symbolicParameters);
179 var argumentTuple = Tuples.flatTupleOf(arguments);
180 new TypeConstraint(body, argumentTuple, wrapView(relationView));
181 embeddedPQuery.addBody(body);
182 return embeddedPQuery;
183 } 217 }
184 218
185 private RelationViewWrapper wrapView(AnyRelationView relationView) { 219 private void translateAssumeLiteral(AssumeLiteral assumeLiteral, PBody body) {
186 return view2WrapperMap.computeIfAbsent(relationView, RelationViewWrapper::new); 220 var evaluator = new AssumptionEvaluator(assumeLiteral.term());
221 new ExpressionEvaluation(body, evaluator, null);
187 } 222 }
188 223
189 private void translateDnfCallLiteral(DnfCallLiteral dnfCallLiteral, PBody body) { 224 private void translateCountLiteral(CountLiteral countLiteral, DnfClause clause, PBody body) {
190 var variablesTuple = translateSubstitution(dnfCallLiteral.getArguments(), body); 225 var wrappedCall = wrapperFactory.maybeWrapConstraint(countLiteral, clause);
191 var translatedReferred = translate(dnfCallLiteral.getTarget()); 226 var substitution = translateSubstitution(wrappedCall.remappedArguments(), body);
192 var polarity = dnfCallLiteral.getPolarity(); 227 var resultVariable = body.getOrCreateVariableByName(countLiteral.getResultVariable().getUniqueName());
193 switch (polarity) { 228 new PatternMatchCounter(body, substitution, wrappedCall.pattern(), resultVariable);
194 case POSITIVE -> new PositivePatternCall(body, variablesTuple, translatedReferred);
195 case TRANSITIVE -> new BinaryTransitiveClosure(body, variablesTuple, translatedReferred);
196 case NEGATIVE -> new NegativePatternCall(body, variablesTuple, translatedReferred);
197 default -> throw new IllegalArgumentException("Unknown polarity: " + polarity);
198 }
199 } 229 }
200 230
201 private void translateConstantLiteral(ConstantLiteral constantLiteral, PBody body) { 231 private <R, T> void translateAggregationLiteral(AggregationLiteral<R, T> aggregationLiteral, DnfClause clause,
202 var variable = body.getOrCreateVariableByName(constantLiteral.variable().getUniqueName()); 232 PBody body) {
203 new ConstantValue(body, variable, constantLiteral.nodeId()); 233 var aggregator = aggregationLiteral.getAggregator();
234 IMultisetAggregationOperator<T, ?, R> aggregationOperator;
235 if (aggregator instanceof StatelessAggregator<R, T> statelessAggregator) {
236 aggregationOperator = new StatelessMultisetAggregator<>(statelessAggregator);
237 } else if (aggregator instanceof StatefulAggregator<R, T> statefulAggregator) {
238 aggregationOperator = new StatefulMultisetAggregator<>(statefulAggregator);
239 } else {
240 throw new IllegalArgumentException("Unknown aggregator: " + aggregator);
241 }
242 var wrappedCall = wrapperFactory.maybeWrapConstraint(aggregationLiteral, clause);
243 var substitution = translateSubstitution(wrappedCall.remappedArguments(), body);
244 var inputVariable = body.getOrCreateVariableByName(aggregationLiteral.getInputVariable().getUniqueName());
245 var aggregatedColumn = substitution.invertIndex().get(inputVariable);
246 if (aggregatedColumn == null) {
247 throw new IllegalStateException("Input variable %s not found in substitution %s".formatted(inputVariable,
248 substitution));
249 }
250 var boundAggregator = new BoundAggregator(aggregationOperator, aggregator.getInputType(),
251 aggregator.getResultType());
252 var resultVariable = body.getOrCreateVariableByName(aggregationLiteral.getResultVariable().getUniqueName());
253 new AggregatorConstraint(boundAggregator, body, substitution, wrappedCall.pattern(), resultVariable,
254 aggregatedColumn);
204 } 255 }
205} 256}