diff options
Diffstat (limited to 'subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java')
-rw-r--r-- | subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java | 183 |
1 files changed, 142 insertions, 41 deletions
diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java index a42ae558..3fac4627 100644 --- a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java | |||
@@ -7,20 +7,16 @@ package tools.refinery.store.query.dnf; | |||
7 | 7 | ||
8 | import tools.refinery.store.query.dnf.callback.*; | 8 | import tools.refinery.store.query.dnf.callback.*; |
9 | import tools.refinery.store.query.literal.Literal; | 9 | import tools.refinery.store.query.literal.Literal; |
10 | import tools.refinery.store.query.term.DataVariable; | 10 | import tools.refinery.store.query.term.*; |
11 | import tools.refinery.store.query.term.NodeVariable; | ||
12 | import tools.refinery.store.query.term.Variable; | ||
13 | 11 | ||
14 | import java.util.*; | 12 | import java.util.*; |
15 | 13 | ||
16 | @SuppressWarnings("UnusedReturnValue") | 14 | @SuppressWarnings("UnusedReturnValue") |
17 | public final class DnfBuilder { | 15 | public final class DnfBuilder { |
18 | private final String name; | 16 | private final String name; |
19 | |||
20 | private final List<Variable> parameters = new ArrayList<>(); | 17 | private final List<Variable> parameters = new ArrayList<>(); |
21 | 18 | private final Map<Variable, ParameterDirection> directions = new HashMap<>(); | |
22 | private final List<FunctionalDependency<Variable>> functionalDependencies = new ArrayList<>(); | 19 | private final List<FunctionalDependency<Variable>> functionalDependencies = new ArrayList<>(); |
23 | |||
24 | private final List<List<Literal>> clauses = new ArrayList<>(); | 20 | private final List<List<Literal>> clauses = new ArrayList<>(); |
25 | 21 | ||
26 | DnfBuilder(String name) { | 22 | DnfBuilder(String name) { |
@@ -37,6 +33,18 @@ public final class DnfBuilder { | |||
37 | return variable; | 33 | return variable; |
38 | } | 34 | } |
39 | 35 | ||
36 | public NodeVariable parameter(ParameterDirection direction) { | ||
37 | var variable = parameter(); | ||
38 | putDirection(variable, direction); | ||
39 | return variable; | ||
40 | } | ||
41 | |||
42 | public NodeVariable parameter(String name, ParameterDirection direction) { | ||
43 | var variable = parameter(name); | ||
44 | putDirection(variable, direction); | ||
45 | return variable; | ||
46 | } | ||
47 | |||
40 | public <T> DataVariable<T> parameter(Class<T> type) { | 48 | public <T> DataVariable<T> parameter(Class<T> type) { |
41 | return parameter(null, type); | 49 | return parameter(null, type); |
42 | } | 50 | } |
@@ -47,6 +55,18 @@ public final class DnfBuilder { | |||
47 | return variable; | 55 | return variable; |
48 | } | 56 | } |
49 | 57 | ||
58 | public <T> DataVariable<T> parameter(Class<T> type, ParameterDirection direction) { | ||
59 | var variable = parameter(type); | ||
60 | putDirection(variable, direction); | ||
61 | return variable; | ||
62 | } | ||
63 | |||
64 | public <T> DataVariable<T> parameter(String name, Class<T> type, ParameterDirection direction) { | ||
65 | var variable = parameter(name, type); | ||
66 | putDirection(variable, direction); | ||
67 | return variable; | ||
68 | } | ||
69 | |||
50 | public DnfBuilder parameter(Variable variable) { | 70 | public DnfBuilder parameter(Variable variable) { |
51 | if (parameters.contains(variable)) { | 71 | if (parameters.contains(variable)) { |
52 | throw new IllegalArgumentException("Duplicate parameter: " + variable); | 72 | throw new IllegalArgumentException("Duplicate parameter: " + variable); |
@@ -55,12 +75,49 @@ public final class DnfBuilder { | |||
55 | return this; | 75 | return this; |
56 | } | 76 | } |
57 | 77 | ||
78 | public DnfBuilder parameter(Variable variable, ParameterDirection direction) { | ||
79 | parameter(variable); | ||
80 | putDirection(variable, direction); | ||
81 | return this; | ||
82 | } | ||
83 | |||
84 | private void putDirection(Variable variable, ParameterDirection direction) { | ||
85 | if (variable.tryGetType().isPresent()) { | ||
86 | if (direction == ParameterDirection.IN_OUT) { | ||
87 | throw new IllegalArgumentException("%s direction is forbidden for data variable %s" | ||
88 | .formatted(direction, variable)); | ||
89 | } | ||
90 | } else { | ||
91 | if (direction == ParameterDirection.OUT) { | ||
92 | throw new IllegalArgumentException("%s direction is forbidden for node variable %s" | ||
93 | .formatted(direction, variable)); | ||
94 | } | ||
95 | } | ||
96 | directions.put(variable, direction); | ||
97 | } | ||
98 | |||
58 | public DnfBuilder parameters(Variable... variables) { | 99 | public DnfBuilder parameters(Variable... variables) { |
59 | return parameters(List.of(variables)); | 100 | return parameters(List.of(variables)); |
60 | } | 101 | } |
61 | 102 | ||
62 | public DnfBuilder parameters(Collection<? extends Variable> variables) { | 103 | public DnfBuilder parameters(Collection<? extends Variable> variables) { |
63 | parameters.addAll(variables); | 104 | for (var variable : variables) { |
105 | parameter(variable); | ||
106 | } | ||
107 | return this; | ||
108 | } | ||
109 | |||
110 | public DnfBuilder parameters(Collection<? extends Variable> variables, ParameterDirection direction) { | ||
111 | for (var variable : variables) { | ||
112 | parameter(variable, direction); | ||
113 | } | ||
114 | return this; | ||
115 | } | ||
116 | |||
117 | public DnfBuilder symbolicParameters(Collection<SymbolicParameter> parameters) { | ||
118 | for (var parameter : parameters) { | ||
119 | parameter(parameter.getVariable(), parameter.getDirection()); | ||
120 | } | ||
64 | return this; | 121 | return this; |
65 | } | 122 | } |
66 | 123 | ||
@@ -152,54 +209,98 @@ public final class DnfBuilder { | |||
152 | } | 209 | } |
153 | 210 | ||
154 | public DnfBuilder clause(Collection<? extends Literal> literals) { | 211 | public DnfBuilder clause(Collection<? extends Literal> literals) { |
155 | // Remove duplicates by using a hashed data structure. | 212 | clauses.add(List.copyOf(literals)); |
156 | var filteredLiterals = new LinkedHashSet<Literal>(literals.size()); | ||
157 | for (var literal : literals) { | ||
158 | var reduction = literal.getReduction(); | ||
159 | switch (reduction) { | ||
160 | case NOT_REDUCIBLE -> filteredLiterals.add(literal); | ||
161 | case ALWAYS_TRUE -> { | ||
162 | // Literals reducible to {@code true} can be omitted, because the model is always assumed to have at | ||
163 | // least on object. | ||
164 | } | ||
165 | case ALWAYS_FALSE -> { | ||
166 | // Clauses with {@code false} literals can be omitted entirely. | ||
167 | return this; | ||
168 | } | ||
169 | default -> throw new IllegalArgumentException("Invalid reduction: " + reduction); | ||
170 | } | ||
171 | } | ||
172 | clauses.add(List.copyOf(filteredLiterals)); | ||
173 | return this; | 213 | return this; |
174 | } | 214 | } |
175 | 215 | ||
216 | <T> void output(DataVariable<T> outputVariable) { | ||
217 | var fromParameters = Set.copyOf(parameters); | ||
218 | parameter(outputVariable, ParameterDirection.OUT); | ||
219 | functionalDependency(fromParameters, Set.of(outputVariable)); | ||
220 | } | ||
221 | |||
176 | public Dnf build() { | 222 | public Dnf build() { |
177 | var postProcessedClauses = postProcessClauses(); | 223 | var postProcessedClauses = postProcessClauses(); |
178 | return new Dnf(name, Collections.unmodifiableList(parameters), | 224 | return new Dnf(name, createParameterList(postProcessedClauses), |
179 | Collections.unmodifiableList(functionalDependencies), | 225 | Collections.unmodifiableList(functionalDependencies), |
180 | Collections.unmodifiableList(postProcessedClauses)); | 226 | Collections.unmodifiableList(postProcessedClauses)); |
181 | } | 227 | } |
182 | 228 | ||
183 | <T> void output(DataVariable<T> outputVariable) { | ||
184 | functionalDependency(Set.copyOf(parameters), Set.of(outputVariable)); | ||
185 | parameter(outputVariable); | ||
186 | } | ||
187 | |||
188 | private List<DnfClause> postProcessClauses() { | 229 | private List<DnfClause> postProcessClauses() { |
230 | var parameterSet = Collections.unmodifiableSet(new LinkedHashSet<>(parameters)); | ||
231 | var parameterWeights = getParameterWeights(); | ||
189 | var postProcessedClauses = new ArrayList<DnfClause>(clauses.size()); | 232 | var postProcessedClauses = new ArrayList<DnfClause>(clauses.size()); |
190 | for (var literals : clauses) { | 233 | for (var literals : clauses) { |
191 | if (literals.isEmpty()) { | 234 | var postProcessor = new ClausePostProcessor(parameterSet, parameterWeights, literals); |
192 | // Predicate will always match, the other clauses are irrelevant. | 235 | var result = postProcessor.postProcessClause(); |
193 | return List.of(new DnfClause(Set.of(), List.of())); | 236 | if (result instanceof ClausePostProcessor.ClauseResult clauseResult) { |
194 | } | 237 | postProcessedClauses.add(clauseResult.clause()); |
195 | var variables = new HashSet<Variable>(); | 238 | } else if (result instanceof ClausePostProcessor.ConstantResult constantResult) { |
196 | for (var literal : literals) { | 239 | switch (constantResult) { |
197 | variables.addAll(literal.getBoundVariables()); | 240 | case ALWAYS_TRUE -> { |
241 | return List.of(new DnfClause(Set.of(), List.of())); | ||
242 | } | ||
243 | case ALWAYS_FALSE -> { | ||
244 | // Skip this clause because it can never match. | ||
245 | } | ||
246 | default -> throw new IllegalStateException("Unexpected ClausePostProcessor.ConstantResult: " + | ||
247 | constantResult); | ||
248 | } | ||
249 | } else { | ||
250 | throw new IllegalStateException("Unexpected ClausePostProcessor.Result: " + result); | ||
198 | } | 251 | } |
199 | parameters.forEach(variables::remove); | ||
200 | postProcessedClauses.add(new DnfClause(Collections.unmodifiableSet(variables), | ||
201 | Collections.unmodifiableList(literals))); | ||
202 | } | 252 | } |
203 | return postProcessedClauses; | 253 | return postProcessedClauses; |
204 | } | 254 | } |
255 | |||
256 | private Map<Variable, Integer> getParameterWeights() { | ||
257 | var mutableParameterWeights = new HashMap<Variable, Integer>(); | ||
258 | int arity = parameters.size(); | ||
259 | for (int i = 0; i < arity; i++) { | ||
260 | mutableParameterWeights.put(parameters.get(i), i); | ||
261 | } | ||
262 | return Collections.unmodifiableMap(mutableParameterWeights); | ||
263 | } | ||
264 | |||
265 | private List<SymbolicParameter> createParameterList(List<DnfClause> postProcessedClauses) { | ||
266 | var outputParameterVariables = new HashSet<>(parameters); | ||
267 | for (var clause : postProcessedClauses) { | ||
268 | outputParameterVariables.retainAll(clause.positiveVariables()); | ||
269 | } | ||
270 | var parameterList = new ArrayList<SymbolicParameter>(parameters.size()); | ||
271 | for (var parameter : parameters) { | ||
272 | ParameterDirection direction = getDirection(outputParameterVariables, parameter); | ||
273 | parameterList.add(new SymbolicParameter(parameter, direction)); | ||
274 | } | ||
275 | return Collections.unmodifiableList(parameterList); | ||
276 | } | ||
277 | |||
278 | private ParameterDirection getDirection(HashSet<Variable> outputParameterVariables, Variable parameter) { | ||
279 | var direction = getInferredDirection(outputParameterVariables, parameter); | ||
280 | var expectedDirection = directions.get(parameter); | ||
281 | if (expectedDirection == ParameterDirection.IN && direction == ParameterDirection.IN_OUT) { | ||
282 | // Parameters may be explicitly marked as {@code @In} even if they are bound in all clauses. | ||
283 | return expectedDirection; | ||
284 | } | ||
285 | if (expectedDirection != null && expectedDirection != direction) { | ||
286 | throw new IllegalArgumentException("Expected parameter %s to have direction %s, but got %s instead" | ||
287 | .formatted(parameter, expectedDirection, direction)); | ||
288 | } | ||
289 | return direction; | ||
290 | } | ||
291 | |||
292 | private static ParameterDirection getInferredDirection(HashSet<Variable> outputParameterVariables, | ||
293 | Variable parameter) { | ||
294 | if (outputParameterVariables.contains(parameter)) { | ||
295 | if (parameter instanceof NodeVariable) { | ||
296 | return ParameterDirection.IN_OUT; | ||
297 | } else if (parameter instanceof AnyDataVariable) { | ||
298 | return ParameterDirection.OUT; | ||
299 | } else { | ||
300 | throw new IllegalArgumentException("Unknown parameter: " + parameter); | ||
301 | } | ||
302 | } else { | ||
303 | return ParameterDirection.IN; | ||
304 | } | ||
305 | } | ||
205 | } | 306 | } |