aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/internal/ScopePropagatorAdapterImpl.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/internal/ScopePropagatorAdapterImpl.java')
-rw-r--r--subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/internal/ScopePropagatorAdapterImpl.java244
1 files changed, 244 insertions, 0 deletions
diff --git a/subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/internal/ScopePropagatorAdapterImpl.java b/subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/internal/ScopePropagatorAdapterImpl.java
new file mode 100644
index 00000000..0d594701
--- /dev/null
+++ b/subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/internal/ScopePropagatorAdapterImpl.java
@@ -0,0 +1,244 @@
1/*
2 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.store.reasoning.scope.internal;
7
8import com.google.ortools.linearsolver.MPConstraint;
9import com.google.ortools.linearsolver.MPObjective;
10import com.google.ortools.linearsolver.MPSolver;
11import com.google.ortools.linearsolver.MPVariable;
12import org.eclipse.collections.api.factory.primitive.IntObjectMaps;
13import org.eclipse.collections.api.factory.primitive.IntSets;
14import org.eclipse.collections.api.map.primitive.MutableIntObjectMap;
15import org.eclipse.collections.api.set.primitive.MutableIntSet;
16import tools.refinery.store.model.Interpretation;
17import tools.refinery.store.model.Model;
18import tools.refinery.store.query.ModelQueryAdapter;
19import tools.refinery.store.reasoning.refinement.RefinementResult;
20import tools.refinery.store.reasoning.scope.ScopePropagatorAdapter;
21import tools.refinery.store.reasoning.scope.ScopePropagatorStoreAdapter;
22import tools.refinery.store.representation.cardinality.*;
23import tools.refinery.store.tuple.Tuple;
24
25class ScopePropagatorAdapterImpl implements ScopePropagatorAdapter {
26 private final Model model;
27 private final ScopePropagatorStoreAdapterImpl storeAdapter;
28 private final ModelQueryAdapter queryEngine;
29 private final Interpretation<CardinalityInterval> countInterpretation;
30 private final MPSolver solver;
31 private final MPObjective objective;
32 private final MutableIntObjectMap<MPVariable> variables = IntObjectMaps.mutable.empty();
33 private final MutableIntSet activeVariables = IntSets.mutable.empty();
34 private final TypeScopePropagator[] propagators;
35 private boolean changed = true;
36
37 public ScopePropagatorAdapterImpl(Model model, ScopePropagatorStoreAdapterImpl storeAdapter) {
38 this.model = model;
39 this.storeAdapter = storeAdapter;
40 queryEngine = model.getAdapter(ModelQueryAdapter.class);
41 countInterpretation = model.getInterpretation(storeAdapter.getCountSymbol());
42 solver = MPSolver.createSolver("GLOP");
43 objective = solver.objective();
44 initializeVariables();
45 countInterpretation.addListener(this::countChanged, true);
46 var propagatorFactories = storeAdapter.getPropagatorFactories();
47 propagators = new TypeScopePropagator[propagatorFactories.size()];
48 for (int i = 0; i < propagators.length; i++) {
49 propagators[i] = propagatorFactories.get(i).createPropagator(this);
50 }
51 }
52
53 @Override
54 public Model getModel() {
55 return model;
56 }
57
58 @Override
59 public ScopePropagatorStoreAdapter getStoreAdapter() {
60 return storeAdapter;
61 }
62
63 private void initializeVariables() {
64 var cursor = countInterpretation.getAll();
65 while (cursor.move()) {
66 var interval = cursor.getValue();
67 if (!interval.equals(CardinalityIntervals.ONE)) {
68 int nodeId = cursor.getKey().get(0);
69 createVariable(nodeId, interval);
70 activeVariables.add(nodeId);
71 }
72 }
73 }
74
75 private MPVariable createVariable(int nodeId, CardinalityInterval interval) {
76 double lowerBound = interval.lowerBound();
77 double upperBound = getUpperBound(interval);
78 var variable = solver.makeNumVar(lowerBound, upperBound, "x" + nodeId);
79 variables.put(nodeId, variable);
80 return variable;
81 }
82
83 private void countChanged(Tuple key, CardinalityInterval fromValue, CardinalityInterval toValue,
84 boolean ignoredRestoring) {
85 int nodeId = key.get(0);
86 if ((toValue == null || toValue.equals(CardinalityIntervals.ONE))) {
87 if (fromValue != null && !fromValue.equals(CardinalityIntervals.ONE)) {
88 var variable = variables.get(nodeId);
89 if (variable == null || !activeVariables.remove(nodeId)) {
90 throw new AssertionError("Variable not active: " + nodeId);
91 }
92 variable.setBounds(0, 0);
93 markAsChanged();
94 }
95 return;
96 }
97 if (fromValue == null || fromValue.equals(CardinalityIntervals.ONE)) {
98 activeVariables.add(nodeId);
99 }
100 var variable = variables.get(nodeId);
101 if (variable == null) {
102 createVariable(nodeId, toValue);
103 markAsChanged();
104 return;
105 }
106 double lowerBound = toValue.lowerBound();
107 double upperBound = getUpperBound(toValue);
108 if (variable.lb() != lowerBound) {
109 variable.setLb(lowerBound);
110 markAsChanged();
111 }
112 if (variable.ub() != upperBound) {
113 variable.setUb(upperBound);
114 markAsChanged();
115 }
116 }
117
118 MPConstraint makeConstraint() {
119 return solver.makeConstraint();
120 }
121
122 MPVariable getVariable(int nodeId) {
123 var variable = variables.get(nodeId);
124 if (variable != null) {
125 return variable;
126 }
127 var interval = countInterpretation.get(Tuple.of(nodeId));
128 if (interval == null || interval.equals(CardinalityIntervals.ONE)) {
129 interval = CardinalityIntervals.NONE;
130 } else {
131 activeVariables.add(nodeId);
132 markAsChanged();
133 }
134 return createVariable(nodeId, interval);
135 }
136
137 void markAsChanged() {
138 changed = true;
139 }
140
141 @Override
142 public RefinementResult propagate() {
143 var result = RefinementResult.UNCHANGED;
144 RefinementResult currentRoundResult;
145 do {
146 currentRoundResult = propagateOne();
147 result = result.andThen(currentRoundResult);
148 if (result.isRejected()) {
149 return result;
150 }
151 } while (currentRoundResult != RefinementResult.UNCHANGED);
152 return result;
153 }
154
155 private RefinementResult propagateOne() {
156 queryEngine.flushChanges();
157 if (!changed) {
158 return RefinementResult.UNCHANGED;
159 }
160 changed = false;
161 for (var propagator : propagators) {
162 propagator.updateBounds();
163 }
164 var result = RefinementResult.UNCHANGED;
165 if (activeVariables.isEmpty()) {
166 return checkEmptiness();
167 }
168 var iterator = activeVariables.intIterator();
169 while (iterator.hasNext()) {
170 int nodeId = iterator.next();
171 var variable = variables.get(nodeId);
172 if (variable == null) {
173 throw new AssertionError("Missing active variable: " + nodeId);
174 }
175 result = result.andThen(propagateNode(nodeId, variable));
176 if (result.isRejected()) {
177 return result;
178 }
179 }
180 return result;
181 }
182
183 private RefinementResult checkEmptiness() {
184 var emptinessCheckingResult = solver.solve();
185 return switch (emptinessCheckingResult) {
186 case OPTIMAL, UNBOUNDED -> RefinementResult.UNCHANGED;
187 case INFEASIBLE -> RefinementResult.REJECTED;
188 default -> throw new IllegalStateException("Failed to check for consistency: " + emptinessCheckingResult);
189 };
190 }
191
192 private RefinementResult propagateNode(int nodeId, MPVariable variable) {
193 objective.setCoefficient(variable, 1);
194 try {
195 objective.setMinimization();
196 var minimizationResult = solver.solve();
197 int lowerBound;
198 switch (minimizationResult) {
199 case OPTIMAL -> lowerBound = RoundingUtil.roundUp(objective.value());
200 case UNBOUNDED -> lowerBound = 0;
201 case INFEASIBLE -> {
202 return RefinementResult.REJECTED;
203 }
204 default -> throw new IllegalStateException("Failed to solve for minimum of %s: %s"
205 .formatted(variable, minimizationResult));
206 }
207
208 objective.setMaximization();
209 var maximizationResult = solver.solve();
210 UpperCardinality upperBound;
211 switch (maximizationResult) {
212 case OPTIMAL -> upperBound = UpperCardinalities.atMost(RoundingUtil.roundDown(objective.value()));
213 case UNBOUNDED -> upperBound = UpperCardinalities.UNBOUNDED;
214 case INFEASIBLE -> {
215 return RefinementResult.REJECTED;
216 }
217 default -> throw new IllegalStateException("Failed to solve for maximum of %s: %s"
218 .formatted(variable, minimizationResult));
219 }
220
221 var newInterval = CardinalityIntervals.between(lowerBound, upperBound);
222 var oldInterval = countInterpretation.put(Tuple.of(nodeId), newInterval);
223 if (newInterval.lowerBound() < oldInterval.lowerBound() ||
224 newInterval.upperBound().compareTo(oldInterval.upperBound()) > 0) {
225 throw new IllegalArgumentException("Failed to refine multiplicity %s of node %d to %s"
226 .formatted(oldInterval, nodeId, newInterval));
227 }
228 return newInterval.equals(oldInterval) ? RefinementResult.UNCHANGED : RefinementResult.REFINED;
229 } finally {
230 objective.setCoefficient(variable, 0);
231 }
232 }
233
234 private static double getUpperBound(CardinalityInterval interval) {
235 var upperBound = interval.upperBound();
236 if (upperBound instanceof FiniteUpperCardinality finiteUpperCardinality) {
237 return finiteUpperCardinality.finiteUpperBound();
238 } else if (upperBound instanceof UnboundedUpperCardinality) {
239 return Double.POSITIVE_INFINITY;
240 } else {
241 throw new IllegalArgumentException("Unknown upper bound: " + upperBound);
242 }
243 }
244}