aboutsummaryrefslogtreecommitdiffstats
path: root/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/optimization/CostElementMatchers.xtend
blob: 885b14e8c4e634c79d02c733228b8949f20fe002 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner.optimization

import com.google.common.collect.ImmutableList
import hu.bme.mit.inf.dslreasoner.logic.model.logicproblem.LogicproblemPackage
import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialinterpretationPackage
import java.util.List
import org.eclipse.emf.ecore.EObject
import org.eclipse.viatra.query.runtime.api.IPatternMatch
import org.eclipse.viatra.query.runtime.api.ViatraQueryMatcher
import org.eclipse.xtend.lib.annotations.Data
import hu.bme.mit.inf.dslreasoner.logic.model.logicproblem.LogicProblem
import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation

@FunctionalInterface
interface ParameterScopeBound {
	def double getUpperBound()
}

@Data
class CostElementMatch {
	val IPatternMatch match
	val boolean must

	def isMulti() {
		CostElementMatchers.isMultiMatch(match)
	}
}

@Data
class CostElementMatchers {
	val ViatraQueryMatcher<? extends IPatternMatch> currentMatcher
	val ViatraQueryMatcher<? extends IPatternMatch> mayMatcher
	val ViatraQueryMatcher<? extends IPatternMatch> mustMatcher
	val List<ParameterScopeBound> parameterScopeBounds
	val int weight

	def getCurrentNumberOfMatches() {
		currentMatcher.countMatches
	}

	def getMinimumNumberOfMatches() {
		mustMatcher.countMatches
	}

	def getMaximumNumberOfMatches() {
		var double sum = 0
		val iterator = mayMatcher.streamAllMatches.iterator
		while (iterator.hasNext) {
			val match = iterator.next
			var double product = 1
			val numberOfParameters = parameterScopeBounds.size
			for (var int i = 0; i < numberOfParameters; i++) {
				if (isMulti(match.get(i + 2))) {
					val scopeBound = parameterScopeBounds.get(i)
					product *= scopeBound.upperBound
				}

			}
			sum += product
		}
		sum
	}

	def getMatches() {
		ImmutableList.copyOf(mayMatcher.streamAllMatches.iterator.map [ match |
			new CostElementMatch(match, mustMatcher.isMatch(match))
		])
	}
	
	def projectMayMatch(IPatternMatch match, int... indices) {
		mayMatcher.projectMatch(match, indices)
	}
	
	private static def <T extends IPatternMatch> projectMatch(ViatraQueryMatcher<T> matcher, IPatternMatch match, int... indices) {
		checkMatch(match)
		val n = matcher.specification.parameters.length - 2
		if (indices.length != n) {
			throw new IllegalArgumentException("Invalid number of projection indices")
		}
		val newMatch = matcher.newEmptyMatch
		newMatch.set(0, match.get(0))
		newMatch.set(1, match.get(1))
		for (var int i = 0; i < n; i++) {
			newMatch.set(i + 2, match.get(indices.get(i)))
		}
		if (!matcher.hasMatch(newMatch)) {
			throw new IllegalArgumentException("Projected match does not exist")
		}
		return newMatch
	}

	private static def <T extends IPatternMatch> isMatch(ViatraQueryMatcher<T> matcher, IPatternMatch match) {
		val n = matcher.specification.parameters.length
		if (n != match.specification.parameters.length) {
			throw new IllegalArgumentException("Invalid number of match arguments")
		}
		val newMatch = matcher.newEmptyMatch
		for (var int i = 0; i < n; i++) {
			newMatch.set(i, match.get(i))
		}
		return matcher.hasMatch(newMatch)
	}

	static def isMulti(Object o) {
		if (o instanceof EObject) {
			switch (feature : o.eContainmentFeature) {
				case LogicproblemPackage.eINSTANCE.logicProblem_Elements,
				case PartialinterpretationPackage.eINSTANCE.partialInterpretation_NewElements:
					false
				case PartialinterpretationPackage.eINSTANCE.partialInterpretation_OpenWorldElements:
					true
				default:
					throw new IllegalStateException("Unknown containment feature for element: " + feature)
			}
		} else {
			false
		}
	}

	static def isMultiMatch(IPatternMatch match) {
		checkMatch(match)
		val n = match.specification.parameters.length
		for (var int i = 2; i < n; i++) {
			if (isMulti(match.get(i))) {
				return true
			}
		}
		false
	}
	
	private static def checkMatch(IPatternMatch match) {
		val n = match.specification.parameters.length
		if (n < 2 || !(match.get(0) instanceof LogicProblem) || !(match.get(1) instanceof PartialInterpretation)) {
			throw new IllegalArgumentException("Match is not from the partial interpretation")
		}
	}
}