aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java')
-rw-r--r--subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java183
1 files changed, 142 insertions, 41 deletions
diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java
index a42ae558..3fac4627 100644
--- a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java
+++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/DnfBuilder.java
@@ -7,20 +7,16 @@ package tools.refinery.store.query.dnf;
7 7
8import tools.refinery.store.query.dnf.callback.*; 8import tools.refinery.store.query.dnf.callback.*;
9import tools.refinery.store.query.literal.Literal; 9import tools.refinery.store.query.literal.Literal;
10import tools.refinery.store.query.term.DataVariable; 10import tools.refinery.store.query.term.*;
11import tools.refinery.store.query.term.NodeVariable;
12import tools.refinery.store.query.term.Variable;
13 11
14import java.util.*; 12import java.util.*;
15 13
16@SuppressWarnings("UnusedReturnValue") 14@SuppressWarnings("UnusedReturnValue")
17public final class DnfBuilder { 15public final class DnfBuilder {
18 private final String name; 16 private final String name;
19
20 private final List<Variable> parameters = new ArrayList<>(); 17 private final List<Variable> parameters = new ArrayList<>();
21 18 private final Map<Variable, ParameterDirection> directions = new HashMap<>();
22 private final List<FunctionalDependency<Variable>> functionalDependencies = new ArrayList<>(); 19 private final List<FunctionalDependency<Variable>> functionalDependencies = new ArrayList<>();
23
24 private final List<List<Literal>> clauses = new ArrayList<>(); 20 private final List<List<Literal>> clauses = new ArrayList<>();
25 21
26 DnfBuilder(String name) { 22 DnfBuilder(String name) {
@@ -37,6 +33,18 @@ public final class DnfBuilder {
37 return variable; 33 return variable;
38 } 34 }
39 35
36 public NodeVariable parameter(ParameterDirection direction) {
37 var variable = parameter();
38 putDirection(variable, direction);
39 return variable;
40 }
41
42 public NodeVariable parameter(String name, ParameterDirection direction) {
43 var variable = parameter(name);
44 putDirection(variable, direction);
45 return variable;
46 }
47
40 public <T> DataVariable<T> parameter(Class<T> type) { 48 public <T> DataVariable<T> parameter(Class<T> type) {
41 return parameter(null, type); 49 return parameter(null, type);
42 } 50 }
@@ -47,6 +55,18 @@ public final class DnfBuilder {
47 return variable; 55 return variable;
48 } 56 }
49 57
58 public <T> DataVariable<T> parameter(Class<T> type, ParameterDirection direction) {
59 var variable = parameter(type);
60 putDirection(variable, direction);
61 return variable;
62 }
63
64 public <T> DataVariable<T> parameter(String name, Class<T> type, ParameterDirection direction) {
65 var variable = parameter(name, type);
66 putDirection(variable, direction);
67 return variable;
68 }
69
50 public DnfBuilder parameter(Variable variable) { 70 public DnfBuilder parameter(Variable variable) {
51 if (parameters.contains(variable)) { 71 if (parameters.contains(variable)) {
52 throw new IllegalArgumentException("Duplicate parameter: " + variable); 72 throw new IllegalArgumentException("Duplicate parameter: " + variable);
@@ -55,12 +75,49 @@ public final class DnfBuilder {
55 return this; 75 return this;
56 } 76 }
57 77
78 public DnfBuilder parameter(Variable variable, ParameterDirection direction) {
79 parameter(variable);
80 putDirection(variable, direction);
81 return this;
82 }
83
84 private void putDirection(Variable variable, ParameterDirection direction) {
85 if (variable.tryGetType().isPresent()) {
86 if (direction == ParameterDirection.IN_OUT) {
87 throw new IllegalArgumentException("%s direction is forbidden for data variable %s"
88 .formatted(direction, variable));
89 }
90 } else {
91 if (direction == ParameterDirection.OUT) {
92 throw new IllegalArgumentException("%s direction is forbidden for node variable %s"
93 .formatted(direction, variable));
94 }
95 }
96 directions.put(variable, direction);
97 }
98
58 public DnfBuilder parameters(Variable... variables) { 99 public DnfBuilder parameters(Variable... variables) {
59 return parameters(List.of(variables)); 100 return parameters(List.of(variables));
60 } 101 }
61 102
62 public DnfBuilder parameters(Collection<? extends Variable> variables) { 103 public DnfBuilder parameters(Collection<? extends Variable> variables) {
63 parameters.addAll(variables); 104 for (var variable : variables) {
105 parameter(variable);
106 }
107 return this;
108 }
109
110 public DnfBuilder parameters(Collection<? extends Variable> variables, ParameterDirection direction) {
111 for (var variable : variables) {
112 parameter(variable, direction);
113 }
114 return this;
115 }
116
117 public DnfBuilder symbolicParameters(Collection<SymbolicParameter> parameters) {
118 for (var parameter : parameters) {
119 parameter(parameter.getVariable(), parameter.getDirection());
120 }
64 return this; 121 return this;
65 } 122 }
66 123
@@ -152,54 +209,98 @@ public final class DnfBuilder {
152 } 209 }
153 210
154 public DnfBuilder clause(Collection<? extends Literal> literals) { 211 public DnfBuilder clause(Collection<? extends Literal> literals) {
155 // Remove duplicates by using a hashed data structure. 212 clauses.add(List.copyOf(literals));
156 var filteredLiterals = new LinkedHashSet<Literal>(literals.size());
157 for (var literal : literals) {
158 var reduction = literal.getReduction();
159 switch (reduction) {
160 case NOT_REDUCIBLE -> filteredLiterals.add(literal);
161 case ALWAYS_TRUE -> {
162 // Literals reducible to {@code true} can be omitted, because the model is always assumed to have at
163 // least on object.
164 }
165 case ALWAYS_FALSE -> {
166 // Clauses with {@code false} literals can be omitted entirely.
167 return this;
168 }
169 default -> throw new IllegalArgumentException("Invalid reduction: " + reduction);
170 }
171 }
172 clauses.add(List.copyOf(filteredLiterals));
173 return this; 213 return this;
174 } 214 }
175 215
216 <T> void output(DataVariable<T> outputVariable) {
217 var fromParameters = Set.copyOf(parameters);
218 parameter(outputVariable, ParameterDirection.OUT);
219 functionalDependency(fromParameters, Set.of(outputVariable));
220 }
221
176 public Dnf build() { 222 public Dnf build() {
177 var postProcessedClauses = postProcessClauses(); 223 var postProcessedClauses = postProcessClauses();
178 return new Dnf(name, Collections.unmodifiableList(parameters), 224 return new Dnf(name, createParameterList(postProcessedClauses),
179 Collections.unmodifiableList(functionalDependencies), 225 Collections.unmodifiableList(functionalDependencies),
180 Collections.unmodifiableList(postProcessedClauses)); 226 Collections.unmodifiableList(postProcessedClauses));
181 } 227 }
182 228
183 <T> void output(DataVariable<T> outputVariable) {
184 functionalDependency(Set.copyOf(parameters), Set.of(outputVariable));
185 parameter(outputVariable);
186 }
187
188 private List<DnfClause> postProcessClauses() { 229 private List<DnfClause> postProcessClauses() {
230 var parameterSet = Collections.unmodifiableSet(new LinkedHashSet<>(parameters));
231 var parameterWeights = getParameterWeights();
189 var postProcessedClauses = new ArrayList<DnfClause>(clauses.size()); 232 var postProcessedClauses = new ArrayList<DnfClause>(clauses.size());
190 for (var literals : clauses) { 233 for (var literals : clauses) {
191 if (literals.isEmpty()) { 234 var postProcessor = new ClausePostProcessor(parameterSet, parameterWeights, literals);
192 // Predicate will always match, the other clauses are irrelevant. 235 var result = postProcessor.postProcessClause();
193 return List.of(new DnfClause(Set.of(), List.of())); 236 if (result instanceof ClausePostProcessor.ClauseResult clauseResult) {
194 } 237 postProcessedClauses.add(clauseResult.clause());
195 var variables = new HashSet<Variable>(); 238 } else if (result instanceof ClausePostProcessor.ConstantResult constantResult) {
196 for (var literal : literals) { 239 switch (constantResult) {
197 variables.addAll(literal.getBoundVariables()); 240 case ALWAYS_TRUE -> {
241 return List.of(new DnfClause(Set.of(), List.of()));
242 }
243 case ALWAYS_FALSE -> {
244 // Skip this clause because it can never match.
245 }
246 default -> throw new IllegalStateException("Unexpected ClausePostProcessor.ConstantResult: " +
247 constantResult);
248 }
249 } else {
250 throw new IllegalStateException("Unexpected ClausePostProcessor.Result: " + result);
198 } 251 }
199 parameters.forEach(variables::remove);
200 postProcessedClauses.add(new DnfClause(Collections.unmodifiableSet(variables),
201 Collections.unmodifiableList(literals)));
202 } 252 }
203 return postProcessedClauses; 253 return postProcessedClauses;
204 } 254 }
255
256 private Map<Variable, Integer> getParameterWeights() {
257 var mutableParameterWeights = new HashMap<Variable, Integer>();
258 int arity = parameters.size();
259 for (int i = 0; i < arity; i++) {
260 mutableParameterWeights.put(parameters.get(i), i);
261 }
262 return Collections.unmodifiableMap(mutableParameterWeights);
263 }
264
265 private List<SymbolicParameter> createParameterList(List<DnfClause> postProcessedClauses) {
266 var outputParameterVariables = new HashSet<>(parameters);
267 for (var clause : postProcessedClauses) {
268 outputParameterVariables.retainAll(clause.positiveVariables());
269 }
270 var parameterList = new ArrayList<SymbolicParameter>(parameters.size());
271 for (var parameter : parameters) {
272 ParameterDirection direction = getDirection(outputParameterVariables, parameter);
273 parameterList.add(new SymbolicParameter(parameter, direction));
274 }
275 return Collections.unmodifiableList(parameterList);
276 }
277
278 private ParameterDirection getDirection(HashSet<Variable> outputParameterVariables, Variable parameter) {
279 var direction = getInferredDirection(outputParameterVariables, parameter);
280 var expectedDirection = directions.get(parameter);
281 if (expectedDirection == ParameterDirection.IN && direction == ParameterDirection.IN_OUT) {
282 // Parameters may be explicitly marked as {@code @In} even if they are bound in all clauses.
283 return expectedDirection;
284 }
285 if (expectedDirection != null && expectedDirection != direction) {
286 throw new IllegalArgumentException("Expected parameter %s to have direction %s, but got %s instead"
287 .formatted(parameter, expectedDirection, direction));
288 }
289 return direction;
290 }
291
292 private static ParameterDirection getInferredDirection(HashSet<Variable> outputParameterVariables,
293 Variable parameter) {
294 if (outputParameterVariables.contains(parameter)) {
295 if (parameter instanceof NodeVariable) {
296 return ParameterDirection.IN_OUT;
297 } else if (parameter instanceof AnyDataVariable) {
298 return ParameterDirection.OUT;
299 } else {
300 throw new IllegalArgumentException("Unknown parameter: " + parameter);
301 }
302 } else {
303 return ParameterDirection.IN;
304 }
305 }
205} 306}