aboutsummaryrefslogtreecommitdiffstats
path: root/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/optimization/CompositeDirectionalThresholdObjective.xtend
blob: 0aa442f54b7cdd2f5644db21d737dbb429a7bc68 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
package hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner.optimization

import com.google.common.collect.ImmutableList
import java.util.Collection
import org.eclipse.viatra.dse.base.ThreadContext

class CompositeDirectionalThresholdObjective extends DirectionalThresholdObjective {
	val Collection<DirectionalThresholdObjective> objectives

	new(String name, Collection<DirectionalThresholdObjective> objectives) {
		this(name, objectives, getKind(objectives), getThreshold(objectives), getLevel(objectives))
	}

	new(String name, DirectionalThresholdObjective... objectives) {
		this(name, objectives as Collection<DirectionalThresholdObjective>)
	}

	protected new(String name, Iterable<DirectionalThresholdObjective> objectives, ObjectiveKind kind,
		ObjectiveThreshold threshold, int level) {
		super(name, kind, threshold, level)
		this.objectives = ImmutableList.copyOf(objectives)
	}

	override createNew() {
		new CompositeDirectionalThresholdObjective(name, objectives.map[createNew as DirectionalThresholdObjective],
			kind, threshold, level)
	}

	override init(ThreadContext context) {
		for (objective : objectives) {
			objective.init(context)
		}
	}

	override protected getRawFitness(ThreadContext context) {
		var double fitness = 0
		for (objective : objectives) {
			fitness += objective.getFitness(context)
		}
		fitness
	}

	private static def getKind(Collection<DirectionalThresholdObjective> objectives) {
		val kinds = objectives.map[kind].toSet
		if (kinds.size != 1) {
			throw new IllegalArgumentException("Passed objectives must have the same kind")
		}
		kinds.head
	}

	private static def getThreshold(Collection<DirectionalThresholdObjective> objectives) {
		objectives.map[threshold].reduce[a, b|a.merge(b)]
	}

	private static def int getLevel(Collection<DirectionalThresholdObjective> objectives) {
		val levels = objectives.map[level].toSet
		if (levels.size != 1) {
			throw new IllegalArgumentException("Passed objectives must have the same level")
		}
		levels.head
	}
}