aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-query-interpreter/src/main/java/tools/refinery/store/query/interpreter/internal/pquery/Dnf2PQuery.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/store-query-interpreter/src/main/java/tools/refinery/store/query/interpreter/internal/pquery/Dnf2PQuery.java')
-rw-r--r--subprojects/store-query-interpreter/src/main/java/tools/refinery/store/query/interpreter/internal/pquery/Dnf2PQuery.java253
1 files changed, 253 insertions, 0 deletions
diff --git a/subprojects/store-query-interpreter/src/main/java/tools/refinery/store/query/interpreter/internal/pquery/Dnf2PQuery.java b/subprojects/store-query-interpreter/src/main/java/tools/refinery/store/query/interpreter/internal/pquery/Dnf2PQuery.java
new file mode 100644
index 00000000..73ce4043
--- /dev/null
+++ b/subprojects/store-query-interpreter/src/main/java/tools/refinery/store/query/interpreter/internal/pquery/Dnf2PQuery.java
@@ -0,0 +1,253 @@
1/*
2 * SPDX-FileCopyrightText: 2021-2023 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.store.query.interpreter.internal.pquery;
7
8import tools.refinery.interpreter.matchers.psystem.basicdeferred.*;
9import tools.refinery.interpreter.matchers.psystem.basicenumerables.*;
10import tools.refinery.interpreter.matchers.psystem.basicenumerables.Connectivity;
11import tools.refinery.store.query.Constraint;
12import tools.refinery.store.query.dnf.Dnf;
13import tools.refinery.store.query.dnf.DnfClause;
14import tools.refinery.store.query.dnf.SymbolicParameter;
15import tools.refinery.store.query.literal.*;
16import tools.refinery.store.query.term.ConstantTerm;
17import tools.refinery.store.query.term.StatefulAggregator;
18import tools.refinery.store.query.term.StatelessAggregator;
19import tools.refinery.store.query.term.Variable;
20import tools.refinery.store.query.view.AnySymbolView;
21import tools.refinery.store.util.CycleDetectingMapper;
22import tools.refinery.interpreter.matchers.backend.IQueryBackendFactory;
23import tools.refinery.interpreter.matchers.backend.QueryEvaluationHint;
24import tools.refinery.interpreter.matchers.context.IInputKey;
25import tools.refinery.interpreter.matchers.psystem.PBody;
26import tools.refinery.interpreter.matchers.psystem.PVariable;
27import tools.refinery.interpreter.matchers.psystem.aggregations.BoundAggregator;
28import tools.refinery.interpreter.matchers.psystem.aggregations.IMultisetAggregationOperator;
29import tools.refinery.interpreter.matchers.psystem.annotations.PAnnotation;
30import tools.refinery.interpreter.matchers.psystem.queries.PParameter;
31import tools.refinery.interpreter.matchers.psystem.queries.PParameterDirection;
32import tools.refinery.interpreter.matchers.psystem.queries.PQuery;
33import tools.refinery.interpreter.matchers.tuple.Tuple;
34import tools.refinery.interpreter.matchers.tuple.Tuples;
35
36import java.util.ArrayList;
37import java.util.HashMap;
38import java.util.List;
39import java.util.Map;
40import java.util.function.Function;
41
42public class Dnf2PQuery {
43 private final CycleDetectingMapper<Dnf, RawPQuery> mapper = new CycleDetectingMapper<>(Dnf::name,
44 this::doTranslate);
45 private final QueryWrapperFactory wrapperFactory = new QueryWrapperFactory(this);
46 private Function<Dnf, QueryEvaluationHint> computeHint = dnf -> new QueryEvaluationHint(null,
47 (IQueryBackendFactory) null);
48
49 public void setComputeHint(Function<Dnf, QueryEvaluationHint> computeHint) {
50 this.computeHint = computeHint;
51 }
52
53 public RawPQuery translate(Dnf dnfQuery) {
54 return mapper.map(dnfQuery);
55 }
56
57 public Map<AnySymbolView, IInputKey> getSymbolViews() {
58 return wrapperFactory.getSymbolViews();
59 }
60
61 private RawPQuery doTranslate(Dnf dnfQuery) {
62 var pQuery = new RawPQuery(dnfQuery.getUniqueName());
63 pQuery.setEvaluationHints(computeHint.apply(dnfQuery));
64
65 Map<SymbolicParameter, PParameter> parameters = new HashMap<>();
66 List<PParameter> parameterList = new ArrayList<>();
67 for (var parameter : dnfQuery.getSymbolicParameters()) {
68 var direction = switch (parameter.getDirection()) {
69 case OUT -> PParameterDirection.INOUT;
70 case IN -> throw new IllegalArgumentException("Query %s with input parameter %s is not supported"
71 .formatted(dnfQuery, parameter.getVariable()));
72 };
73 var pParameter = new PParameter(parameter.getVariable().getUniqueName(), null, null, direction);
74 parameters.put(parameter, pParameter);
75 parameterList.add(pParameter);
76 }
77
78 pQuery.setParameters(parameterList);
79
80 for (var functionalDependency : dnfQuery.getFunctionalDependencies()) {
81 var functionalDependencyAnnotation = new PAnnotation("FunctionalDependency");
82 for (var forEachVariable : functionalDependency.forEach()) {
83 functionalDependencyAnnotation.addAttribute("forEach", forEachVariable.getUniqueName());
84 }
85 for (var uniqueVariable : functionalDependency.unique()) {
86 functionalDependencyAnnotation.addAttribute("unique", uniqueVariable.getUniqueName());
87 }
88 pQuery.addAnnotation(functionalDependencyAnnotation);
89 }
90
91 for (DnfClause clause : dnfQuery.getClauses()) {
92 PBody body = new PBody(pQuery);
93 List<ExportedParameter> parameterExports = new ArrayList<>();
94 for (var parameter : dnfQuery.getSymbolicParameters()) {
95 PVariable pVar = body.getOrCreateVariableByName(parameter.getVariable().getUniqueName());
96 parameterExports.add(new ExportedParameter(body, pVar, parameters.get(parameter)));
97 }
98 body.setSymbolicParameters(parameterExports);
99 pQuery.addBody(body);
100 for (Literal literal : clause.literals()) {
101 translateLiteral(literal, body);
102 }
103 }
104
105 return pQuery;
106 }
107
108 private void translateLiteral(Literal literal, PBody body) {
109 if (literal instanceof EquivalenceLiteral equivalenceLiteral) {
110 translateEquivalenceLiteral(equivalenceLiteral, body);
111 } else if (literal instanceof CallLiteral callLiteral) {
112 translateCallLiteral(callLiteral, body);
113 } else if (literal instanceof ConstantLiteral constantLiteral) {
114 translateConstantLiteral(constantLiteral, body);
115 } else if (literal instanceof AssignLiteral<?> assignLiteral) {
116 translateAssignLiteral(assignLiteral, body);
117 } else if (literal instanceof CheckLiteral checkLiteral) {
118 translateCheckLiteral(checkLiteral, body);
119 } else if (literal instanceof CountLiteral countLiteral) {
120 translateCountLiteral(countLiteral, body);
121 } else if (literal instanceof AggregationLiteral<?, ?> aggregationLiteral) {
122 translateAggregationLiteral(aggregationLiteral, body);
123 } else if (literal instanceof RepresentativeElectionLiteral representativeElectionLiteral) {
124 translateRepresentativeElectionLiteral(representativeElectionLiteral, body);
125 } else {
126 throw new IllegalArgumentException("Unknown literal: " + literal.toString());
127 }
128 }
129
130 private void translateEquivalenceLiteral(EquivalenceLiteral equivalenceLiteral, PBody body) {
131 PVariable varSource = body.getOrCreateVariableByName(equivalenceLiteral.getLeft().getUniqueName());
132 PVariable varTarget = body.getOrCreateVariableByName(equivalenceLiteral.getRight().getUniqueName());
133 if (equivalenceLiteral.isPositive()) {
134 new Equality(body, varSource, varTarget);
135 } else {
136 new Inequality(body, varSource, varTarget);
137 }
138 }
139
140 private void translateCallLiteral(CallLiteral callLiteral, PBody body) {
141 var polarity = callLiteral.getPolarity();
142 switch (polarity) {
143 case POSITIVE -> {
144 var substitution = translateSubstitution(callLiteral.getArguments(), body);
145 var constraint = callLiteral.getTarget();
146 if (constraint instanceof Dnf dnf) {
147 var pattern = translate(dnf);
148 new PositivePatternCall(body, substitution, pattern);
149 } else if (constraint instanceof AnySymbolView symbolView) {
150 var inputKey = wrapperFactory.getInputKey(symbolView);
151 new TypeConstraint(body, substitution, inputKey);
152 } else {
153 throw new IllegalArgumentException("Unknown Constraint: " + constraint);
154 }
155 }
156 case TRANSITIVE -> {
157 var substitution = translateSubstitution(callLiteral.getArguments(), body);
158 var pattern = wrapConstraintWithIdentityArguments(callLiteral.getTarget());
159 new BinaryTransitiveClosure(body, substitution, pattern);
160 }
161 case NEGATIVE -> {
162 var wrappedCall = wrapperFactory.maybeWrapConstraint(callLiteral);
163 var substitution = translateSubstitution(wrappedCall.remappedArguments(), body);
164 var pattern = wrappedCall.pattern();
165 new NegativePatternCall(body, substitution, pattern);
166 }
167 default -> throw new IllegalArgumentException("Unknown polarity: " + polarity);
168 }
169 }
170
171 private PQuery wrapConstraintWithIdentityArguments(Constraint constraint) {
172 if (constraint instanceof Dnf dnf) {
173 return translate(dnf);
174 } else if (constraint instanceof AnySymbolView symbolView) {
175 return wrapperFactory.wrapSymbolViewIdentityArguments(symbolView);
176 } else {
177 throw new IllegalArgumentException("Unknown Constraint: " + constraint);
178 }
179 }
180
181 private static Tuple translateSubstitution(List<Variable> substitution, PBody body) {
182 int arity = substitution.size();
183 Object[] variables = new Object[arity];
184 for (int i = 0; i < arity; i++) {
185 var variable = substitution.get(i);
186 variables[i] = body.getOrCreateVariableByName(variable.getUniqueName());
187 }
188 return Tuples.flatTupleOf(variables);
189 }
190
191 private void translateConstantLiteral(ConstantLiteral constantLiteral, PBody body) {
192 var variable = body.getOrCreateVariableByName(constantLiteral.getVariable().getUniqueName());
193 new ConstantValue(body, variable, tools.refinery.store.tuple.Tuple.of(constantLiteral.getNodeId()));
194 }
195
196 private <T> void translateAssignLiteral(AssignLiteral<T> assignLiteral, PBody body) {
197 var variable = body.getOrCreateVariableByName(assignLiteral.getVariable().getUniqueName());
198 var term = assignLiteral.getTerm();
199 if (term instanceof ConstantTerm<T> constantTerm) {
200 new ConstantValue(body, variable, constantTerm.getValue());
201 } else {
202 var evaluator = new TermEvaluator<>(term);
203 new ExpressionEvaluation(body, evaluator, variable);
204 }
205 }
206
207 private void translateCheckLiteral(CheckLiteral checkLiteral, PBody body) {
208 var evaluator = new CheckEvaluator(checkLiteral.getTerm());
209 new ExpressionEvaluation(body, evaluator, null);
210 }
211
212 private void translateCountLiteral(CountLiteral countLiteral, PBody body) {
213 var wrappedCall = wrapperFactory.maybeWrapConstraint(countLiteral);
214 var substitution = translateSubstitution(wrappedCall.remappedArguments(), body);
215 var resultVariable = body.getOrCreateVariableByName(countLiteral.getResultVariable().getUniqueName());
216 new PatternMatchCounter(body, substitution, wrappedCall.pattern(), resultVariable);
217 }
218
219 private <R, T> void translateAggregationLiteral(AggregationLiteral<R, T> aggregationLiteral, PBody body) {
220 var aggregator = aggregationLiteral.getAggregator();
221 IMultisetAggregationOperator<T, ?, R> aggregationOperator;
222 if (aggregator instanceof StatelessAggregator<R, T> statelessAggregator) {
223 aggregationOperator = new StatelessMultisetAggregator<>(statelessAggregator);
224 } else if (aggregator instanceof StatefulAggregator<R, T> statefulAggregator) {
225 aggregationOperator = new StatefulMultisetAggregator<>(statefulAggregator);
226 } else {
227 throw new IllegalArgumentException("Unknown aggregator: " + aggregator);
228 }
229 var wrappedCall = wrapperFactory.maybeWrapConstraint(aggregationLiteral);
230 var substitution = translateSubstitution(wrappedCall.remappedArguments(), body);
231 var inputVariable = body.getOrCreateVariableByName(aggregationLiteral.getInputVariable().getUniqueName());
232 var aggregatedColumn = substitution.invertIndex().get(inputVariable);
233 if (aggregatedColumn == null) {
234 throw new IllegalStateException("Input variable %s not found in substitution %s".formatted(inputVariable,
235 substitution));
236 }
237 var boundAggregator = new BoundAggregator(aggregationOperator, aggregator.getInputType(),
238 aggregator.getResultType());
239 var resultVariable = body.getOrCreateVariableByName(aggregationLiteral.getResultVariable().getUniqueName());
240 new AggregatorConstraint(boundAggregator, body, substitution, wrappedCall.pattern(), resultVariable,
241 aggregatedColumn);
242 }
243
244 private void translateRepresentativeElectionLiteral(RepresentativeElectionLiteral literal, PBody body) {
245 var substitution = translateSubstitution(literal.getArguments(), body);
246 var pattern = wrapConstraintWithIdentityArguments(literal.getTarget());
247 var connectivity = switch (literal.getConnectivity()) {
248 case WEAK -> Connectivity.WEAK;
249 case STRONG -> Connectivity.STRONG;
250 };
251 new RepresentativeElectionConstraint(body, substitution, pattern, connectivity);
252 }
253}