From 4fe7fce97aedbd516109ef81afc33e00112b7b68 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Fri, 28 Aug 2020 18:58:37 +0200 Subject: Must unit propagation --- .../rules/RefinementRuleProvider.xtend | 338 +++++++++++++++------ 1 file changed, 241 insertions(+), 97 deletions(-) (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/rules/RefinementRuleProvider.xtend') diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/rules/RefinementRuleProvider.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/rules/RefinementRuleProvider.xtend index 699b095d..dca10baf 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/rules/RefinementRuleProvider.xtend +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/rules/RefinementRuleProvider.xtend @@ -1,5 +1,6 @@ package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.rules +import com.google.common.collect.ImmutableList import hu.bme.mit.inf.dslreasoner.ecore2logic.ecore2logicannotations.InverseRelationAssertion import hu.bme.mit.inf.dslreasoner.ecore2logic.ecore2logicannotations.LowerMultiplicityAssertion import hu.bme.mit.inf.dslreasoner.logic.model.logiclanguage.BoolTypeReference @@ -29,12 +30,14 @@ import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.par import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialinterpretationFactory import java.lang.reflect.Field import java.util.HashMap +import java.util.Iterator import java.util.LinkedHashMap import java.util.LinkedList import java.util.List import java.util.Map import org.eclipse.viatra.query.runtime.api.AdvancedViatraQueryEngine import org.eclipse.viatra.query.runtime.api.GenericPatternMatch +import org.eclipse.viatra.query.runtime.api.IPatternMatch import org.eclipse.viatra.query.runtime.api.IQuerySpecification import org.eclipse.viatra.query.runtime.api.ViatraQueryEngine import org.eclipse.viatra.query.runtime.api.ViatraQueryMatcher @@ -43,6 +46,7 @@ import org.eclipse.viatra.query.runtime.rete.matcher.ReteBackendFactory import org.eclipse.viatra.transformation.runtime.emf.rules.batch.BatchTransformationRule import org.eclipse.viatra.transformation.runtime.emf.rules.batch.BatchTransformationRuleFactory import org.eclipse.xtend.lib.annotations.Data +import org.eclipse.xtend.lib.annotations.FinalFieldsConstructor import org.eclipse.xtext.xbase.lib.Functions.Function0 class RefinementRuleProvider { @@ -50,57 +54,55 @@ class RefinementRuleProvider { val extension PartialinterpretationFactory factory2 = PartialinterpretationFactory.eINSTANCE val extension LogiclanguageFactory factory3 = LogiclanguageFactory.eINSTANCE - var AdvancedViatraQueryEngine queryEngine - var Field delayMessageDelivery - def canonizeName(String name) { return name.replace(' ', '_') } + def createUnitPrulePropagator(LogicProblem p, PartialInterpretation i, GeneratedPatterns patterns, + ScopePropagator scopePropagator, ModelGenerationStatistics statistics) { + new UnitRulePropagator(p, i, this, scopePropagator, patterns.mustRelationPropagationQueries, statistics) + } + def LinkedHashMap>> createObjectRefinementRules( LogicProblem p, PartialInterpretation i, GeneratedPatterns patterns, - ScopePropagator scopePropagator, + UnitRulePropagator unitRulePropagator, boolean nameNewElement, ModelGenerationStatistics statistics ) { val res = new LinkedHashMap val recursiveObjectCreation = recursiveObjectCreation(p, i) - queryEngine = ViatraQueryEngine.on(new EMFScope(i)) as AdvancedViatraQueryEngine - delayMessageDelivery = queryEngine.class.getDeclaredField("delayMessageDelivery") - delayMessageDelivery.accessible = true for (LHSEntry : patterns.refineObjectQueries.entrySet) { val containmentRelation = LHSEntry.key.containmentRelation val inverseRelation = LHSEntry.key.inverseContainment val type = LHSEntry.key.newType val lhs = LHSEntry.value as IQuerySpecification> val rule = createObjectCreationRule(p, containmentRelation, inverseRelation, type, - recursiveObjectCreation.get(type), lhs, nameNewElement, scopePropagator, statistics) + recursiveObjectCreation.get(type), lhs, nameNewElement, unitRulePropagator, statistics) res.put(LHSEntry.key, rule) } return res } def private createObjectCreationRule(LogicProblem p, Relation containmentRelation, Relation inverseRelation, - Type type, List recursiceObjectCreations, + Type type, List recursiveObjectCreations, IQuerySpecification> lhs, boolean nameNewElement, - ScopePropagator scopePropagator, ModelGenerationStatistics statistics) { + UnitRulePropagator unitRulePropagator, ModelGenerationStatistics statistics) { val name = '''addObject_«type.name.canonizeName»«IF containmentRelation!==null»_by_«containmentRelation.name.canonizeName»«ENDIF»''' val ruleBuilder = factory.createRule(lhs).name(name) if (containmentRelation !== null) { if (inverseRelation !== null) { ruleBuilder.action [ match | statistics.incrementTransformationCount -// println(name) +// println(name) + val startTime = System.nanoTime // val problem = match.get(0) as LogicProblem val interpretation = match.get(1) as PartialInterpretation val relationInterpretation = match.get(2) as PartialRelationInterpretation val inverseRelationInterpretation = match.get(3) as PartialRelationInterpretation val typeInterpretation = match.get(4) as PartialComplexTypeInterpretation val container = match.get(5) as DefinedElement - - val startTime = System.nanoTime createObjectActionWithContainmentAndInverse( nameNewElement, interpretation, @@ -109,29 +111,24 @@ class RefinementRuleProvider { relationInterpretation, inverseRelationInterpretation, [createDefinedElement], - recursiceObjectCreations, - scopePropagator + recursiveObjectCreations, + unitRulePropagator ) statistics.addExecutionTime(System.nanoTime - startTime) - flushQueryEngine(scopePropagator) - - // Scope propagation - val propagatorStartTime = System.nanoTime - scopePropagator.propagateAllScopeConstraints() - statistics.addScopePropagationTime(System.nanoTime - propagatorStartTime) + unitRulePropagator.propagate ] } else { ruleBuilder.action [ match | statistics.incrementTransformationCount // println(name) + val startTime = System.nanoTime // val problem = match.get(0) as LogicProblem val interpretation = match.get(1) as PartialInterpretation val relationInterpretation = match.get(2) as PartialRelationInterpretation val typeInterpretation = match.get(3) as PartialComplexTypeInterpretation val container = match.get(4) as DefinedElement - val startTime = System.nanoTime createObjectActionWithContainment( nameNewElement, interpretation, @@ -139,44 +136,34 @@ class RefinementRuleProvider { container, relationInterpretation, [createDefinedElement], - recursiceObjectCreations, - scopePropagator + recursiveObjectCreations, + unitRulePropagator ) statistics.addExecutionTime(System.nanoTime - startTime) - flushQueryEngine(scopePropagator) - - // Scope propagation - val propagatorStartTime = System.nanoTime - scopePropagator.propagateAllScopeConstraints() - statistics.addScopePropagationTime(System.nanoTime - propagatorStartTime) + unitRulePropagator.propagate ] } } else { ruleBuilder.action [ match | statistics.incrementTransformationCount // println(name) + val startTime = System.nanoTime // val problem = match.get(0) as LogicProblem val interpretation = match.get(1) as PartialInterpretation val typeInterpretation = match.get(2) as PartialComplexTypeInterpretation - val startTime = System.nanoTime createObjectAction( nameNewElement, interpretation, typeInterpretation, [createDefinedElement], - recursiceObjectCreations, - scopePropagator + recursiveObjectCreations, + unitRulePropagator ) statistics.addExecutionTime(System.nanoTime - startTime) - flushQueryEngine(scopePropagator) - - // Scope propagation - val propagatorStartTime = System.nanoTime - scopePropagator.propagateAllScopeConstraints() - statistics.addScopePropagationTime(System.nanoTime - propagatorStartTime) + unitRulePropagator.propagate ] } return ruleBuilder.build @@ -342,14 +329,14 @@ class RefinementRuleProvider { [createStringElement] } - def createRelationRefinementRules(GeneratedPatterns patterns, ScopePropagator scopePropagator, + def createRelationRefinementRules(GeneratedPatterns patterns, UnitRulePropagator unitRulePropagator, ModelGenerationStatistics statistics) { val res = new LinkedHashMap - for (LHSEntry : patterns.refinerelationQueries.entrySet) { + for (LHSEntry : patterns.refineRelationQueries.entrySet) { val declaration = LHSEntry.key.key val inverseReference = LHSEntry.key.value val lhs = LHSEntry.value as IQuerySpecification> - val rule = createRelationRefinementRule(declaration, inverseReference, lhs, scopePropagator, statistics) + val rule = createRelationRefinementRule(declaration, inverseReference, lhs, unitRulePropagator, statistics) res.put(LHSEntry.key, rule) } return res @@ -357,57 +344,29 @@ class RefinementRuleProvider { def private BatchTransformationRule> createRelationRefinementRule( RelationDeclaration declaration, Relation inverseRelation, - IQuerySpecification> lhs, ScopePropagator scopePropagator, + IQuerySpecification> lhs, UnitRulePropagator unitRulePropagator, ModelGenerationStatistics statistics) { val name = '''addRelation_«declaration.name.canonizeName»«IF inverseRelation !== null»_and_«inverseRelation.name.canonizeName»«ENDIF»''' val ruleBuilder = factory.createRule(lhs).name(name) if (inverseRelation === null) { ruleBuilder.action [ match | statistics.incrementTransformationCount - // println(name) - // val problem = match.get(0) as LogicProblem - // val interpretation = match.get(1) as PartialInterpretation - val relationInterpretation = match.get(2) as PartialRelationInterpretation - val src = match.get(3) as DefinedElement - val trg = match.get(4) as DefinedElement - val startTime = System.nanoTime - createRelationLinkAction(src, trg, relationInterpretation) + createRelationLinkAction(match, unitRulePropagator) statistics.addExecutionTime(System.nanoTime - startTime) - // Scope propagation - if (scopePropagator.isPropagationNeededAfterAdditionToRelation(declaration)) { - flushQueryEngine(scopePropagator) - - val propagatorStartTime = System.nanoTime - scopePropagator.propagateAllScopeConstraints() - statistics.addScopePropagationTime(System.nanoTime - propagatorStartTime) - } + unitRulePropagator.propagate ] } else { ruleBuilder.action [ match | statistics.incrementTransformationCount // println(name) - // val problem = match.get(0) as LogicProblem - // val interpretation = match.get(1) as PartialInterpretation - val relationInterpretation = match.get(2) as PartialRelationInterpretation - val inverseInterpretation = match.get(3) as PartialRelationInterpretation - val src = match.get(4) as DefinedElement - val trg = match.get(5) as DefinedElement - val startTime = System.nanoTime - createRelationLinkWithInverse(src, trg, relationInterpretation, inverseInterpretation) + createRelationLinkWithInverse(match, unitRulePropagator) statistics.addExecutionTime(System.nanoTime - startTime) - // Scope propagation - if (scopePropagator.isPropagationNeededAfterAdditionToRelation(declaration)) { - flushQueryEngine(scopePropagator) - - val propagatorStartTime = System.nanoTime - scopePropagator.propagateAllScopeConstraints() - statistics.addScopePropagationTime(System.nanoTime - propagatorStartTime) - } + unitRulePropagator.propagate ] } @@ -418,7 +377,7 @@ class RefinementRuleProvider { // Actions // /////////////////////// protected def void createObjectAction(boolean nameNewElement, ObjectCreationInterpretationData data, - DefinedElement container, ScopePropagator scopePropagator) { + DefinedElement container, UnitRulePropagator unitRulePropagator) { if (data.containerInterpretation !== null) { if (data.containerInverseInterpretation !== null) { createObjectActionWithContainmentAndInverse( @@ -430,7 +389,7 @@ class RefinementRuleProvider { data.containerInverseInterpretation, data.constructor, data.recursiveConstructors, - scopePropagator + unitRulePropagator ) } else { createObjectActionWithContainment( @@ -441,7 +400,7 @@ class RefinementRuleProvider { data.containerInterpretation, data.constructor, data.recursiveConstructors, - scopePropagator + unitRulePropagator ) } } else { @@ -451,7 +410,7 @@ class RefinementRuleProvider { data.typeInterpretation, data.constructor, data.recursiveConstructors, - scopePropagator + unitRulePropagator ) } @@ -466,7 +425,7 @@ class RefinementRuleProvider { PartialRelationInterpretation inverseRelationInterpretation, Function0 constructor, List recursiceObjectCreations, - ScopePropagator scopePropagator + UnitRulePropagator unitRulePropagator ) { val newElement = constructor.apply if (nameNewElement) { @@ -486,14 +445,16 @@ class RefinementRuleProvider { inverseRelationInterpretation.relationlinks += newLink2 // Scope propagation - scopePropagator.decrementTypeScope(typeInterpretation) + unitRulePropagator.decrementTypeScope(typeInterpretation) + unitRulePropagator.addedToRelation(relationInterpretation.interpretationOf) + unitRulePropagator.addedToRelation(inverseRelationInterpretation.interpretationOf) // Existence interpretation.newElements += newElement // Do recursive object creation for (newConstructor : recursiceObjectCreations) { - createObjectAction(nameNewElement, newConstructor, newElement, scopePropagator) + createObjectAction(nameNewElement, newConstructor, newElement, unitRulePropagator) } return newElement @@ -507,7 +468,7 @@ class RefinementRuleProvider { PartialRelationInterpretation relationInterpretation, Function0 constructor, List recursiceObjectCreations, - ScopePropagator scopePropagator + UnitRulePropagator unitRulePropagator ) { val newElement = constructor.apply if (nameNewElement) { @@ -522,16 +483,17 @@ class RefinementRuleProvider { // ContainmentRelation val newLink = factory2.createBinaryElementRelationLink => [it.param1 = container it.param2 = newElement] relationInterpretation.relationlinks += newLink + unitRulePropagator.addedToRelation(relationInterpretation.interpretationOf) // Scope propagation - scopePropagator.decrementTypeScope(typeInterpretation) + unitRulePropagator.decrementTypeScope(typeInterpretation) // Existence interpretation.newElements += newElement // Do recursive object creation for (newConstructor : recursiceObjectCreations) { - createObjectAction(nameNewElement, newConstructor, newElement, scopePropagator) + createObjectAction(nameNewElement, newConstructor, newElement, unitRulePropagator) } return newElement @@ -539,7 +501,7 @@ class RefinementRuleProvider { protected def createObjectAction(boolean nameNewElement, PartialInterpretation interpretation, PartialTypeInterpratation typeInterpretation, Function0 constructor, - List recursiceObjectCreations, ScopePropagator scopePropagator) { + List recursiceObjectCreations, UnitRulePropagator unitRulePropagator) { val newElement = constructor.apply if (nameNewElement) { newElement.name = '''new «interpretation.newElements.size»''' @@ -552,38 +514,220 @@ class RefinementRuleProvider { } // Scope propagation - scopePropagator.decrementTypeScope(typeInterpretation) + unitRulePropagator.decrementTypeScope(typeInterpretation) // Existence interpretation.newElements += newElement // Do recursive object creation for (newConstructor : recursiceObjectCreations) { - createObjectAction(nameNewElement, newConstructor, newElement, scopePropagator) + createObjectAction(nameNewElement, newConstructor, newElement, unitRulePropagator) } return newElement } - protected def boolean createRelationLinkAction(DefinedElement src, DefinedElement trg, - PartialRelationInterpretation relationInterpretation) { + protected def createRelationLinkAction(IPatternMatch match, UnitRulePropagator unitRulePropagator) { + // val problem = match.get(0) as LogicProblem + // val interpretation = match.get(1) as PartialInterpretation + val relationInterpretation = match.get(2) as PartialRelationInterpretation + val src = match.get(3) as DefinedElement + val trg = match.get(4) as DefinedElement + createRelationLinkAction(src, trg, relationInterpretation, unitRulePropagator) + } + + protected def void createRelationLinkAction(DefinedElement src, DefinedElement trg, + PartialRelationInterpretation relationInterpretation, UnitRulePropagator unitRulePropagator) { val link = createBinaryElementRelationLink => [it.param1 = src it.param2 = trg] relationInterpretation.relationlinks += link + unitRulePropagator.addedToRelation(relationInterpretation.interpretationOf) } - protected def boolean createRelationLinkWithInverse(DefinedElement src, DefinedElement trg, - PartialRelationInterpretation relationInterpretation, PartialRelationInterpretation inverseInterpretation) { + protected def void createRelationLinkWithInverse(IPatternMatch match, UnitRulePropagator unitRulePropagator) { + // val problem = match.get(0) as LogicProblem + // val interpretation = match.get(1) as PartialInterpretation + val relationInterpretation = match.get(2) as PartialRelationInterpretation + val inverseInterpretation = match.get(3) as PartialRelationInterpretation + val src = match.get(4) as DefinedElement + val trg = match.get(5) as DefinedElement + createRelationLinkWithInverse(src, trg, relationInterpretation, inverseInterpretation, unitRulePropagator) + } + + protected def void createRelationLinkWithInverse(DefinedElement src, DefinedElement trg, + PartialRelationInterpretation relationInterpretation, PartialRelationInterpretation inverseInterpretation, + UnitRulePropagator unitRulePropagator) { val link = createBinaryElementRelationLink => [it.param1 = src it.param2 = trg] relationInterpretation.relationlinks += link val inverseLink = createBinaryElementRelationLink => [it.param1 = trg it.param2 = src] inverseInterpretation.relationlinks += inverseLink + unitRulePropagator.addedToRelation(relationInterpretation.interpretationOf) + unitRulePropagator.addedToRelation(inverseInterpretation.interpretationOf) } - protected def flushQueryEngine(ScopePropagator scopePropagator) { - if (scopePropagator.queryEngineFlushRequiredBeforePropagation && queryEngine.updatePropagationDelayed) { - delayMessageDelivery.setBoolean(queryEngine, false) - queryEngine.getQueryBackend(ReteBackendFactory.INSTANCE).flushUpdates - delayMessageDelivery.setBoolean(queryEngine, true) + static class UnitRulePropagator { + val LogicProblem p + val PartialInterpretation i + val RefinementRuleProvider refinementRuleProvider + var AdvancedViatraQueryEngine queryEngine + var Field delayMessageDelivery + val ScopePropagator scopePropagator + val List> propagators + val ModelGenerationStatistics statistics + + new(LogicProblem p, PartialInterpretation i, RefinementRuleProvider refinementRuleProvider, + ScopePropagator scopePropagator, + Map, IQuerySpecification>> mustRelationPropagationQueries, + ModelGenerationStatistics statistics) { + this.p = p + this.i = i + this.refinementRuleProvider = refinementRuleProvider + queryEngine = ViatraQueryEngine.on(new EMFScope(i)) as AdvancedViatraQueryEngine + delayMessageDelivery = queryEngine.class.getDeclaredField("delayMessageDelivery") + delayMessageDelivery.accessible = true + this.scopePropagator = scopePropagator + propagators = ImmutableList.copyOf(mustRelationPropagationQueries.entrySet.map [ entry | + val matcher = queryEngine.getMatcher(entry.value) + getPropagator(entry.key.key, entry.key.value, matcher) + ]) + this.statistics = statistics + } + + def decrementTypeScope(PartialTypeInterpratation partialTypeInterpratation) { + scopePropagator.decrementTypeScope(partialTypeInterpratation) + } + + def addedToRelation(Relation r) { + scopePropagator.addedToRelation(r) + } + + def propagate() { + var boolean changed + do { + val scopeChanged = propagateScope() + val mustChanged = propagateMustRelations() + changed = scopeChanged || mustChanged + } while (changed) + } + + protected def flushQueryEngine() { + if (queryEngine.updatePropagationDelayed) { + delayMessageDelivery.setBoolean(queryEngine, false) + queryEngine.getQueryBackend(ReteBackendFactory.INSTANCE).flushUpdates + delayMessageDelivery.setBoolean(queryEngine, true) + } + } + + protected def propagateScope() { + if (scopePropagator.scopePropagationNeeded) { + if (scopePropagator.queryEngineFlushRequiredBeforePropagation) { + flushQueryEngine() + } + val propagatorStartTime = System.nanoTime + scopePropagator.propagateAllScopeConstraints() + statistics.addScopePropagationTime(System.nanoTime - propagatorStartTime) + true + } else { + false + } + } + + protected def propagateMustRelations() { + if (propagators.empty) { + return false + } + flushQueryEngine() + val propagatorStartTime = System.nanoTime + var changed = false + for (propagator : propagators) { + changed = propagator.propagate(p, i, refinementRuleProvider, this) || changed + } + statistics.addMustRelationPropagationTime(System.nanoTime - propagatorStartTime) + changed + } + + private static def getPropagator(Relation relation, Relation inverseRelation, + ViatraQueryMatcher matcher) { + if (inverseRelation === null) { + new MustRelationPropagator(matcher) + } else if (relation == inverseRelation) { + new MustRelationPropagatorWithSelfInverse(matcher) + } else { + new MustRelationPropagatorWithInverse(matcher) + } + } + + @FinalFieldsConstructor + private static abstract class AbstractMustRelationPropagator { + val ViatraQueryMatcher matcher + + def propagate(LogicProblem p, PartialInterpretation i, RefinementRuleProvider refinementRuleProvider, + UnitRulePropagator unitRulePropagator) { + val iterator = getIterator(p, i) + if (!iterator.hasNext) { + return false + } + iterate(iterator, refinementRuleProvider, unitRulePropagator) + true + } + + def iterate(Iterator iterator, RefinementRuleProvider refinementRuleProvider, + UnitRulePropagator unitRulePropagator) { + while (iterator.hasNext) { + doPropagate(iterator.next, refinementRuleProvider, unitRulePropagator) + } + } + + protected def getIterator(LogicProblem p, PartialInterpretation i) { + val partialMatch = matcher.newEmptyMatch + partialMatch.set(0, p) + partialMatch.set(1, i) + matcher.streamAllMatches(partialMatch).iterator + } + + protected def void doPropagate(T match, RefinementRuleProvider refinementRuleProvider, + UnitRulePropagator unitRulePropagator) + } + + private static class MustRelationPropagator extends AbstractMustRelationPropagator { + new(ViatraQueryMatcher matcher) { + super(matcher) + } + + override protected doPropagate(T match, RefinementRuleProvider refinementRuleProvider, + UnitRulePropagator unitRulePropagator) { + refinementRuleProvider.createRelationLinkAction(match, unitRulePropagator) + } + } + + private static class MustRelationPropagatorWithInverse extends AbstractMustRelationPropagator { + new(ViatraQueryMatcher matcher) { + super(matcher) + } + + override protected doPropagate(T match, RefinementRuleProvider refinementRuleProvider, + UnitRulePropagator unitRulePropagator) { + refinementRuleProvider.createRelationLinkWithInverse(match, unitRulePropagator) + } + } + + private static class MustRelationPropagatorWithSelfInverse extends MustRelationPropagatorWithInverse { + new(ViatraQueryMatcher matcher) { + super(matcher) + } + + override iterate(Iterator iterator, RefinementRuleProvider refinementRuleProvider, + UnitRulePropagator unitRulePropagator) { + val pairs = newHashSet + while (iterator.hasNext) { + val match = iterator.next + val src = match.get(4) as DefinedElement + val trg = match.get(5) as DefinedElement + if (!pairs.contains(trg -> src)) { + pairs.add(src -> trg) + doPropagate(match, refinementRuleProvider, unitRulePropagator) + } + } + } } } } -- cgit v1.2.3-70-g09d2