aboutsummaryrefslogtreecommitdiffstats
path: root/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/ViatraReasonerSolutionSaver.xtend
blob: d879d4cc4cba23b118d9a9ab804ac50fe152ff49 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner.dse

import java.util.HashMap
import java.util.Map
import org.eclipse.viatra.dse.api.DSEException
import org.eclipse.viatra.dse.api.Solution
import org.eclipse.viatra.dse.api.SolutionTrajectory
import org.eclipse.viatra.dse.base.ThreadContext
import org.eclipse.viatra.dse.objectives.Fitness
import org.eclipse.viatra.dse.objectives.IObjective
import org.eclipse.viatra.dse.objectives.ObjectiveComparatorHelper
import org.eclipse.viatra.dse.solutionstore.SolutionStore.ISolutionSaver
import org.eclipse.xtend.lib.annotations.Accessors

/**
 * Based on {@link SolutionStore.BestSolutionSaver}.
 * 
 * Will also automatically fill any missing numerical values in the saved solutions
 * using the supplied {@link NumericSolver}.
 */
class ViatraReasonerSolutionSaver implements ISolutionSaver {
	@Accessors val SolutionCopier solutionCopier
	@Accessors val DiversityChecker diversityChecker
	val boolean hasExtremalObjectives
	val int numberOfRequiredSolutions
	val ObjectiveComparatorHelper comparatorHelper
	val Map<SolutionTrajectory, Fitness> trajectories = new HashMap

	@Accessors(PUBLIC_SETTER) var Map<Object, Solution> solutionsCollection

	new(IObjective[][] leveledExtremalObjectives, int numberOfRequiredSolutions, DiversityChecker diversityChecker, NumericSolver numericSolver) {
		this.diversityChecker = diversityChecker
		comparatorHelper = new ObjectiveComparatorHelper(leveledExtremalObjectives)
		hasExtremalObjectives = leveledExtremalObjectives.exists[!empty]
		this.numberOfRequiredSolutions = numberOfRequiredSolutions
		this.solutionCopier = new SolutionCopier(numericSolver)
	}

	override saveSolution(ThreadContext context, Object id, SolutionTrajectory solutionTrajectory) {
		if (hasExtremalObjectives) {
			saveBestSolutionOnly(context, id, solutionTrajectory)
		} else {
			saveAnyDiverseSolution(context, id, solutionTrajectory)
		}
	}

	private def saveBestSolutionOnly(ThreadContext context, Object id, SolutionTrajectory solutionTrajectory) {
		val fitness = context.lastFitness
		if (!shouldSaveSolution(fitness, context)) {
			return false
		}
		val dominatedTrajectories = newArrayList
		for (entry : trajectories.entrySet) {
			val isLastFitnessBetter = comparatorHelper.compare(fitness, entry.value)
			if (isLastFitnessBetter < 0) {
				// Found a trajectory that dominates the current one, no need to save
				return false
			}
			if (isLastFitnessBetter > 0) {
				dominatedTrajectories += entry.key
			}
		}
		if (dominatedTrajectories.size == 0 && !needsMoreSolutionsWithSameFitness) {
			return false
		}
		if (!diversityChecker.newSolution(context, id, dominatedTrajectories.map[solution.stateCode])) {
			return false
		}
		// We must save the new trajectory before removing dominated trajectories
		// to avoid removing the current solution when it is reachable only via dominated trajectories.
		val solutionSaved = basicSaveSolution(context, id, solutionTrajectory, fitness)
		for (dominatedTrajectory : dominatedTrajectories) {
			trajectories -= dominatedTrajectory
			val dominatedSolution = dominatedTrajectory.solution
			if (!dominatedSolution.trajectories.remove(dominatedTrajectory)) {
				throw new DSEException(
					"Dominated solution is not reachable from dominated trajectory. This should never happen!")
			}
			if (dominatedSolution.trajectories.empty) {
				val dominatedSolutionId = dominatedSolution.stateCode
				solutionCopier.markAsObsolete(dominatedSolutionId)
				solutionsCollection -= dominatedSolutionId
			}
		}
		solutionSaved
	}

	private def saveAnyDiverseSolution(ThreadContext context, Object id, SolutionTrajectory solutionTrajectory) {
		val fitness = context.lastFitness
		if (!shouldSaveSolution(fitness, context)) {
			return false
		}
		if (!diversityChecker.newSolution(context, id, emptyList)) {
			return false
		}
		basicSaveSolution(context, id, solutionTrajectory, fitness)
	}

	private def shouldSaveSolution(Fitness fitness, ThreadContext context) {
		return fitness.satisifiesHardObjectives
	}

	private def basicSaveSolution(ThreadContext context, Object id, SolutionTrajectory solutionTrajectory,
		Fitness fitness) {
		var boolean solutionSaved = false
		var dseSolution = solutionsCollection.get(id)
		if (dseSolution === null) {
			solutionCopier.copySolution(context, id)
			dseSolution = new Solution(id, solutionTrajectory)
			solutionsCollection.put(id, dseSolution)
			solutionSaved = true
		} else {
			solutionSaved = dseSolution.trajectories.add(solutionTrajectory)
		}
		if (solutionSaved) {
			solutionTrajectory.solution = dseSolution
			trajectories.put(solutionTrajectory, fitness)
		}
		solutionSaved
	}

	def fitnessMayBeSaved(Fitness fitness) {
		if (!hasExtremalObjectives) {
			return true
		}
		var boolean mayDominate
		for (existingFitness : trajectories.values) {
			val isNewFitnessBetter = comparatorHelper.compare(fitness, existingFitness)
			if (isNewFitnessBetter < 0) {
				return false
			}
			if (isNewFitnessBetter > 0) {
				mayDominate = true
			}
		}
		mayDominate || needsMoreSolutionsWithSameFitness
	}

	private def boolean needsMoreSolutionsWithSameFitness() {
		if (solutionsCollection === null) {
			// The solutions collection will only be initialized upon saving the first solution.
			return true
		}
		solutionsCollection.size < numberOfRequiredSolutions
	}
	
	def getTotalCopierRuntime() {
		solutionCopier.totalCopierRuntime
	}
}