diff options
Diffstat (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/ViatraReasonerSolutionSaver.xtend')
-rw-r--r-- | Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/ViatraReasonerSolutionSaver.xtend | 250 |
1 files changed, 250 insertions, 0 deletions
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/ViatraReasonerSolutionSaver.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/ViatraReasonerSolutionSaver.xtend new file mode 100644 index 00000000..e00f76ff --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/ViatraReasonerSolutionSaver.xtend | |||
@@ -0,0 +1,250 @@ | |||
1 | package hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner.dse | ||
2 | |||
3 | import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.cardinality.Bounds | ||
4 | import hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner.optimization.DirectionalThresholdObjective | ||
5 | import hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner.optimization.IObjectiveBoundsProvider | ||
6 | import hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner.optimization.ObjectiveThreshold | ||
7 | import java.util.HashMap | ||
8 | import java.util.Map | ||
9 | import org.eclipse.viatra.dse.api.DSEException | ||
10 | import org.eclipse.viatra.dse.api.Solution | ||
11 | import org.eclipse.viatra.dse.api.SolutionTrajectory | ||
12 | import org.eclipse.viatra.dse.base.ThreadContext | ||
13 | import org.eclipse.viatra.dse.objectives.Fitness | ||
14 | import org.eclipse.viatra.dse.objectives.IObjective | ||
15 | import org.eclipse.viatra.dse.objectives.ObjectiveComparatorHelper | ||
16 | import org.eclipse.viatra.dse.solutionstore.SolutionStore.ISolutionSaver | ||
17 | import org.eclipse.xtend.lib.annotations.Accessors | ||
18 | |||
19 | /** | ||
20 | * Based on {@link SolutionStore.BestSolutionSaver}. | ||
21 | * | ||
22 | * Will also automatically fill any missing numerical values in the saved solutions | ||
23 | * using the supplied {@link NumericSolver}. | ||
24 | */ | ||
25 | class ViatraReasonerSolutionSaver implements ISolutionSaver, IObjectiveBoundsProvider { | ||
26 | static val TOLERANCE = 1e-10 | ||
27 | |||
28 | @Accessors val SolutionCopier solutionCopier | ||
29 | @Accessors val DiversityChecker diversityChecker | ||
30 | val IObjective[][] leveledExtremalObjectives | ||
31 | val boolean hasExtremalObjectives | ||
32 | val int numberOfRequiredSolutions | ||
33 | val ObjectiveComparatorHelper comparatorHelper | ||
34 | val Map<SolutionTrajectory, Fitness> trajectories = new HashMap | ||
35 | |||
36 | @Accessors var NumericSolver numericSolver | ||
37 | @Accessors var Map<Object, Solution> solutionsCollection | ||
38 | |||
39 | new(IObjective[][] leveledExtremalObjectives, int numberOfRequiredSolutions, DiversityChecker diversityChecker) { | ||
40 | this.diversityChecker = diversityChecker | ||
41 | comparatorHelper = new ObjectiveComparatorHelper(leveledExtremalObjectives) | ||
42 | this.leveledExtremalObjectives = leveledExtremalObjectives | ||
43 | hasExtremalObjectives = leveledExtremalObjectives.exists[!empty] | ||
44 | this.numberOfRequiredSolutions = numberOfRequiredSolutions | ||
45 | this.solutionCopier = new SolutionCopier | ||
46 | } | ||
47 | |||
48 | def setNumericSolver(NumericSolver numericSolver) { | ||
49 | this.numericSolver = numericSolver | ||
50 | solutionCopier.numericSolver = numericSolver | ||
51 | } | ||
52 | |||
53 | override saveSolution(ThreadContext context, Object id, SolutionTrajectory solutionTrajectory) { | ||
54 | if (hasExtremalObjectives) { | ||
55 | saveBestSolutionOnly(context, id, solutionTrajectory) | ||
56 | } else { | ||
57 | saveAnyDiverseSolution(context, id, solutionTrajectory) | ||
58 | } | ||
59 | } | ||
60 | |||
61 | private def saveBestSolutionOnly(ThreadContext context, Object id, SolutionTrajectory solutionTrajectory) { | ||
62 | val fitness = context.lastFitness | ||
63 | if (!shouldSaveSolution(fitness, context)) { | ||
64 | return false | ||
65 | } | ||
66 | println("Found: " + fitness) | ||
67 | val dominatedTrajectories = newArrayList | ||
68 | for (entry : trajectories.entrySet) { | ||
69 | val isLastFitnessBetter = comparatorHelper.compare(fitness, entry.value) | ||
70 | if (isLastFitnessBetter < 0) { | ||
71 | // Found a trajectory that dominates the current one, no need to save | ||
72 | return false | ||
73 | } | ||
74 | if (isLastFitnessBetter > 0) { | ||
75 | dominatedTrajectories += entry.key | ||
76 | } | ||
77 | } | ||
78 | if (dominatedTrajectories.size == 0 && !needsMoreSolutionsWithSameFitness) { | ||
79 | return false | ||
80 | } | ||
81 | if (!diversityChecker.newSolution(context, id, dominatedTrajectories.map[solution.stateCode])) { | ||
82 | return false | ||
83 | } | ||
84 | // We must save the new trajectory before removing dominated trajectories | ||
85 | // to avoid removing the current solution when it is reachable only via dominated trajectories. | ||
86 | val solutionSaved = basicSaveSolution(context, id, solutionTrajectory, fitness) | ||
87 | for (dominatedTrajectory : dominatedTrajectories) { | ||
88 | trajectories -= dominatedTrajectory | ||
89 | val dominatedSolution = dominatedTrajectory.solution | ||
90 | if (!dominatedSolution.trajectories.remove(dominatedTrajectory)) { | ||
91 | throw new DSEException( | ||
92 | "Dominated solution is not reachable from dominated trajectory. This should never happen!") | ||
93 | } | ||
94 | if (dominatedSolution.trajectories.empty) { | ||
95 | val dominatedSolutionId = dominatedSolution.stateCode | ||
96 | solutionCopier.markAsObsolete(dominatedSolutionId) | ||
97 | solutionsCollection -= dominatedSolutionId | ||
98 | } | ||
99 | } | ||
100 | solutionSaved | ||
101 | } | ||
102 | |||
103 | private def saveAnyDiverseSolution(ThreadContext context, Object id, SolutionTrajectory solutionTrajectory) { | ||
104 | val fitness = context.lastFitness | ||
105 | if (!shouldSaveSolution(fitness, context)) { | ||
106 | return false | ||
107 | } | ||
108 | if (!diversityChecker.newSolution(context, id, emptyList)) { | ||
109 | return false | ||
110 | } | ||
111 | basicSaveSolution(context, id, solutionTrajectory, fitness) | ||
112 | } | ||
113 | |||
114 | private def shouldSaveSolution(Fitness fitness, ThreadContext context) { | ||
115 | fitness.satisifiesHardObjectives && (numericSolver === null || numericSolver.currentSatisfiable) | ||
116 | } | ||
117 | |||
118 | private def basicSaveSolution(ThreadContext context, Object id, SolutionTrajectory solutionTrajectory, | ||
119 | Fitness fitness) { | ||
120 | var boolean solutionSaved = false | ||
121 | var dseSolution = solutionsCollection.get(id) | ||
122 | if (dseSolution === null) { | ||
123 | solutionCopier.copySolution(context, id) | ||
124 | dseSolution = new Solution(id, solutionTrajectory) | ||
125 | solutionsCollection.put(id, dseSolution) | ||
126 | solutionSaved = true | ||
127 | } else { | ||
128 | solutionSaved = dseSolution.trajectories.add(solutionTrajectory) | ||
129 | } | ||
130 | if (solutionSaved) { | ||
131 | solutionTrajectory.solution = dseSolution | ||
132 | trajectories.put(solutionTrajectory, fitness) | ||
133 | } | ||
134 | solutionSaved | ||
135 | } | ||
136 | |||
137 | def fitnessMayBeSaved(Fitness fitness) { | ||
138 | if (!hasExtremalObjectives) { | ||
139 | return true | ||
140 | } | ||
141 | var boolean mayDominate | ||
142 | for (existingFitness : trajectories.values) { | ||
143 | val isNewFitnessBetter = comparatorHelper.compare(fitness, existingFitness) | ||
144 | if (isNewFitnessBetter < 0) { | ||
145 | return false | ||
146 | } | ||
147 | if (isNewFitnessBetter > 0) { | ||
148 | mayDominate = true | ||
149 | } | ||
150 | } | ||
151 | mayDominate || needsMoreSolutionsWithSameFitness | ||
152 | } | ||
153 | |||
154 | private def boolean needsMoreSolutionsWithSameFitness() { | ||
155 | if (solutionsCollection === null) { | ||
156 | // The solutions collection will only be initialized upon saving the first solution. | ||
157 | return true | ||
158 | } | ||
159 | solutionsCollection.size < numberOfRequiredSolutions | ||
160 | } | ||
161 | |||
162 | def getTotalCopierRuntime() { | ||
163 | solutionCopier.totalCopierRuntime | ||
164 | } | ||
165 | |||
166 | override computeRequiredBounds(IObjective objective, Bounds bounds) { | ||
167 | if (!hasExtremalObjectives) { | ||
168 | return | ||
169 | } | ||
170 | if (objective instanceof DirectionalThresholdObjective) { | ||
171 | switch (threshold : objective.threshold) { | ||
172 | case ObjectiveThreshold.NO_THRESHOLD: { | ||
173 | // No threshold to set. | ||
174 | } | ||
175 | ObjectiveThreshold.Exclusive: { | ||
176 | switch (kind : objective.kind) { | ||
177 | case HIGHER_IS_BETTER: | ||
178 | bounds.tightenLowerBound(Math.floor(threshold.threshold + 1) as int) | ||
179 | case LOWER_IS_BETTER: | ||
180 | bounds.tightenUpperBound(Math.ceil(threshold.threshold - 1) as int) | ||
181 | default: | ||
182 | throw new IllegalArgumentException("Unknown objective kind" + kind) | ||
183 | } | ||
184 | if (threshold.clampToThreshold) { | ||
185 | return | ||
186 | } | ||
187 | } | ||
188 | ObjectiveThreshold.Inclusive: { | ||
189 | switch (kind : objective.kind) { | ||
190 | case HIGHER_IS_BETTER: | ||
191 | bounds.tightenLowerBound(Math.ceil(threshold.threshold) as int) | ||
192 | case LOWER_IS_BETTER: | ||
193 | bounds.tightenUpperBound(Math.floor(threshold.threshold) as int) | ||
194 | default: | ||
195 | throw new IllegalArgumentException("Unknown objective kind" + kind) | ||
196 | } | ||
197 | if (threshold.clampToThreshold) { | ||
198 | return | ||
199 | } | ||
200 | } | ||
201 | default: | ||
202 | throw new IllegalArgumentException("Unknown threshold: " + threshold) | ||
203 | } | ||
204 | for (level : leveledExtremalObjectives) { | ||
205 | switch (level.size) { | ||
206 | case 0: { | ||
207 | // Nothing to do, wait for the next level. | ||
208 | } | ||
209 | case 1: { | ||
210 | val primaryObjective = level.get(0) | ||
211 | if (primaryObjective != objective) { | ||
212 | // There are no worst-case bounds for secondary objectives. | ||
213 | return | ||
214 | } | ||
215 | } | ||
216 | default: | ||
217 | // There are no worst-case bounds for Pareto front calculation. | ||
218 | return | ||
219 | } | ||
220 | } | ||
221 | val fitnessIterator = trajectories.values.iterator | ||
222 | if (!fitnessIterator.hasNext) { | ||
223 | return | ||
224 | } | ||
225 | val fitness = fitnessIterator.next.get(objective.name) | ||
226 | while (fitnessIterator.hasNext) { | ||
227 | val otherFitness = fitnessIterator.next.get(objective.name) | ||
228 | if (Math.abs(fitness - otherFitness) > TOLERANCE) { | ||
229 | throw new IllegalStateException("Inconsistent fitness: " + objective.name) | ||
230 | } | ||
231 | } | ||
232 | switch (kind : objective.kind) { | ||
233 | case HIGHER_IS_BETTER: | ||
234 | if (needsMoreSolutionsWithSameFitness) { | ||
235 | bounds.tightenLowerBound(Math.floor(fitness) as int) | ||
236 | } else { | ||
237 | bounds.tightenLowerBound(Math.floor(fitness + 1) as int) | ||
238 | } | ||
239 | case LOWER_IS_BETTER: | ||
240 | if (needsMoreSolutionsWithSameFitness) { | ||
241 | bounds.tightenUpperBound(Math.ceil(fitness) as int) | ||
242 | } else { | ||
243 | bounds.tightenUpperBound(Math.ceil(fitness - 1) as int) | ||
244 | } | ||
245 | default: | ||
246 | throw new IllegalArgumentException("Unknown objective kind" + kind) | ||
247 | } | ||
248 | } | ||
249 | } | ||
250 | } | ||