diff options
Diffstat (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/NumericSolver.xtend')
-rw-r--r-- | Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/NumericSolver.xtend | 164 |
1 files changed, 164 insertions, 0 deletions
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/NumericSolver.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/NumericSolver.xtend new file mode 100644 index 00000000..71793aa6 --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/NumericSolver.xtend | |||
@@ -0,0 +1,164 @@ | |||
1 | package hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner.dse | ||
2 | |||
3 | import hu.bme.mit.inf.dslreasoner.viatra2logic.NumericProblemSolver | ||
4 | import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.ModelGenerationMethod | ||
5 | import java.util.HashMap | ||
6 | import org.eclipse.viatra.query.runtime.api.ViatraQueryEngine | ||
7 | import org.eclipse.viatra.query.runtime.api.IPatternMatch | ||
8 | import org.eclipse.viatra.query.runtime.api.ViatraQueryMatcher | ||
9 | import org.eclipse.viatra.query.runtime.matchers.psystem.PConstraint | ||
10 | import hu.bme.mit.inf.dslreasoner.viatra2logic.NumericTranslator | ||
11 | import org.eclipse.viatra.dse.base.ThreadContext | ||
12 | import org.eclipse.emf.ecore.EObject | ||
13 | import java.util.Map | ||
14 | import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation | ||
15 | import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PrimitiveElement | ||
16 | import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.BooleanElement | ||
17 | import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.IntegerElement | ||
18 | import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.StringElement | ||
19 | import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.RealElement | ||
20 | import java.util.List | ||
21 | import java.math.BigDecimal | ||
22 | import java.util.LinkedHashSet | ||
23 | import java.util.LinkedHashMap | ||
24 | |||
25 | class NumericSolver { | ||
26 | val ThreadContext threadContext; | ||
27 | val constraint2UnitPropagationPrecondition = new HashMap<PConstraint,ViatraQueryMatcher<? extends IPatternMatch>> | ||
28 | NumericTranslator t = new NumericTranslator | ||
29 | |||
30 | val boolean caching; | ||
31 | Map<LinkedHashMap<PConstraint, Iterable<List<Integer>>>,Boolean> satisfiabilityCache = new HashMap | ||
32 | |||
33 | var long runtime = 0 | ||
34 | var long cachingTime = 0 | ||
35 | var int numberOfSolverCalls = 0 | ||
36 | var int numberOfCachedSolverCalls = 0 | ||
37 | |||
38 | new(ThreadContext threadContext, ModelGenerationMethod method, boolean caching) { | ||
39 | this.threadContext = threadContext | ||
40 | val engine = threadContext.queryEngine | ||
41 | for(entry : method.unitPropagationPreconditions.entrySet) { | ||
42 | val constraint = entry.key | ||
43 | val querySpec = entry.value | ||
44 | val matcher = querySpec.getMatcher(engine); | ||
45 | constraint2UnitPropagationPrecondition.put(constraint,matcher) | ||
46 | } | ||
47 | this.caching = caching | ||
48 | } | ||
49 | |||
50 | def getRuntime(){runtime} | ||
51 | def getCachingTime(){cachingTime} | ||
52 | def getNumberOfSolverCalls(){numberOfSolverCalls} | ||
53 | def getNumberOfCachedSolverCalls(){numberOfCachedSolverCalls} | ||
54 | |||
55 | def boolean isSatisfiable() { | ||
56 | val start = System.nanoTime | ||
57 | var boolean finalResult | ||
58 | if(constraint2UnitPropagationPrecondition.empty){ | ||
59 | finalResult=true | ||
60 | } else { | ||
61 | val propagatedConstraints = new HashMap | ||
62 | for(entry : constraint2UnitPropagationPrecondition.entrySet) { | ||
63 | val constraint = entry.key | ||
64 | //println(constraint) | ||
65 | val allMatches = entry.value.allMatches.map[it.toArray] | ||
66 | //println(allMatches.toList) | ||
67 | propagatedConstraints.put(constraint,allMatches) | ||
68 | } | ||
69 | if(propagatedConstraints.values.forall[empty]) { | ||
70 | finalResult=true | ||
71 | } else { | ||
72 | if(caching) { | ||
73 | val code = getCode(propagatedConstraints) | ||
74 | val cachedResult = satisfiabilityCache.get(code) | ||
75 | if(cachedResult === null) { | ||
76 | // println('''new problem, call solver''') | ||
77 | // for(entry : code.entrySet) { | ||
78 | // println('''«entry.key» -> «entry.value»''') | ||
79 | // } | ||
80 | //println(code.hashCode) | ||
81 | this.numberOfSolverCalls++ | ||
82 | val res = t.delegateIsSatisfiable(propagatedConstraints) | ||
83 | satisfiabilityCache.put(code,res) | ||
84 | finalResult=res | ||
85 | } else { | ||
86 | //println('''similar problem, answer from cache''') | ||
87 | finalResult=cachedResult | ||
88 | this.numberOfCachedSolverCalls++ | ||
89 | } | ||
90 | } else { | ||
91 | finalResult= t.delegateIsSatisfiable(propagatedConstraints) | ||
92 | this.numberOfSolverCalls++ | ||
93 | } | ||
94 | } | ||
95 | } | ||
96 | this.runtime+=System.nanoTime-start | ||
97 | return finalResult | ||
98 | } | ||
99 | |||
100 | def getCode(HashMap<PConstraint, Iterable<Object[]>> propagatedConstraints) { | ||
101 | val start = System.nanoTime | ||
102 | val involvedObjects = new LinkedHashSet(propagatedConstraints.values.flatten.map[toList].flatten.toList).toList | ||
103 | val res = new LinkedHashMap(propagatedConstraints.mapValues[matches | matches.map[objects | objects.map[object | involvedObjects.indexOf(object)].toList]]) | ||
104 | this.cachingTime += System.nanoTime-start | ||
105 | return res | ||
106 | } | ||
107 | |||
108 | def fillSolutionCopy(Map<EObject, EObject> trace) { | ||
109 | val model = threadContext.getModel as PartialInterpretation | ||
110 | val dataObjects = model.newElements.filter(PrimitiveElement).filter[!model.openWorldElements.contains(it)].toList | ||
111 | if(constraint2UnitPropagationPrecondition.empty) { | ||
112 | fillWithDefaultValues(dataObjects,trace) | ||
113 | } else { | ||
114 | val propagatedConstraints = new HashMap | ||
115 | for(entry : constraint2UnitPropagationPrecondition.entrySet) { | ||
116 | val constraint = entry.key | ||
117 | val allMatches = entry.value.allMatches.map[it.toArray] | ||
118 | propagatedConstraints.put(constraint,allMatches) | ||
119 | } | ||
120 | |||
121 | if(propagatedConstraints.values.forall[empty]) { | ||
122 | fillWithDefaultValues(dataObjects,trace) | ||
123 | } else { | ||
124 | val solution = t.delegateGetSolution(dataObjects,propagatedConstraints) | ||
125 | fillWithSolutions(dataObjects,solution,trace) | ||
126 | } | ||
127 | } | ||
128 | } | ||
129 | |||
130 | def protected fillWithDefaultValues(List<PrimitiveElement> elements, Map<EObject, EObject> trace) { | ||
131 | for(element : elements) { | ||
132 | if(element.valueSet==false) { | ||
133 | val value = getDefaultValue(element) | ||
134 | val target = trace.get(element) as PrimitiveElement | ||
135 | target.fillWithValue(value) | ||
136 | } | ||
137 | } | ||
138 | } | ||
139 | |||
140 | def protected dispatch getDefaultValue(BooleanElement e) {false} | ||
141 | def protected dispatch getDefaultValue(IntegerElement e) {0} | ||
142 | def protected dispatch getDefaultValue(RealElement e) {0.0} | ||
143 | def protected dispatch getDefaultValue(StringElement e) {""} | ||
144 | |||
145 | def protected fillWithSolutions(List<PrimitiveElement> elements, Map<PrimitiveElement, Integer> solution, Map<EObject, EObject> trace) { | ||
146 | for(element : elements) { | ||
147 | if(element.valueSet==false) { | ||
148 | if(solution.containsKey(element)) { | ||
149 | val value = solution.get(element) | ||
150 | val target = trace.get(element) as PrimitiveElement | ||
151 | target.fillWithValue(value) | ||
152 | } else { | ||
153 | val target = trace.get(element) as PrimitiveElement | ||
154 | target.fillWithValue(target.defaultValue) | ||
155 | } | ||
156 | } | ||
157 | } | ||
158 | } | ||
159 | |||
160 | def protected dispatch fillWithValue(BooleanElement e, Object value) {e.valueSet=true e.value=value as Boolean} | ||
161 | def protected dispatch fillWithValue(IntegerElement e, Object value) {e.valueSet=true e.value=value as Integer} | ||
162 | def protected dispatch fillWithValue(RealElement e, Object value) {e.valueSet=true e.value=BigDecimal.valueOf(value as Double) } | ||
163 | def protected dispatch fillWithValue(StringElement e, Object value) {e.valueSet=true e.value=value as String} | ||
164 | } \ No newline at end of file | ||