diff options
Diffstat (limited to 'subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/internal/ScopePropagatorAdapterImpl.java')
-rw-r--r-- | subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/internal/ScopePropagatorAdapterImpl.java | 244 |
1 files changed, 244 insertions, 0 deletions
diff --git a/subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/internal/ScopePropagatorAdapterImpl.java b/subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/internal/ScopePropagatorAdapterImpl.java new file mode 100644 index 00000000..0d594701 --- /dev/null +++ b/subprojects/store-reasoning-scope/src/main/java/tools/refinery/store/reasoning/scope/internal/ScopePropagatorAdapterImpl.java | |||
@@ -0,0 +1,244 @@ | |||
1 | /* | ||
2 | * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/> | ||
3 | * | ||
4 | * SPDX-License-Identifier: EPL-2.0 | ||
5 | */ | ||
6 | package tools.refinery.store.reasoning.scope.internal; | ||
7 | |||
8 | import com.google.ortools.linearsolver.MPConstraint; | ||
9 | import com.google.ortools.linearsolver.MPObjective; | ||
10 | import com.google.ortools.linearsolver.MPSolver; | ||
11 | import com.google.ortools.linearsolver.MPVariable; | ||
12 | import org.eclipse.collections.api.factory.primitive.IntObjectMaps; | ||
13 | import org.eclipse.collections.api.factory.primitive.IntSets; | ||
14 | import org.eclipse.collections.api.map.primitive.MutableIntObjectMap; | ||
15 | import org.eclipse.collections.api.set.primitive.MutableIntSet; | ||
16 | import tools.refinery.store.model.Interpretation; | ||
17 | import tools.refinery.store.model.Model; | ||
18 | import tools.refinery.store.query.ModelQueryAdapter; | ||
19 | import tools.refinery.store.reasoning.refinement.RefinementResult; | ||
20 | import tools.refinery.store.reasoning.scope.ScopePropagatorAdapter; | ||
21 | import tools.refinery.store.reasoning.scope.ScopePropagatorStoreAdapter; | ||
22 | import tools.refinery.store.representation.cardinality.*; | ||
23 | import tools.refinery.store.tuple.Tuple; | ||
24 | |||
25 | class ScopePropagatorAdapterImpl implements ScopePropagatorAdapter { | ||
26 | private final Model model; | ||
27 | private final ScopePropagatorStoreAdapterImpl storeAdapter; | ||
28 | private final ModelQueryAdapter queryEngine; | ||
29 | private final Interpretation<CardinalityInterval> countInterpretation; | ||
30 | private final MPSolver solver; | ||
31 | private final MPObjective objective; | ||
32 | private final MutableIntObjectMap<MPVariable> variables = IntObjectMaps.mutable.empty(); | ||
33 | private final MutableIntSet activeVariables = IntSets.mutable.empty(); | ||
34 | private final TypeScopePropagator[] propagators; | ||
35 | private boolean changed = true; | ||
36 | |||
37 | public ScopePropagatorAdapterImpl(Model model, ScopePropagatorStoreAdapterImpl storeAdapter) { | ||
38 | this.model = model; | ||
39 | this.storeAdapter = storeAdapter; | ||
40 | queryEngine = model.getAdapter(ModelQueryAdapter.class); | ||
41 | countInterpretation = model.getInterpretation(storeAdapter.getCountSymbol()); | ||
42 | solver = MPSolver.createSolver("GLOP"); | ||
43 | objective = solver.objective(); | ||
44 | initializeVariables(); | ||
45 | countInterpretation.addListener(this::countChanged, true); | ||
46 | var propagatorFactories = storeAdapter.getPropagatorFactories(); | ||
47 | propagators = new TypeScopePropagator[propagatorFactories.size()]; | ||
48 | for (int i = 0; i < propagators.length; i++) { | ||
49 | propagators[i] = propagatorFactories.get(i).createPropagator(this); | ||
50 | } | ||
51 | } | ||
52 | |||
53 | @Override | ||
54 | public Model getModel() { | ||
55 | return model; | ||
56 | } | ||
57 | |||
58 | @Override | ||
59 | public ScopePropagatorStoreAdapter getStoreAdapter() { | ||
60 | return storeAdapter; | ||
61 | } | ||
62 | |||
63 | private void initializeVariables() { | ||
64 | var cursor = countInterpretation.getAll(); | ||
65 | while (cursor.move()) { | ||
66 | var interval = cursor.getValue(); | ||
67 | if (!interval.equals(CardinalityIntervals.ONE)) { | ||
68 | int nodeId = cursor.getKey().get(0); | ||
69 | createVariable(nodeId, interval); | ||
70 | activeVariables.add(nodeId); | ||
71 | } | ||
72 | } | ||
73 | } | ||
74 | |||
75 | private MPVariable createVariable(int nodeId, CardinalityInterval interval) { | ||
76 | double lowerBound = interval.lowerBound(); | ||
77 | double upperBound = getUpperBound(interval); | ||
78 | var variable = solver.makeNumVar(lowerBound, upperBound, "x" + nodeId); | ||
79 | variables.put(nodeId, variable); | ||
80 | return variable; | ||
81 | } | ||
82 | |||
83 | private void countChanged(Tuple key, CardinalityInterval fromValue, CardinalityInterval toValue, | ||
84 | boolean ignoredRestoring) { | ||
85 | int nodeId = key.get(0); | ||
86 | if ((toValue == null || toValue.equals(CardinalityIntervals.ONE))) { | ||
87 | if (fromValue != null && !fromValue.equals(CardinalityIntervals.ONE)) { | ||
88 | var variable = variables.get(nodeId); | ||
89 | if (variable == null || !activeVariables.remove(nodeId)) { | ||
90 | throw new AssertionError("Variable not active: " + nodeId); | ||
91 | } | ||
92 | variable.setBounds(0, 0); | ||
93 | markAsChanged(); | ||
94 | } | ||
95 | return; | ||
96 | } | ||
97 | if (fromValue == null || fromValue.equals(CardinalityIntervals.ONE)) { | ||
98 | activeVariables.add(nodeId); | ||
99 | } | ||
100 | var variable = variables.get(nodeId); | ||
101 | if (variable == null) { | ||
102 | createVariable(nodeId, toValue); | ||
103 | markAsChanged(); | ||
104 | return; | ||
105 | } | ||
106 | double lowerBound = toValue.lowerBound(); | ||
107 | double upperBound = getUpperBound(toValue); | ||
108 | if (variable.lb() != lowerBound) { | ||
109 | variable.setLb(lowerBound); | ||
110 | markAsChanged(); | ||
111 | } | ||
112 | if (variable.ub() != upperBound) { | ||
113 | variable.setUb(upperBound); | ||
114 | markAsChanged(); | ||
115 | } | ||
116 | } | ||
117 | |||
118 | MPConstraint makeConstraint() { | ||
119 | return solver.makeConstraint(); | ||
120 | } | ||
121 | |||
122 | MPVariable getVariable(int nodeId) { | ||
123 | var variable = variables.get(nodeId); | ||
124 | if (variable != null) { | ||
125 | return variable; | ||
126 | } | ||
127 | var interval = countInterpretation.get(Tuple.of(nodeId)); | ||
128 | if (interval == null || interval.equals(CardinalityIntervals.ONE)) { | ||
129 | interval = CardinalityIntervals.NONE; | ||
130 | } else { | ||
131 | activeVariables.add(nodeId); | ||
132 | markAsChanged(); | ||
133 | } | ||
134 | return createVariable(nodeId, interval); | ||
135 | } | ||
136 | |||
137 | void markAsChanged() { | ||
138 | changed = true; | ||
139 | } | ||
140 | |||
141 | @Override | ||
142 | public RefinementResult propagate() { | ||
143 | var result = RefinementResult.UNCHANGED; | ||
144 | RefinementResult currentRoundResult; | ||
145 | do { | ||
146 | currentRoundResult = propagateOne(); | ||
147 | result = result.andThen(currentRoundResult); | ||
148 | if (result.isRejected()) { | ||
149 | return result; | ||
150 | } | ||
151 | } while (currentRoundResult != RefinementResult.UNCHANGED); | ||
152 | return result; | ||
153 | } | ||
154 | |||
155 | private RefinementResult propagateOne() { | ||
156 | queryEngine.flushChanges(); | ||
157 | if (!changed) { | ||
158 | return RefinementResult.UNCHANGED; | ||
159 | } | ||
160 | changed = false; | ||
161 | for (var propagator : propagators) { | ||
162 | propagator.updateBounds(); | ||
163 | } | ||
164 | var result = RefinementResult.UNCHANGED; | ||
165 | if (activeVariables.isEmpty()) { | ||
166 | return checkEmptiness(); | ||
167 | } | ||
168 | var iterator = activeVariables.intIterator(); | ||
169 | while (iterator.hasNext()) { | ||
170 | int nodeId = iterator.next(); | ||
171 | var variable = variables.get(nodeId); | ||
172 | if (variable == null) { | ||
173 | throw new AssertionError("Missing active variable: " + nodeId); | ||
174 | } | ||
175 | result = result.andThen(propagateNode(nodeId, variable)); | ||
176 | if (result.isRejected()) { | ||
177 | return result; | ||
178 | } | ||
179 | } | ||
180 | return result; | ||
181 | } | ||
182 | |||
183 | private RefinementResult checkEmptiness() { | ||
184 | var emptinessCheckingResult = solver.solve(); | ||
185 | return switch (emptinessCheckingResult) { | ||
186 | case OPTIMAL, UNBOUNDED -> RefinementResult.UNCHANGED; | ||
187 | case INFEASIBLE -> RefinementResult.REJECTED; | ||
188 | default -> throw new IllegalStateException("Failed to check for consistency: " + emptinessCheckingResult); | ||
189 | }; | ||
190 | } | ||
191 | |||
192 | private RefinementResult propagateNode(int nodeId, MPVariable variable) { | ||
193 | objective.setCoefficient(variable, 1); | ||
194 | try { | ||
195 | objective.setMinimization(); | ||
196 | var minimizationResult = solver.solve(); | ||
197 | int lowerBound; | ||
198 | switch (minimizationResult) { | ||
199 | case OPTIMAL -> lowerBound = RoundingUtil.roundUp(objective.value()); | ||
200 | case UNBOUNDED -> lowerBound = 0; | ||
201 | case INFEASIBLE -> { | ||
202 | return RefinementResult.REJECTED; | ||
203 | } | ||
204 | default -> throw new IllegalStateException("Failed to solve for minimum of %s: %s" | ||
205 | .formatted(variable, minimizationResult)); | ||
206 | } | ||
207 | |||
208 | objective.setMaximization(); | ||
209 | var maximizationResult = solver.solve(); | ||
210 | UpperCardinality upperBound; | ||
211 | switch (maximizationResult) { | ||
212 | case OPTIMAL -> upperBound = UpperCardinalities.atMost(RoundingUtil.roundDown(objective.value())); | ||
213 | case UNBOUNDED -> upperBound = UpperCardinalities.UNBOUNDED; | ||
214 | case INFEASIBLE -> { | ||
215 | return RefinementResult.REJECTED; | ||
216 | } | ||
217 | default -> throw new IllegalStateException("Failed to solve for maximum of %s: %s" | ||
218 | .formatted(variable, minimizationResult)); | ||
219 | } | ||
220 | |||
221 | var newInterval = CardinalityIntervals.between(lowerBound, upperBound); | ||
222 | var oldInterval = countInterpretation.put(Tuple.of(nodeId), newInterval); | ||
223 | if (newInterval.lowerBound() < oldInterval.lowerBound() || | ||
224 | newInterval.upperBound().compareTo(oldInterval.upperBound()) > 0) { | ||
225 | throw new IllegalArgumentException("Failed to refine multiplicity %s of node %d to %s" | ||
226 | .formatted(oldInterval, nodeId, newInterval)); | ||
227 | } | ||
228 | return newInterval.equals(oldInterval) ? RefinementResult.UNCHANGED : RefinementResult.REFINED; | ||
229 | } finally { | ||
230 | objective.setCoefficient(variable, 0); | ||
231 | } | ||
232 | } | ||
233 | |||
234 | private static double getUpperBound(CardinalityInterval interval) { | ||
235 | var upperBound = interval.upperBound(); | ||
236 | if (upperBound instanceof FiniteUpperCardinality finiteUpperCardinality) { | ||
237 | return finiteUpperCardinality.finiteUpperBound(); | ||
238 | } else if (upperBound instanceof UnboundedUpperCardinality) { | ||
239 | return Double.POSITIVE_INFINITY; | ||
240 | } else { | ||
241 | throw new IllegalArgumentException("Unknown upper bound: " + upperBound); | ||
242 | } | ||
243 | } | ||
244 | } | ||