aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/BoundScopePropagator.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/BoundScopePropagator.java')
-rw-r--r--subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/BoundScopePropagator.java229
1 files changed, 229 insertions, 0 deletions
diff --git a/subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/BoundScopePropagator.java b/subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/BoundScopePropagator.java
new file mode 100644
index 00000000..62aadb4a
--- /dev/null
+++ b/subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/BoundScopePropagator.java
@@ -0,0 +1,229 @@
1/*
2 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.store.reasoning.scope;
7
8import com.google.ortools.linearsolver.MPConstraint;
9import com.google.ortools.linearsolver.MPObjective;
10import com.google.ortools.linearsolver.MPSolver;
11import com.google.ortools.linearsolver.MPVariable;
12import org.eclipse.collections.api.factory.primitive.IntObjectMaps;
13import org.eclipse.collections.api.factory.primitive.IntSets;
14import org.eclipse.collections.api.map.primitive.MutableIntObjectMap;
15import org.eclipse.collections.api.set.primitive.MutableIntSet;
16import tools.refinery.store.dse.propagation.BoundPropagator;
17import tools.refinery.store.dse.propagation.PropagationResult;
18import tools.refinery.store.model.Interpretation;
19import tools.refinery.store.model.Model;
20import tools.refinery.store.query.ModelQueryAdapter;
21import tools.refinery.store.representation.cardinality.*;
22import tools.refinery.store.tuple.Tuple;
23
24class BoundScopePropagator implements BoundPropagator {
25 private final ModelQueryAdapter queryEngine;
26 private final Interpretation<CardinalityInterval> countInterpretation;
27 private final MPSolver solver;
28 private final MPObjective objective;
29 private final MutableIntObjectMap<MPVariable> variables = IntObjectMaps.mutable.empty();
30 private final MutableIntSet activeVariables = IntSets.mutable.empty();
31 private final TypeScopePropagator[] propagators;
32 private boolean changed = true;
33
34 public BoundScopePropagator(Model model, ScopePropagator storeAdapter) {
35 queryEngine = model.getAdapter(ModelQueryAdapter.class);
36 countInterpretation = model.getInterpretation(storeAdapter.getCountSymbol());
37 solver = MPSolver.createSolver("GLOP");
38 objective = solver.objective();
39 initializeVariables();
40 countInterpretation.addListener(this::countChanged, true);
41 var propagatorFactories = storeAdapter.getTypeScopePropagatorFactories();
42 propagators = new TypeScopePropagator[propagatorFactories.size()];
43 for (int i = 0; i < propagators.length; i++) {
44 propagators[i] = propagatorFactories.get(i).createPropagator(this);
45 }
46 }
47
48 ModelQueryAdapter getQueryEngine() {
49 return queryEngine;
50 }
51
52 private void initializeVariables() {
53 var cursor = countInterpretation.getAll();
54 while (cursor.move()) {
55 var interval = cursor.getValue();
56 if (!interval.equals(CardinalityIntervals.ONE)) {
57 int nodeId = cursor.getKey().get(0);
58 createVariable(nodeId, interval);
59 activeVariables.add(nodeId);
60 }
61 }
62 }
63
64 private MPVariable createVariable(int nodeId, CardinalityInterval interval) {
65 double lowerBound = interval.lowerBound();
66 double upperBound = getUpperBound(interval);
67 var variable = solver.makeNumVar(lowerBound, upperBound, "x" + nodeId);
68 variables.put(nodeId, variable);
69 return variable;
70 }
71
72 private void countChanged(Tuple key, CardinalityInterval fromValue, CardinalityInterval toValue,
73 boolean ignoredRestoring) {
74 int nodeId = key.get(0);
75 if ((toValue == null || toValue.equals(CardinalityIntervals.ONE))) {
76 if (fromValue != null && !fromValue.equals(CardinalityIntervals.ONE)) {
77 removeActiveVariable(toValue, nodeId);
78 }
79 return;
80 }
81 if (fromValue == null || fromValue.equals(CardinalityIntervals.ONE)) {
82 activeVariables.add(nodeId);
83 }
84 var variable = variables.get(nodeId);
85 if (variable == null) {
86 createVariable(nodeId, toValue);
87 markAsChanged();
88 return;
89 }
90 double lowerBound = toValue.lowerBound();
91 double upperBound = getUpperBound(toValue);
92 if (variable.lb() != lowerBound) {
93 variable.setLb(lowerBound);
94 markAsChanged();
95 }
96 if (variable.ub() != upperBound) {
97 variable.setUb(upperBound);
98 markAsChanged();
99 }
100 }
101
102 private void removeActiveVariable(CardinalityInterval toValue, int nodeId) {
103 var variable = variables.get(nodeId);
104 if (variable == null || !activeVariables.remove(nodeId)) {
105 throw new AssertionError("Variable not active: " + nodeId);
106 }
107 if (toValue == null) {
108 variable.setBounds(0, 0);
109 } else {
110 // Until queries are flushed and the constraints can be properly updated,
111 // the variable corresponding to the (previous) multi-object has to stand in for a single object.
112 variable.setBounds(1, 1);
113 }
114 markAsChanged();
115 }
116
117 MPConstraint makeConstraint() {
118 return solver.makeConstraint();
119 }
120
121 MPVariable getVariable(int nodeId) {
122 var variable = variables.get(nodeId);
123 if (variable != null) {
124 return variable;
125 }
126 var interval = countInterpretation.get(Tuple.of(nodeId));
127 if (interval == null || interval.equals(CardinalityIntervals.ONE)) {
128 interval = CardinalityIntervals.NONE;
129 } else {
130 activeVariables.add(nodeId);
131 markAsChanged();
132 }
133 return createVariable(nodeId, interval);
134 }
135
136 void markAsChanged() {
137 changed = true;
138 }
139
140 @Override
141 public PropagationResult propagateOne() {
142 queryEngine.flushChanges();
143 if (!changed) {
144 return PropagationResult.UNCHANGED;
145 }
146 changed = false;
147 for (var propagator : propagators) {
148 propagator.updateBounds();
149 }
150 var result = PropagationResult.UNCHANGED;
151 if (activeVariables.isEmpty()) {
152 return checkEmptiness();
153 }
154 var iterator = activeVariables.intIterator();
155 while (iterator.hasNext()) {
156 int nodeId = iterator.next();
157 var variable = variables.get(nodeId);
158 if (variable == null) {
159 throw new AssertionError("Missing active variable: " + nodeId);
160 }
161 result = result.andThen(propagateNode(nodeId, variable));
162 if (result.isRejected()) {
163 return result;
164 }
165 }
166 return result;
167 }
168
169 private PropagationResult checkEmptiness() {
170 var emptinessCheckingResult = solver.solve();
171 return switch (emptinessCheckingResult) {
172 case OPTIMAL, UNBOUNDED -> PropagationResult.UNCHANGED;
173 case INFEASIBLE -> PropagationResult.REJECTED;
174 default -> throw new IllegalStateException("Failed to check for consistency: " + emptinessCheckingResult);
175 };
176 }
177
178 private PropagationResult propagateNode(int nodeId, MPVariable variable) {
179 objective.setCoefficient(variable, 1);
180 try {
181 objective.setMinimization();
182 var minimizationResult = solver.solve();
183 int lowerBound;
184 switch (minimizationResult) {
185 case OPTIMAL -> lowerBound = RoundingUtil.roundUp(objective.value());
186 case UNBOUNDED -> lowerBound = 0;
187 case INFEASIBLE -> {
188 return PropagationResult.REJECTED;
189 }
190 default -> throw new IllegalStateException("Failed to solve for minimum of %s: %s"
191 .formatted(variable, minimizationResult));
192 }
193
194 objective.setMaximization();
195 var maximizationResult = solver.solve();
196 UpperCardinality upperBound;
197 switch (maximizationResult) {
198 case OPTIMAL -> upperBound = UpperCardinalities.atMost(RoundingUtil.roundDown(objective.value()));
199 // Problem was feasible when minimizing, the only possible source of {@code UNBOUNDED_OR_INFEASIBLE} is
200 // an unbounded maximization problem. See https://github.com/google/or-tools/issues/3319
201 case UNBOUNDED, INFEASIBLE -> upperBound = UpperCardinalities.UNBOUNDED;
202 default -> throw new IllegalStateException("Failed to solve for maximum of %s: %s"
203 .formatted(variable, minimizationResult));
204 }
205
206 var newInterval = CardinalityIntervals.between(lowerBound, upperBound);
207 var oldInterval = countInterpretation.put(Tuple.of(nodeId), newInterval);
208 if (newInterval.lowerBound() < oldInterval.lowerBound() ||
209 newInterval.upperBound().compareTo(oldInterval.upperBound()) > 0) {
210 throw new IllegalArgumentException("Failed to refine multiplicity %s of node %d to %s"
211 .formatted(oldInterval, nodeId, newInterval));
212 }
213 return newInterval.equals(oldInterval) ? PropagationResult.UNCHANGED : PropagationResult.PROPAGATED;
214 } finally {
215 objective.setCoefficient(variable, 0);
216 }
217 }
218
219 private static double getUpperBound(CardinalityInterval interval) {
220 var upperBound = interval.upperBound();
221 if (upperBound instanceof FiniteUpperCardinality finiteUpperCardinality) {
222 return finiteUpperCardinality.finiteUpperBound();
223 } else if (upperBound instanceof UnboundedUpperCardinality) {
224 return Double.POSITIVE_INFINITY;
225 } else {
226 throw new IllegalArgumentException("Unknown upper bound: " + upperBound);
227 }
228 }
229}