From 4fe7fce97aedbd516109ef81afc33e00112b7b68 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Fri, 28 Aug 2020 18:58:37 +0200 Subject: Must unit propagation --- .../src/modes3/run/Modes3TypeScopeHint.xtend | 79 ++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 Domains/ca.mcgill.rtgmrt.example.modes3/src/modes3/run/Modes3TypeScopeHint.xtend (limited to 'Domains/ca.mcgill.rtgmrt.example.modes3/src/modes3/run/Modes3TypeScopeHint.xtend') diff --git a/Domains/ca.mcgill.rtgmrt.example.modes3/src/modes3/run/Modes3TypeScopeHint.xtend b/Domains/ca.mcgill.rtgmrt.example.modes3/src/modes3/run/Modes3TypeScopeHint.xtend new file mode 100644 index 00000000..94e5eb08 --- /dev/null +++ b/Domains/ca.mcgill.rtgmrt.example.modes3/src/modes3/run/Modes3TypeScopeHint.xtend @@ -0,0 +1,79 @@ +package modes3.run + +import hu.bme.mit.inf.dslreasoner.ecore2logic.Ecore2Logic +import hu.bme.mit.inf.dslreasoner.ecore2logic.Ecore2Logic_Trace +import hu.bme.mit.inf.dslreasoner.logic.model.logiclanguage.Type +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.Modality +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.cardinality.LinearTypeConstraintHint +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.cardinality.LinearTypeExpressionBuilderFactory +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.patterns.PatternGenerator +import java.util.Map +import modes3.Modes3Package +import modes3.queries.Adjacent +import org.eclipse.viatra.query.runtime.api.IPatternMatch +import org.eclipse.viatra.query.runtime.api.ViatraQueryMatcher +import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PQuery +import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation + +class Modes3TypeScopeHint implements LinearTypeConstraintHint { + static val TURNOUT_NEIGHBOR_COUNT = "turnoutNeighborCount" + + val Type segmentType + val Type turnoutType + + new(extension Ecore2Logic ecore2Logic, Ecore2Logic_Trace ecore2LogicTrace) { + extension val Modes3Package = Modes3Package.eINSTANCE + segmentType = ecore2LogicTrace.TypeofEClass(segment) + turnoutType = ecore2LogicTrace.TypeofEClass(turnout) + } + + override getAdditionalPatterns(extension PatternGenerator patternGenerator, Map fqnToPQuery) { + ''' + pattern «TURNOUT_NEIGHBOR_COUNT»_helper(problem: LogicProblem, interpretation: PartialInterpretation, source: DefinedElement, target: DefinedElement) { + find interpretation(problem, interpretation); + find mustExist(problem, interpretation, source); + find mustExist(problem, interpretation, target); + «typeIndexer.referInstanceOf(turnoutType, Modality.MUST, "source")» + «typeIndexer.referInstanceOf(segmentType, Modality.MUST, "target")» + «relationDefinitionIndexer.referPattern(fqnToPQuery.get(Adjacent.instance.fullyQualifiedName), #["source", "target"], Modality.MUST, true, false)» + } + + pattern «TURNOUT_NEIGHBOR_COUNT»(problem: LogicProblem, interpretation: PartialInterpretation, element: DefinedElement, N) { + find interpretation(problem, interpretation); + find mustExist(problem, interpretation, element); + «typeIndexer.referInstanceOf(turnoutType, Modality.MUST, "element")» + N == count find «TURNOUT_NEIGHBOR_COUNT»_helper(problem, interpretation, element, _); + } + ''' + } + + override createConstraintUpdater(LinearTypeExpressionBuilderFactory builderFactory) { + val turnoutNeighborCountMatcher = builderFactory.createMatcher(TURNOUT_NEIGHBOR_COUNT) + val newNeighbors = builderFactory.createBuilder.add(1, segmentType).build + + return [ partialInterpretation | + val requiredNeighbbors = turnoutNeighborCountMatcher.getRemainingCount(partialInterpretation, 3) + newNeighbors.tightenLowerBound(requiredNeighbbors) + ] + } + + private static def getRemainingCount(ViatraQueryMatcher matcher, + PartialInterpretation partialInterpretation, int capacity) { + val partialMatch = matcher.newEmptyMatch + partialMatch.set(0, partialInterpretation.problem) + partialMatch.set(1, partialInterpretation) + val iterator = matcher.streamAllMatches(partialMatch).iterator + var int max = 0 + while (iterator.hasNext) { + val match = iterator.next + val n = (match.get(3) as Integer).intValue + if (n < capacity) { + val required = capacity - n + if (max < required) { + max = required + } + } + } + max + } +} -- cgit v1.2.3-70-g09d2