diff options
Diffstat (limited to 'subprojects/store-query-interpreter/src/main/java/tools/refinery/store/query/interpreter/internal/pquery/Dnf2PQuery.java')
-rw-r--r-- | subprojects/store-query-interpreter/src/main/java/tools/refinery/store/query/interpreter/internal/pquery/Dnf2PQuery.java | 253 |
1 files changed, 253 insertions, 0 deletions
diff --git a/subprojects/store-query-interpreter/src/main/java/tools/refinery/store/query/interpreter/internal/pquery/Dnf2PQuery.java b/subprojects/store-query-interpreter/src/main/java/tools/refinery/store/query/interpreter/internal/pquery/Dnf2PQuery.java new file mode 100644 index 00000000..73ce4043 --- /dev/null +++ b/subprojects/store-query-interpreter/src/main/java/tools/refinery/store/query/interpreter/internal/pquery/Dnf2PQuery.java | |||
@@ -0,0 +1,253 @@ | |||
1 | /* | ||
2 | * SPDX-FileCopyrightText: 2021-2023 The Refinery Authors <https://refinery.tools/> | ||
3 | * | ||
4 | * SPDX-License-Identifier: EPL-2.0 | ||
5 | */ | ||
6 | package tools.refinery.store.query.interpreter.internal.pquery; | ||
7 | |||
8 | import tools.refinery.interpreter.matchers.psystem.basicdeferred.*; | ||
9 | import tools.refinery.interpreter.matchers.psystem.basicenumerables.*; | ||
10 | import tools.refinery.interpreter.matchers.psystem.basicenumerables.Connectivity; | ||
11 | import tools.refinery.store.query.Constraint; | ||
12 | import tools.refinery.store.query.dnf.Dnf; | ||
13 | import tools.refinery.store.query.dnf.DnfClause; | ||
14 | import tools.refinery.store.query.dnf.SymbolicParameter; | ||
15 | import tools.refinery.store.query.literal.*; | ||
16 | import tools.refinery.store.query.term.ConstantTerm; | ||
17 | import tools.refinery.store.query.term.StatefulAggregator; | ||
18 | import tools.refinery.store.query.term.StatelessAggregator; | ||
19 | import tools.refinery.store.query.term.Variable; | ||
20 | import tools.refinery.store.query.view.AnySymbolView; | ||
21 | import tools.refinery.store.util.CycleDetectingMapper; | ||
22 | import tools.refinery.interpreter.matchers.backend.IQueryBackendFactory; | ||
23 | import tools.refinery.interpreter.matchers.backend.QueryEvaluationHint; | ||
24 | import tools.refinery.interpreter.matchers.context.IInputKey; | ||
25 | import tools.refinery.interpreter.matchers.psystem.PBody; | ||
26 | import tools.refinery.interpreter.matchers.psystem.PVariable; | ||
27 | import tools.refinery.interpreter.matchers.psystem.aggregations.BoundAggregator; | ||
28 | import tools.refinery.interpreter.matchers.psystem.aggregations.IMultisetAggregationOperator; | ||
29 | import tools.refinery.interpreter.matchers.psystem.annotations.PAnnotation; | ||
30 | import tools.refinery.interpreter.matchers.psystem.queries.PParameter; | ||
31 | import tools.refinery.interpreter.matchers.psystem.queries.PParameterDirection; | ||
32 | import tools.refinery.interpreter.matchers.psystem.queries.PQuery; | ||
33 | import tools.refinery.interpreter.matchers.tuple.Tuple; | ||
34 | import tools.refinery.interpreter.matchers.tuple.Tuples; | ||
35 | |||
36 | import java.util.ArrayList; | ||
37 | import java.util.HashMap; | ||
38 | import java.util.List; | ||
39 | import java.util.Map; | ||
40 | import java.util.function.Function; | ||
41 | |||
42 | public class Dnf2PQuery { | ||
43 | private final CycleDetectingMapper<Dnf, RawPQuery> mapper = new CycleDetectingMapper<>(Dnf::name, | ||
44 | this::doTranslate); | ||
45 | private final QueryWrapperFactory wrapperFactory = new QueryWrapperFactory(this); | ||
46 | private Function<Dnf, QueryEvaluationHint> computeHint = dnf -> new QueryEvaluationHint(null, | ||
47 | (IQueryBackendFactory) null); | ||
48 | |||
49 | public void setComputeHint(Function<Dnf, QueryEvaluationHint> computeHint) { | ||
50 | this.computeHint = computeHint; | ||
51 | } | ||
52 | |||
53 | public RawPQuery translate(Dnf dnfQuery) { | ||
54 | return mapper.map(dnfQuery); | ||
55 | } | ||
56 | |||
57 | public Map<AnySymbolView, IInputKey> getSymbolViews() { | ||
58 | return wrapperFactory.getSymbolViews(); | ||
59 | } | ||
60 | |||
61 | private RawPQuery doTranslate(Dnf dnfQuery) { | ||
62 | var pQuery = new RawPQuery(dnfQuery.getUniqueName()); | ||
63 | pQuery.setEvaluationHints(computeHint.apply(dnfQuery)); | ||
64 | |||
65 | Map<SymbolicParameter, PParameter> parameters = new HashMap<>(); | ||
66 | List<PParameter> parameterList = new ArrayList<>(); | ||
67 | for (var parameter : dnfQuery.getSymbolicParameters()) { | ||
68 | var direction = switch (parameter.getDirection()) { | ||
69 | case OUT -> PParameterDirection.INOUT; | ||
70 | case IN -> throw new IllegalArgumentException("Query %s with input parameter %s is not supported" | ||
71 | .formatted(dnfQuery, parameter.getVariable())); | ||
72 | }; | ||
73 | var pParameter = new PParameter(parameter.getVariable().getUniqueName(), null, null, direction); | ||
74 | parameters.put(parameter, pParameter); | ||
75 | parameterList.add(pParameter); | ||
76 | } | ||
77 | |||
78 | pQuery.setParameters(parameterList); | ||
79 | |||
80 | for (var functionalDependency : dnfQuery.getFunctionalDependencies()) { | ||
81 | var functionalDependencyAnnotation = new PAnnotation("FunctionalDependency"); | ||
82 | for (var forEachVariable : functionalDependency.forEach()) { | ||
83 | functionalDependencyAnnotation.addAttribute("forEach", forEachVariable.getUniqueName()); | ||
84 | } | ||
85 | for (var uniqueVariable : functionalDependency.unique()) { | ||
86 | functionalDependencyAnnotation.addAttribute("unique", uniqueVariable.getUniqueName()); | ||
87 | } | ||
88 | pQuery.addAnnotation(functionalDependencyAnnotation); | ||
89 | } | ||
90 | |||
91 | for (DnfClause clause : dnfQuery.getClauses()) { | ||
92 | PBody body = new PBody(pQuery); | ||
93 | List<ExportedParameter> parameterExports = new ArrayList<>(); | ||
94 | for (var parameter : dnfQuery.getSymbolicParameters()) { | ||
95 | PVariable pVar = body.getOrCreateVariableByName(parameter.getVariable().getUniqueName()); | ||
96 | parameterExports.add(new ExportedParameter(body, pVar, parameters.get(parameter))); | ||
97 | } | ||
98 | body.setSymbolicParameters(parameterExports); | ||
99 | pQuery.addBody(body); | ||
100 | for (Literal literal : clause.literals()) { | ||
101 | translateLiteral(literal, body); | ||
102 | } | ||
103 | } | ||
104 | |||
105 | return pQuery; | ||
106 | } | ||
107 | |||
108 | private void translateLiteral(Literal literal, PBody body) { | ||
109 | if (literal instanceof EquivalenceLiteral equivalenceLiteral) { | ||
110 | translateEquivalenceLiteral(equivalenceLiteral, body); | ||
111 | } else if (literal instanceof CallLiteral callLiteral) { | ||
112 | translateCallLiteral(callLiteral, body); | ||
113 | } else if (literal instanceof ConstantLiteral constantLiteral) { | ||
114 | translateConstantLiteral(constantLiteral, body); | ||
115 | } else if (literal instanceof AssignLiteral<?> assignLiteral) { | ||
116 | translateAssignLiteral(assignLiteral, body); | ||
117 | } else if (literal instanceof CheckLiteral checkLiteral) { | ||
118 | translateCheckLiteral(checkLiteral, body); | ||
119 | } else if (literal instanceof CountLiteral countLiteral) { | ||
120 | translateCountLiteral(countLiteral, body); | ||
121 | } else if (literal instanceof AggregationLiteral<?, ?> aggregationLiteral) { | ||
122 | translateAggregationLiteral(aggregationLiteral, body); | ||
123 | } else if (literal instanceof RepresentativeElectionLiteral representativeElectionLiteral) { | ||
124 | translateRepresentativeElectionLiteral(representativeElectionLiteral, body); | ||
125 | } else { | ||
126 | throw new IllegalArgumentException("Unknown literal: " + literal.toString()); | ||
127 | } | ||
128 | } | ||
129 | |||
130 | private void translateEquivalenceLiteral(EquivalenceLiteral equivalenceLiteral, PBody body) { | ||
131 | PVariable varSource = body.getOrCreateVariableByName(equivalenceLiteral.getLeft().getUniqueName()); | ||
132 | PVariable varTarget = body.getOrCreateVariableByName(equivalenceLiteral.getRight().getUniqueName()); | ||
133 | if (equivalenceLiteral.isPositive()) { | ||
134 | new Equality(body, varSource, varTarget); | ||
135 | } else { | ||
136 | new Inequality(body, varSource, varTarget); | ||
137 | } | ||
138 | } | ||
139 | |||
140 | private void translateCallLiteral(CallLiteral callLiteral, PBody body) { | ||
141 | var polarity = callLiteral.getPolarity(); | ||
142 | switch (polarity) { | ||
143 | case POSITIVE -> { | ||
144 | var substitution = translateSubstitution(callLiteral.getArguments(), body); | ||
145 | var constraint = callLiteral.getTarget(); | ||
146 | if (constraint instanceof Dnf dnf) { | ||
147 | var pattern = translate(dnf); | ||
148 | new PositivePatternCall(body, substitution, pattern); | ||
149 | } else if (constraint instanceof AnySymbolView symbolView) { | ||
150 | var inputKey = wrapperFactory.getInputKey(symbolView); | ||
151 | new TypeConstraint(body, substitution, inputKey); | ||
152 | } else { | ||
153 | throw new IllegalArgumentException("Unknown Constraint: " + constraint); | ||
154 | } | ||
155 | } | ||
156 | case TRANSITIVE -> { | ||
157 | var substitution = translateSubstitution(callLiteral.getArguments(), body); | ||
158 | var pattern = wrapConstraintWithIdentityArguments(callLiteral.getTarget()); | ||
159 | new BinaryTransitiveClosure(body, substitution, pattern); | ||
160 | } | ||
161 | case NEGATIVE -> { | ||
162 | var wrappedCall = wrapperFactory.maybeWrapConstraint(callLiteral); | ||
163 | var substitution = translateSubstitution(wrappedCall.remappedArguments(), body); | ||
164 | var pattern = wrappedCall.pattern(); | ||
165 | new NegativePatternCall(body, substitution, pattern); | ||
166 | } | ||
167 | default -> throw new IllegalArgumentException("Unknown polarity: " + polarity); | ||
168 | } | ||
169 | } | ||
170 | |||
171 | private PQuery wrapConstraintWithIdentityArguments(Constraint constraint) { | ||
172 | if (constraint instanceof Dnf dnf) { | ||
173 | return translate(dnf); | ||
174 | } else if (constraint instanceof AnySymbolView symbolView) { | ||
175 | return wrapperFactory.wrapSymbolViewIdentityArguments(symbolView); | ||
176 | } else { | ||
177 | throw new IllegalArgumentException("Unknown Constraint: " + constraint); | ||
178 | } | ||
179 | } | ||
180 | |||
181 | private static Tuple translateSubstitution(List<Variable> substitution, PBody body) { | ||
182 | int arity = substitution.size(); | ||
183 | Object[] variables = new Object[arity]; | ||
184 | for (int i = 0; i < arity; i++) { | ||
185 | var variable = substitution.get(i); | ||
186 | variables[i] = body.getOrCreateVariableByName(variable.getUniqueName()); | ||
187 | } | ||
188 | return Tuples.flatTupleOf(variables); | ||
189 | } | ||
190 | |||
191 | private void translateConstantLiteral(ConstantLiteral constantLiteral, PBody body) { | ||
192 | var variable = body.getOrCreateVariableByName(constantLiteral.getVariable().getUniqueName()); | ||
193 | new ConstantValue(body, variable, tools.refinery.store.tuple.Tuple.of(constantLiteral.getNodeId())); | ||
194 | } | ||
195 | |||
196 | private <T> void translateAssignLiteral(AssignLiteral<T> assignLiteral, PBody body) { | ||
197 | var variable = body.getOrCreateVariableByName(assignLiteral.getVariable().getUniqueName()); | ||
198 | var term = assignLiteral.getTerm(); | ||
199 | if (term instanceof ConstantTerm<T> constantTerm) { | ||
200 | new ConstantValue(body, variable, constantTerm.getValue()); | ||
201 | } else { | ||
202 | var evaluator = new TermEvaluator<>(term); | ||
203 | new ExpressionEvaluation(body, evaluator, variable); | ||
204 | } | ||
205 | } | ||
206 | |||
207 | private void translateCheckLiteral(CheckLiteral checkLiteral, PBody body) { | ||
208 | var evaluator = new CheckEvaluator(checkLiteral.getTerm()); | ||
209 | new ExpressionEvaluation(body, evaluator, null); | ||
210 | } | ||
211 | |||
212 | private void translateCountLiteral(CountLiteral countLiteral, PBody body) { | ||
213 | var wrappedCall = wrapperFactory.maybeWrapConstraint(countLiteral); | ||
214 | var substitution = translateSubstitution(wrappedCall.remappedArguments(), body); | ||
215 | var resultVariable = body.getOrCreateVariableByName(countLiteral.getResultVariable().getUniqueName()); | ||
216 | new PatternMatchCounter(body, substitution, wrappedCall.pattern(), resultVariable); | ||
217 | } | ||
218 | |||
219 | private <R, T> void translateAggregationLiteral(AggregationLiteral<R, T> aggregationLiteral, PBody body) { | ||
220 | var aggregator = aggregationLiteral.getAggregator(); | ||
221 | IMultisetAggregationOperator<T, ?, R> aggregationOperator; | ||
222 | if (aggregator instanceof StatelessAggregator<R, T> statelessAggregator) { | ||
223 | aggregationOperator = new StatelessMultisetAggregator<>(statelessAggregator); | ||
224 | } else if (aggregator instanceof StatefulAggregator<R, T> statefulAggregator) { | ||
225 | aggregationOperator = new StatefulMultisetAggregator<>(statefulAggregator); | ||
226 | } else { | ||
227 | throw new IllegalArgumentException("Unknown aggregator: " + aggregator); | ||
228 | } | ||
229 | var wrappedCall = wrapperFactory.maybeWrapConstraint(aggregationLiteral); | ||
230 | var substitution = translateSubstitution(wrappedCall.remappedArguments(), body); | ||
231 | var inputVariable = body.getOrCreateVariableByName(aggregationLiteral.getInputVariable().getUniqueName()); | ||
232 | var aggregatedColumn = substitution.invertIndex().get(inputVariable); | ||
233 | if (aggregatedColumn == null) { | ||
234 | throw new IllegalStateException("Input variable %s not found in substitution %s".formatted(inputVariable, | ||
235 | substitution)); | ||
236 | } | ||
237 | var boundAggregator = new BoundAggregator(aggregationOperator, aggregator.getInputType(), | ||
238 | aggregator.getResultType()); | ||
239 | var resultVariable = body.getOrCreateVariableByName(aggregationLiteral.getResultVariable().getUniqueName()); | ||
240 | new AggregatorConstraint(boundAggregator, body, substitution, wrappedCall.pattern(), resultVariable, | ||
241 | aggregatedColumn); | ||
242 | } | ||
243 | |||
244 | private void translateRepresentativeElectionLiteral(RepresentativeElectionLiteral literal, PBody body) { | ||
245 | var substitution = translateSubstitution(literal.getArguments(), body); | ||
246 | var pattern = wrapConstraintWithIdentityArguments(literal.getTarget()); | ||
247 | var connectivity = switch (literal.getConnectivity()) { | ||
248 | case WEAK -> Connectivity.WEAK; | ||
249 | case STRONG -> Connectivity.STRONG; | ||
250 | }; | ||
251 | new RepresentativeElectionConstraint(body, substitution, pattern, connectivity); | ||
252 | } | ||
253 | } | ||