aboutsummaryrefslogtreecommitdiffstats
path: root/Domains/ca.mcgill.rtgmrt.example.modes3/src/modes3/run/Modes3TypeScopeHint.xtend
diff options
context:
space:
mode:
Diffstat (limited to 'Domains/ca.mcgill.rtgmrt.example.modes3/src/modes3/run/Modes3TypeScopeHint.xtend')
-rw-r--r--Domains/ca.mcgill.rtgmrt.example.modes3/src/modes3/run/Modes3TypeScopeHint.xtend79
1 files changed, 79 insertions, 0 deletions
diff --git a/Domains/ca.mcgill.rtgmrt.example.modes3/src/modes3/run/Modes3TypeScopeHint.xtend b/Domains/ca.mcgill.rtgmrt.example.modes3/src/modes3/run/Modes3TypeScopeHint.xtend
new file mode 100644
index 00000000..94e5eb08
--- /dev/null
+++ b/Domains/ca.mcgill.rtgmrt.example.modes3/src/modes3/run/Modes3TypeScopeHint.xtend
@@ -0,0 +1,79 @@
1package modes3.run
2
3import hu.bme.mit.inf.dslreasoner.ecore2logic.Ecore2Logic
4import hu.bme.mit.inf.dslreasoner.ecore2logic.Ecore2Logic_Trace
5import hu.bme.mit.inf.dslreasoner.logic.model.logiclanguage.Type
6import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.Modality
7import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.cardinality.LinearTypeConstraintHint
8import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.cardinality.LinearTypeExpressionBuilderFactory
9import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.patterns.PatternGenerator
10import java.util.Map
11import modes3.Modes3Package
12import modes3.queries.Adjacent
13import org.eclipse.viatra.query.runtime.api.IPatternMatch
14import org.eclipse.viatra.query.runtime.api.ViatraQueryMatcher
15import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PQuery
16import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation
17
18class Modes3TypeScopeHint implements LinearTypeConstraintHint {
19 static val TURNOUT_NEIGHBOR_COUNT = "turnoutNeighborCount"
20
21 val Type segmentType
22 val Type turnoutType
23
24 new(extension Ecore2Logic ecore2Logic, Ecore2Logic_Trace ecore2LogicTrace) {
25 extension val Modes3Package = Modes3Package.eINSTANCE
26 segmentType = ecore2LogicTrace.TypeofEClass(segment)
27 turnoutType = ecore2LogicTrace.TypeofEClass(turnout)
28 }
29
30 override getAdditionalPatterns(extension PatternGenerator patternGenerator, Map<String, PQuery> fqnToPQuery) {
31 '''
32 pattern «TURNOUT_NEIGHBOR_COUNT»_helper(problem: LogicProblem, interpretation: PartialInterpretation, source: DefinedElement, target: DefinedElement) {
33 find interpretation(problem, interpretation);
34 find mustExist(problem, interpretation, source);
35 find mustExist(problem, interpretation, target);
36 «typeIndexer.referInstanceOf(turnoutType, Modality.MUST, "source")»
37 «typeIndexer.referInstanceOf(segmentType, Modality.MUST, "target")»
38 «relationDefinitionIndexer.referPattern(fqnToPQuery.get(Adjacent.instance.fullyQualifiedName), #["source", "target"], Modality.MUST, true, false)»
39 }
40
41 pattern «TURNOUT_NEIGHBOR_COUNT»(problem: LogicProblem, interpretation: PartialInterpretation, element: DefinedElement, N) {
42 find interpretation(problem, interpretation);
43 find mustExist(problem, interpretation, element);
44 «typeIndexer.referInstanceOf(turnoutType, Modality.MUST, "element")»
45 N == count find «TURNOUT_NEIGHBOR_COUNT»_helper(problem, interpretation, element, _);
46 }
47 '''
48 }
49
50 override createConstraintUpdater(LinearTypeExpressionBuilderFactory builderFactory) {
51 val turnoutNeighborCountMatcher = builderFactory.createMatcher(TURNOUT_NEIGHBOR_COUNT)
52 val newNeighbors = builderFactory.createBuilder.add(1, segmentType).build
53
54 return [ partialInterpretation |
55 val requiredNeighbbors = turnoutNeighborCountMatcher.getRemainingCount(partialInterpretation, 3)
56 newNeighbors.tightenLowerBound(requiredNeighbbors)
57 ]
58 }
59
60 private static def <T extends IPatternMatch> getRemainingCount(ViatraQueryMatcher<T> matcher,
61 PartialInterpretation partialInterpretation, int capacity) {
62 val partialMatch = matcher.newEmptyMatch
63 partialMatch.set(0, partialInterpretation.problem)
64 partialMatch.set(1, partialInterpretation)
65 val iterator = matcher.streamAllMatches(partialMatch).iterator
66 var int max = 0
67 while (iterator.hasNext) {
68 val match = iterator.next
69 val n = (match.get(3) as Integer).intValue
70 if (n < capacity) {
71 val required = capacity - n
72 if (max < required) {
73 max = required
74 }
75 }
76 }
77 max
78 }
79}