aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/ClausePostProcessor.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/ClausePostProcessor.java')
-rw-r--r--subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/ClausePostProcessor.java102
1 files changed, 70 insertions, 32 deletions
diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/ClausePostProcessor.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/ClausePostProcessor.java
index b5e7092b..8800a155 100644
--- a/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/ClausePostProcessor.java
+++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/dnf/ClausePostProcessor.java
@@ -6,13 +6,12 @@
6package tools.refinery.store.query.dnf; 6package tools.refinery.store.query.dnf;
7 7
8import org.jetbrains.annotations.NotNull; 8import org.jetbrains.annotations.NotNull;
9import tools.refinery.store.query.literal.BooleanLiteral; 9import tools.refinery.store.query.Constraint;
10import tools.refinery.store.query.literal.EquivalenceLiteral; 10import tools.refinery.store.query.InvalidQueryException;
11import tools.refinery.store.query.literal.Literal; 11import tools.refinery.store.query.literal.*;
12import tools.refinery.store.query.substitution.MapBasedSubstitution; 12import tools.refinery.store.query.substitution.MapBasedSubstitution;
13import tools.refinery.store.query.substitution.StatelessSubstitution; 13import tools.refinery.store.query.substitution.StatelessSubstitution;
14import tools.refinery.store.query.substitution.Substitution; 14import tools.refinery.store.query.substitution.Substitution;
15import tools.refinery.store.query.term.NodeVariable;
16import tools.refinery.store.query.term.ParameterDirection; 15import tools.refinery.store.query.term.ParameterDirection;
17import tools.refinery.store.query.term.Variable; 16import tools.refinery.store.query.term.Variable;
18 17
@@ -22,8 +21,8 @@ import java.util.function.Function;
22class ClausePostProcessor { 21class ClausePostProcessor {
23 private final Map<Variable, ParameterInfo> parameters; 22 private final Map<Variable, ParameterInfo> parameters;
24 private final List<Literal> literals; 23 private final List<Literal> literals;
25 private final Map<NodeVariable, NodeVariable> representatives = new LinkedHashMap<>(); 24 private final Map<Variable, Variable> representatives = new LinkedHashMap<>();
26 private final Map<NodeVariable, Set<NodeVariable>> equivalencePartition = new HashMap<>(); 25 private final Map<Variable, Set<Variable>> equivalencePartition = new HashMap<>();
27 private List<Literal> substitutedLiterals; 26 private List<Literal> substitutedLiterals;
28 private final Set<Variable> existentiallyQuantifiedVariables = new LinkedHashSet<>(); 27 private final Set<Variable> existentiallyQuantifiedVariables = new LinkedHashSet<>();
29 private Set<Variable> positiveVariables; 28 private Set<Variable> positiveVariables;
@@ -58,6 +57,9 @@ class ClausePostProcessor {
58 if (filteredLiterals.isEmpty()) { 57 if (filteredLiterals.isEmpty()) {
59 return ConstantResult.ALWAYS_TRUE; 58 return ConstantResult.ALWAYS_TRUE;
60 } 59 }
60 if (hasContradictoryCall(filteredLiterals)) {
61 return ConstantResult.ALWAYS_FALSE;
62 }
61 var clause = new DnfClause(Collections.unmodifiableSet(positiveVariables), 63 var clause = new DnfClause(Collections.unmodifiableSet(positiveVariables),
62 Collections.unmodifiableList(filteredLiterals)); 64 Collections.unmodifiableList(filteredLiterals));
63 return new ClauseResult(clause); 65 return new ClauseResult(clause);
@@ -67,16 +69,16 @@ class ClausePostProcessor {
67 for (var literal : literals) { 69 for (var literal : literals) {
68 if (isPositiveEquivalence(literal)) { 70 if (isPositiveEquivalence(literal)) {
69 var equivalenceLiteral = (EquivalenceLiteral) literal; 71 var equivalenceLiteral = (EquivalenceLiteral) literal;
70 mergeVariables(equivalenceLiteral.left(), equivalenceLiteral.right()); 72 mergeVariables(equivalenceLiteral.getLeft(), equivalenceLiteral.getRight());
71 } 73 }
72 } 74 }
73 } 75 }
74 76
75 private static boolean isPositiveEquivalence(Literal literal) { 77 private static boolean isPositiveEquivalence(Literal literal) {
76 return literal instanceof EquivalenceLiteral equivalenceLiteral && equivalenceLiteral.positive(); 78 return literal instanceof EquivalenceLiteral equivalenceLiteral && equivalenceLiteral.isPositive();
77 } 79 }
78 80
79 private void mergeVariables(NodeVariable left, NodeVariable right) { 81 private void mergeVariables(Variable left, Variable right) {
80 var leftRepresentative = getRepresentative(left); 82 var leftRepresentative = getRepresentative(left);
81 var rightRepresentative = getRepresentative(right); 83 var rightRepresentative = getRepresentative(right);
82 var leftInfo = parameters.get(leftRepresentative); 84 var leftInfo = parameters.get(leftRepresentative);
@@ -89,7 +91,7 @@ class ClausePostProcessor {
89 } 91 }
90 } 92 }
91 93
92 private void doMergeVariables(NodeVariable parentRepresentative, NodeVariable newChildRepresentative) { 94 private void doMergeVariables(Variable parentRepresentative, Variable newChildRepresentative) {
93 var parentSet = getEquivalentVariables(parentRepresentative); 95 var parentSet = getEquivalentVariables(parentRepresentative);
94 var childSet = getEquivalentVariables(newChildRepresentative); 96 var childSet = getEquivalentVariables(newChildRepresentative);
95 parentSet.addAll(childSet); 97 parentSet.addAll(childSet);
@@ -99,18 +101,18 @@ class ClausePostProcessor {
99 } 101 }
100 } 102 }
101 103
102 private NodeVariable getRepresentative(NodeVariable variable) { 104 private Variable getRepresentative(Variable variable) {
103 return representatives.computeIfAbsent(variable, Function.identity()); 105 return representatives.computeIfAbsent(variable, Function.identity());
104 } 106 }
105 107
106 private Set<NodeVariable> getEquivalentVariables(NodeVariable variable) { 108 private Set<Variable> getEquivalentVariables(Variable variable) {
107 var representative = getRepresentative(variable); 109 var representative = getRepresentative(variable);
108 if (!representative.equals(variable)) { 110 if (!representative.equals(variable)) {
109 throw new AssertionError("NodeVariable %s already has a representative %s" 111 throw new AssertionError("NodeVariable %s already has a representative %s"
110 .formatted(variable, representative)); 112 .formatted(variable, representative));
111 } 113 }
112 return equivalencePartition.computeIfAbsent(variable, key -> { 114 return equivalencePartition.computeIfAbsent(variable, key -> {
113 var set = new HashSet<NodeVariable>(1); 115 var set = new HashSet<Variable>(1);
114 set.add(key); 116 set.add(key);
115 return set; 117 return set;
116 }); 118 });
@@ -121,7 +123,7 @@ class ClausePostProcessor {
121 var left = pair.getKey(); 123 var left = pair.getKey();
122 var right = pair.getValue(); 124 var right = pair.getValue();
123 if (!left.equals(right) && parameters.containsKey(left) && parameters.containsKey(right)) { 125 if (!left.equals(right) && parameters.containsKey(left) && parameters.containsKey(right)) {
124 substitutedLiterals.add(left.isEquivalent(right)); 126 substitutedLiterals.add(new EquivalenceLiteral(true, left, right));
125 } 127 }
126 } 128 }
127 } 129 }
@@ -147,20 +149,7 @@ class ClausePostProcessor {
147 149
148 private void computeExistentiallyQuantifiedVariables() { 150 private void computeExistentiallyQuantifiedVariables() {
149 for (var literal : substitutedLiterals) { 151 for (var literal : substitutedLiterals) {
150 for (var variable : literal.getOutputVariables()) { 152 existentiallyQuantifiedVariables.addAll(literal.getOutputVariables());
151 boolean added = existentiallyQuantifiedVariables.add(variable);
152 if (!variable.isUnifiable()) {
153 var parameterInfo = parameters.get(variable);
154 if (parameterInfo != null && parameterInfo.direction() == ParameterDirection.IN) {
155 throw new IllegalArgumentException("Trying to bind %s parameter %s"
156 .formatted(ParameterDirection.IN, variable));
157 }
158 if (!added) {
159 throw new IllegalArgumentException("Variable %s has multiple assigned values"
160 .formatted(variable));
161 }
162 }
163 }
164 } 153 }
165 } 154 }
166 155
@@ -172,7 +161,7 @@ class ClausePostProcessor {
172 // Inputs count as positive, because they are already bound when we evaluate literals. 161 // Inputs count as positive, because they are already bound when we evaluate literals.
173 positiveVariables.add(variable); 162 positiveVariables.add(variable);
174 } else if (!existentiallyQuantifiedVariables.contains(variable)) { 163 } else if (!existentiallyQuantifiedVariables.contains(variable)) {
175 throw new IllegalArgumentException("Unbound %s parameter %s" 164 throw new InvalidQueryException("Unbound %s parameter %s"
176 .formatted(ParameterDirection.OUT, variable)); 165 .formatted(ParameterDirection.OUT, variable));
177 } 166 }
178 } 167 }
@@ -184,7 +173,7 @@ class ClausePostProcessor {
184 var representative = pair.getKey(); 173 var representative = pair.getKey();
185 if (!positiveVariables.contains(representative)) { 174 if (!positiveVariables.contains(representative)) {
186 var variableSet = pair.getValue(); 175 var variableSet = pair.getValue();
187 throw new IllegalArgumentException("Variables %s were merged by equivalence but are not bound" 176 throw new InvalidQueryException("Variables %s were merged by equivalence but are not bound"
188 .formatted(variableSet)); 177 .formatted(variableSet));
189 } 178 }
190 } 179 }
@@ -196,7 +185,7 @@ class ClausePostProcessor {
196 for (var variable : literal.getPrivateVariables(positiveVariables)) { 185 for (var variable : literal.getPrivateVariables(positiveVariables)) {
197 var oldLiteral = negativeVariablesMap.put(variable, literal); 186 var oldLiteral = negativeVariablesMap.put(variable, literal);
198 if (oldLiteral != null) { 187 if (oldLiteral != null) {
199 throw new IllegalArgumentException("Unbound variable %s appears in multiple literals %s and %s" 188 throw new InvalidQueryException("Unbound variable %s appears in multiple literals %s and %s"
200 .formatted(variable, oldLiteral, literal)); 189 .formatted(variable, oldLiteral, literal));
201 } 190 }
202 } 191 }
@@ -218,11 +207,60 @@ class ClausePostProcessor {
218 variable.addToSortedLiterals(); 207 variable.addToSortedLiterals();
219 } 208 }
220 if (!variableToLiteralInputMap.isEmpty()) { 209 if (!variableToLiteralInputMap.isEmpty()) {
221 throw new IllegalArgumentException("Unbound input variables %s" 210 throw new InvalidQueryException("Unbound input variables %s"
222 .formatted(variableToLiteralInputMap.keySet())); 211 .formatted(variableToLiteralInputMap.keySet()));
223 } 212 }
224 } 213 }
225 214
215 private boolean hasContradictoryCall(Collection<Literal> filteredLiterals) {
216 var positiveCalls = new HashMap<Constraint, Set<CallLiteral>>();
217 for (var literal : filteredLiterals) {
218 if (literal instanceof CallLiteral callLiteral && callLiteral.getPolarity() == CallPolarity.POSITIVE) {
219 var callsOfTarget = positiveCalls.computeIfAbsent(callLiteral.getTarget(), key -> new HashSet<>());
220 callsOfTarget.add(callLiteral);
221 }
222 }
223 for (var literal : filteredLiterals) {
224 if (literal instanceof CallLiteral callLiteral && callLiteral.getPolarity() == CallPolarity.NEGATIVE) {
225 var callsOfTarget = positiveCalls.get(callLiteral.getTarget());
226 if (contradicts(callLiteral, callsOfTarget)) {
227 return true;
228 }
229 }
230 }
231 return false;
232 }
233
234 private boolean contradicts(CallLiteral negativeCall, Collection<CallLiteral> positiveCalls) {
235 if (positiveCalls == null) {
236 return false;
237 }
238 for (var positiveCall : positiveCalls) {
239 if (contradicts(negativeCall, positiveCall)) {
240 return true;
241 }
242 }
243 return false;
244 }
245
246 private boolean contradicts(CallLiteral negativeCall, CallLiteral positiveCall) {
247 var privateVariables = negativeCall.getPrivateVariables(positiveVariables);
248 var negativeArguments = negativeCall.getArguments();
249 var positiveArguments = positiveCall.getArguments();
250 int arity = negativeArguments.size();
251 for (int i = 0; i < arity; i++) {
252 var negativeArgument = negativeArguments.get(i);
253 if (privateVariables.contains(negativeArgument)) {
254 continue;
255 }
256 var positiveArgument = positiveArguments.get(i);
257 if (!negativeArgument.equals(positiveArgument)) {
258 return false;
259 }
260 }
261 return true;
262 }
263
226 private class SortableLiteral implements Comparable<SortableLiteral> { 264 private class SortableLiteral implements Comparable<SortableLiteral> {
227 private final int index; 265 private final int index;
228 private final Literal literal; 266 private final Literal literal;