aboutsummaryrefslogtreecommitdiffstats
path: root/Framework/hu.bme.mit.inf.dslreasoner.viatra2logic/src/hu/bme/mit/inf/dslreasoner/viatra2logic/NumericZ3ProblemSolver.java
diff options
context:
space:
mode:
authorLibravatar Aren Babikian <aren.babikian@mail.mcgill.ca>2020-12-13 20:37:05 -0500
committerLibravatar Aren Babikian <aren.babikian@mail.mcgill.ca>2021-01-06 00:02:48 +0100
commitca1e53f0ca5e8d61699ce7c34494cb85c2b04ee8 (patch)
tree8a4614b464faf151409bbce112a9de9c1f734f62 /Framework/hu.bme.mit.inf.dslreasoner.viatra2logic/src/hu/bme/mit/inf/dslreasoner/viatra2logic/NumericZ3ProblemSolver.java
parentimplement setup for dreal calls (diff)
downloadVIATRA-Generator-ca1e53f0ca5e8d61699ce7c34494cb85c2b04ee8.tar.gz
VIATRA-Generator-ca1e53f0ca5e8d61699ce7c34494cb85c2b04ee8.tar.zst
VIATRA-Generator-ca1e53f0ca5e8d61699ce7c34494cb85c2b04ee8.zip
prep for refactoring Numeric Probelm Solvers
Diffstat (limited to 'Framework/hu.bme.mit.inf.dslreasoner.viatra2logic/src/hu/bme/mit/inf/dslreasoner/viatra2logic/NumericZ3ProblemSolver.java')
-rw-r--r--Framework/hu.bme.mit.inf.dslreasoner.viatra2logic/src/hu/bme/mit/inf/dslreasoner/viatra2logic/NumericZ3ProblemSolver.java465
1 files changed, 465 insertions, 0 deletions
diff --git a/Framework/hu.bme.mit.inf.dslreasoner.viatra2logic/src/hu/bme/mit/inf/dslreasoner/viatra2logic/NumericZ3ProblemSolver.java b/Framework/hu.bme.mit.inf.dslreasoner.viatra2logic/src/hu/bme/mit/inf/dslreasoner/viatra2logic/NumericZ3ProblemSolver.java
new file mode 100644
index 00000000..8b7ee043
--- /dev/null
+++ b/Framework/hu.bme.mit.inf.dslreasoner.viatra2logic/src/hu/bme/mit/inf/dslreasoner/viatra2logic/NumericZ3ProblemSolver.java
@@ -0,0 +1,465 @@
1package hu.bme.mit.inf.dslreasoner.viatra2logic;
2
3import java.util.ArrayList;
4import java.util.HashMap;
5import java.util.List;
6import java.util.Map;
7
8import org.eclipse.xtext.common.types.JvmIdentifiableElement;
9import org.eclipse.xtext.xbase.XBinaryOperation;
10import org.eclipse.xtext.xbase.XExpression;
11import org.eclipse.xtext.xbase.XFeatureCall;
12import org.eclipse.xtext.xbase.XNumberLiteral;
13
14import com.microsoft.z3.ArithExpr;
15import com.microsoft.z3.BoolExpr;
16import com.microsoft.z3.Context;
17import com.microsoft.z3.Expr;
18import com.microsoft.z3.IntExpr;
19import com.microsoft.z3.Model;
20import com.microsoft.z3.RealExpr;
21import com.microsoft.z3.Solver;
22import com.microsoft.z3.Status;
23import com.microsoft.z3.enumerations.Z3_ast_print_mode;
24
25import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.IntegerElement;
26import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PrimitiveElement;
27import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.RealElement;
28
29
30public class NumericZ3ProblemSolver {
31 private static final String N_Base = "org.eclipse.xtext.xbase.lib.";
32 private static final String N_PLUS = "operator_plus";
33 private static final String N_MINUS = "operator_minus";
34 private static final String N_POWER = "operator_power";
35 private static final String N_MULTIPLY = "operator_multiply";
36 private static final String N_DIVIDE = "operator_divide";
37 private static final String N_MODULO = "operator_modulo";
38 private static final String N_LESSTHAN = "operator_lessThan";
39 private static final String N_LESSEQUALSTHAN = "operator_lessEqualsThan";
40 private static final String N_GREATERTHAN = "operator_greaterThan";
41 private static final String N_GREATEREQUALTHAN = "operator_greaterEqualsThan";
42 private static final String N_EQUALS = "operator_equals";
43 private static final String N_NOTEQUALS = "operator_notEquals";
44 private static final String N_EQUALS3 = "operator_tripleEquals";
45 private static final String N_NOTEQUALS3 = "operator_tripleNotEquals";
46
47
48 private Context ctx;
49 private Solver s;
50 private Map<Object, Expr> varMap;
51
52 long endformingProblem=0;
53 long endSolvingProblem=0;
54 long endFormingSolution=0;
55
56 public NumericZ3ProblemSolver() {
57 HashMap<String, String> cfg = new HashMap<String, String>();
58 cfg.put("model", "true");
59 ctx = new Context(cfg);
60 ctx.setPrintMode(Z3_ast_print_mode.Z3_PRINT_SMTLIB_FULL);
61 s = ctx.mkSolver();
62 varMap = new HashMap<Object, Expr>();
63 }
64
65 public Context getNumericProblemContext() {
66 return ctx;
67 }
68
69 public long getEndformingProblem() {
70 return endformingProblem;
71 }
72
73 public long getEndSolvingProblem() {
74 return endSolvingProblem;
75 }
76
77 public long getEndFormingSolution() {
78 return endFormingSolution;
79 }
80
81 private ArrayList<JvmIdentifiableElement> getJvmIdentifiableElements(XExpression expression) {
82 ArrayList<JvmIdentifiableElement> allElem = new ArrayList<JvmIdentifiableElement>();
83 XExpression left = ((XBinaryOperation) expression).getLeftOperand();
84 XExpression right = ((XBinaryOperation) expression).getRightOperand();
85
86 getJvmIdentifiableElementsHelper(left, allElem);
87 getJvmIdentifiableElementsHelper(right, allElem);
88 return allElem;
89 }
90
91 private void getJvmIdentifiableElementsHelper(XExpression e, List<JvmIdentifiableElement> allElem) {
92 if (e instanceof XFeatureCall) {
93 allElem.add(((XFeatureCall) e).getFeature());
94 } else if (e instanceof XBinaryOperation) {
95 getJvmIdentifiableElementsHelper(((XBinaryOperation) e).getLeftOperand(), allElem);
96 getJvmIdentifiableElementsHelper(((XBinaryOperation) e).getRightOperand(), allElem);
97 }
98 }
99
100 public boolean isSatisfiable(Map<XExpression, Iterable<Map<JvmIdentifiableElement,PrimitiveElement>>> matches) throws Exception {
101 long startformingProblem = System.nanoTime();
102 BoolExpr problemInstance = formNumericProblemInstance(matches);
103 s.add(problemInstance);
104 endformingProblem = System.nanoTime()-startformingProblem;
105 long startSolvingProblem = System.nanoTime();
106 boolean result = (s.check() == Status.SATISFIABLE);
107 endSolvingProblem = System.nanoTime()-startSolvingProblem;
108 this.ctx.close();
109 return result;
110 }
111
112 public Map<PrimitiveElement,Number> getOneSolution(List<PrimitiveElement> objs, Map<XExpression, Iterable<Map<JvmIdentifiableElement,PrimitiveElement>>> matches) throws Exception {
113 Map<PrimitiveElement,Number> sol = new HashMap<PrimitiveElement, Number>();
114 long startformingProblem = System.nanoTime();
115 BoolExpr problemInstance = formNumericProblemInstance(matches);
116 endformingProblem = System.nanoTime()-startformingProblem;
117 //System.out.println("Forming problem: " + (endformingProblem - startformingProblem));
118 s.add(problemInstance);
119 long startSolvingProblem = System.nanoTime();
120 if (s.check() == Status.SATISFIABLE) {
121 Model m = s.getModel();
122 endSolvingProblem = System.nanoTime()-startSolvingProblem;
123 //System.out.println("Solving problem: " + (endSolvingProblem - startSolvingProblem));
124 long startFormingSolution = System.nanoTime();
125 for (PrimitiveElement o: objs) {
126 if(varMap.containsKey(o)) {
127 if (o instanceof IntegerElement) {
128 IntExpr val =(IntExpr) m.evaluate(varMap.get(o), false);
129 Integer oSol = Integer.parseInt(val.toString());
130 sol.put(o, oSol);
131 } else {
132 RealExpr val = (RealExpr) m.evaluate(varMap.get(o), false);
133 Double oSol = Double.parseDouble(val.toString());
134 sol.put(o, oSol);
135 }
136 //System.out.println("Solution:" + o + "->" + oSol);
137
138 } else {
139 //System.out.println("not used var:" + o);
140 }
141 }
142 endFormingSolution = System.nanoTime()-startFormingSolution;
143 //System.out.println("Forming solution: " + (endFormingSolution - startFormingSolution));
144 } else {
145 System.out.println("Unsatisfiable numerical problem");
146 }
147 this.ctx.close();
148 return sol;
149 }
150
151 private BoolExpr formNumericConstraint(XExpression e, Map<JvmIdentifiableElement, PrimitiveElement> aMatch) throws Exception {
152 if (!(e instanceof XBinaryOperation)) {
153 throw new Exception ("error in check expression!!!");
154 }
155
156 String name = ((XBinaryOperation) e).getFeature().getQualifiedName();
157
158 BoolExpr constraint = null;
159
160 ArithExpr left_operand = formNumericConstraintHelper(((XBinaryOperation) e).getLeftOperand(), aMatch);
161 ArithExpr right_operand = formNumericConstraintHelper(((XBinaryOperation) e).getRightOperand(), aMatch);
162
163 if (nameEndsWith(name, N_LESSTHAN)) {
164 constraint = ctx.mkLt(left_operand, right_operand);
165 } else if (nameEndsWith(name, N_LESSEQUALSTHAN)) {
166 constraint = ctx.mkLe(left_operand, right_operand);
167 } else if (nameEndsWith(name, N_GREATERTHAN)) {
168 constraint = ctx.mkGt(left_operand, right_operand);
169 } else if (nameEndsWith(name, N_GREATEREQUALTHAN)) {
170 constraint = ctx.mkGe(left_operand, right_operand);
171 } else if (nameEndsWith(name, N_EQUALS)) {
172 constraint = ctx.mkEq(left_operand, right_operand);
173 } else if (nameEndsWith(name, N_NOTEQUALS)) {
174 constraint = ctx.mkDistinct(left_operand, right_operand);
175 } else if (nameEndsWith(name, N_EQUALS3)) {
176 constraint = ctx.mkGe(left_operand, right_operand); // ???
177 } else if (nameEndsWith(name, N_NOTEQUALS3)) {
178 constraint = ctx.mkGe(left_operand, right_operand); // ???
179 } else {
180 throw new Exception ("Unsupported binary operation " + name);
181 }
182
183 return constraint;
184 }
185
186 private ArithExpr formNumericConstraintHelper(XExpression e, Map<JvmIdentifiableElement, PrimitiveElement> aMatch) throws Exception {
187 ArithExpr expr = null;
188 // Variables
189 if (e instanceof XFeatureCall) {
190 PrimitiveElement matchedObj = aMatch.get(((XFeatureCall) e).getFeature());
191 boolean isInt = matchedObj instanceof IntegerElement;
192 if (!matchedObj.isValueSet()) {
193 if (varMap.get(matchedObj) == null) {
194 String var_name = ((XFeatureCall) e).getFeature().getQualifiedName() + matchedObj.toString();
195 if (isInt) {
196 expr = (ArithExpr) ctx.mkConst(ctx.mkSymbol(var_name), ctx.getIntSort());
197 } else {
198 expr = (ArithExpr) ctx.mkConst(ctx.mkSymbol(var_name), ctx.getRealSort());
199 }
200 varMap.put(matchedObj, expr);
201 } else {
202 expr = (ArithExpr) varMap.get(matchedObj);
203 }
204 } else {
205 if (isInt) {
206 int value = ((IntegerElement) matchedObj).getValue();
207 expr = (ArithExpr) ctx.mkInt(value);
208 } else {
209 double value = ((RealElement) matchedObj).getValue().doubleValue();
210 expr = (ArithExpr) ctx.mkReal(String.valueOf(value));
211 }
212 varMap.put(matchedObj, expr);
213 }
214 }
215 // Constants
216 else if (e instanceof XNumberLiteral) {
217 String value = ((XNumberLiteral) e).getValue();
218 try{
219 int val = Integer.parseInt(value);
220 expr = (ArithExpr) ctx.mkInt(val);
221 } catch(NumberFormatException err){
222 try{
223 expr = (ArithExpr) ctx.mkReal(value);
224 } catch(NumberFormatException err2){}
225 }
226 }
227 // Expressions with operators
228 else if (e instanceof XBinaryOperation) {
229 String name = ((XBinaryOperation) e).getFeature().getQualifiedName();
230 ArithExpr left_operand = formNumericConstraintHelper(((XBinaryOperation) e).getLeftOperand(), aMatch);
231 ArithExpr right_operand = formNumericConstraintHelper(((XBinaryOperation) e).getRightOperand(), aMatch);
232
233 if (nameEndsWith(name, N_PLUS)) {
234 expr = ctx.mkAdd(left_operand, right_operand);
235 } else if (nameEndsWith(name, N_MINUS)) {
236 expr = ctx.mkAdd(left_operand, ctx.mkUnaryMinus(right_operand));
237 } else if (nameEndsWith(name, N_POWER)) {
238 expr = ctx.mkPower(left_operand, right_operand);
239 } else if (nameEndsWith(name, N_MULTIPLY)) {
240 expr = ctx.mkMul(left_operand, right_operand);
241 } else if (nameEndsWith(name, N_DIVIDE)) {
242 expr = ctx.mkDiv(left_operand, right_operand);
243 } else if (nameEndsWith(name, N_MODULO)) {
244 expr = ctx.mkMod((IntExpr)left_operand, (IntExpr)right_operand);
245 } else {
246 throw new Exception ("Unsupported binary operation " + name);
247 }
248 } else {
249 throw new Exception ("Unsupported expression " + e.getClass().getSimpleName());
250 }
251 return expr;
252
253 }
254
255 private boolean nameEndsWith(String name, String end) {
256 return name.startsWith(N_Base) && name.endsWith(end);
257 }
258
259 private BoolExpr formNumericProblemInstance(Map<XExpression, Iterable<Map<JvmIdentifiableElement,PrimitiveElement>>> matches) throws Exception {
260 BoolExpr constraintInstances = ctx.mkTrue();
261 for (XExpression e: matches.keySet()) {
262 Iterable<Map<JvmIdentifiableElement, PrimitiveElement>> matchSets = matches.get(e);
263 for (Map<JvmIdentifiableElement, PrimitiveElement> aMatch: matchSets) {
264 BoolExpr constraintInstance = ctx.mkNot(formNumericConstraint(e, aMatch));
265 constraintInstances = ctx.mkAnd(constraintInstances, constraintInstance);
266 }
267 }
268 return constraintInstances;
269 }
270
271
272 /*
273 public void testIsSat(XExpression expression, Term t) throws Exception {
274 int count = 10000;
275 Map<XExpression, Set<Map<JvmIdentifiableElement,PrimitiveElement>>> matches = new HashMap<XExpression, Set<Map<JvmIdentifiableElement,PrimitiveElement>>>();
276 Set<Map<JvmIdentifiableElement,PrimitiveElement>> matchSet = new HashSet<Map<JvmIdentifiableElement,PrimitiveElement>>();
277 ArrayList<JvmIdentifiableElement> allElem = getJvmIdentifiableElements(expression);
278
279 for (int i = 0; i < count; i++) {
280 Map<JvmIdentifiableElement,PrimitiveElement> match = new HashMap<JvmIdentifiableElement,PrimitiveElement>();
281 for (JvmIdentifiableElement e: allElem) {
282 FakeIntegerElement intE = new FakeIntegerElement();
283 match.put(e, intE);
284 }
285 matchSet.add(match);
286 }
287
288 matches.put(expression, matchSet);
289 long start = System.currentTimeMillis();
290 boolean sat = isSatisfiable(matches);
291 long end = System.currentTimeMillis();
292 System.out.println(sat);
293 System.out.println("Number of matches: " + count);
294 System.out.println("Running time:" + (end - start));
295 }
296
297 public void testIsNotSat(XExpression expression, Term t) throws Exception {
298 Map<XExpression, Set<Map<JvmIdentifiableElement,PrimitiveElement>>> matches = new HashMap<XExpression, Set<Map<JvmIdentifiableElement,PrimitiveElement>>>();
299 Set<Map<JvmIdentifiableElement,PrimitiveElement>> matchSet = new HashSet<Map<JvmIdentifiableElement,PrimitiveElement>>();
300 Map<JvmIdentifiableElement,PrimitiveElement> match = new HashMap<JvmIdentifiableElement,PrimitiveElement>();
301 ArrayList<JvmIdentifiableElement> allElem = getJvmIdentifiableElements(expression);
302 FakeIntegerElement int1 = null;
303 FakeIntegerElement int2 = null;
304 boolean first = true;
305 for (JvmIdentifiableElement e: allElem) {
306 FakeIntegerElement intE = new FakeIntegerElement();
307 if (first) {
308 int1 = intE;
309 first = false;
310 } else {
311 int2 = intE;
312 }
313
314 match.put(e, intE);
315 }
316 matchSet.add(match);
317
318 Map<JvmIdentifiableElement,PrimitiveElement> match2 = new HashMap<JvmIdentifiableElement,PrimitiveElement>();
319 boolean first2 = true;
320 for (JvmIdentifiableElement e: allElem) {
321 if (first2) {
322 match2.put(e, int2);
323 first2 = false;
324 } else {
325 match2.put(e, int1);
326 }
327 }
328 matchSet.add(match2);
329
330 matches.put(expression, matchSet);
331 long start = System.currentTimeMillis();
332 boolean sat = isSatisfiable(matches);
333 long end = System.currentTimeMillis();
334 System.out.println(sat);
335 System.out.println("Number of matches: ");
336 System.out.println("Running time:" + (end - start));
337 }
338 */
339
340 /* public void testGetOneSol(XExpression expression, Term t) throws Exception {
341 int count = 10;
342 Map<XExpression, Iterable<Map<JvmIdentifiableElement,PrimitiveElement>>> matches = new HashMap<XExpression, Iterable<Map<JvmIdentifiableElement,PrimitiveElement>>>();
343 Iterable<Map<JvmIdentifiableElement,PrimitiveElement>> matchSet = new ArrayList<Map<JvmIdentifiableElement,PrimitiveElement>>();
344
345 ArrayList<JvmIdentifiableElement> allElem = getJvmIdentifiableElements(expression);
346 List<PrimitiveElement> obj = new ArrayList<PrimitiveElement>();
347
348 for (int i = 0; i < count; i++) {
349 Map<JvmIdentifiableElement,PrimitiveElement> match = new HashMap<JvmIdentifiableElement,PrimitiveElement>();
350 for (JvmIdentifiableElement e: allElem) {
351 FakeIntegerElement intE = new FakeIntegerElement();
352 obj.add(intE);
353 match.put(e, intE);
354 }
355 ((ArrayList) matchSet).add(match);
356 matches.put(expression, matchSet);
357 }
358
359 long start = System.currentTimeMillis();
360 Map<PrimitiveElement,Integer> sol = getOneSolution(obj, matches);
361 long end = System.currentTimeMillis();
362
363
364 // Print sol
365 for (Object o: sol.keySet()) {
366 System.out.println(o + " :" + sol.get(o));
367 }
368
369
370 System.out.println("Number of matches: " + count);
371 System.out.println("Running time:" + (end - start));
372 }*/
373 /*
374 public void testGetOneSol2(XExpression expression, Term t) throws Exception {
375 int count = 250;
376 Map<XExpression, Set<Map<JvmIdentifiableElement,PrimitiveElement>>> matches = new HashMap<XExpression, Set<Map<JvmIdentifiableElement,PrimitiveElement>>>();
377 Set<Map<JvmIdentifiableElement,PrimitiveElement>> matchSet = new HashSet<Map<JvmIdentifiableElement,PrimitiveElement>>();
378 ArrayList<JvmIdentifiableElement> allElem = getJvmIdentifiableElements(expression);
379 List<Object> obj = new ArrayList<Object>();
380 for (int i = 0; i < count; i++) {
381 Map<JvmIdentifiableElement,PrimitiveElement> match = new HashMap<JvmIdentifiableElement,PrimitiveElement>();
382 FakeIntegerElement int2 = null;
383 boolean first = false;
384 for (JvmIdentifiableElement e: allElem) {
385 FakeIntegerElement intE = new FakeIntegerElement();
386 if (first) {
387 first = false;
388 } else {
389 int2 = intE;
390 }
391 obj.add(intE);
392 match.put(e, intE);
393 }
394
395 Map<JvmIdentifiableElement,PrimitiveElement> match2 = new HashMap<JvmIdentifiableElement,PrimitiveElement>();
396 boolean first2 = true;
397 for (JvmIdentifiableElement e: allElem) {
398 FakeIntegerElement intE = null;
399 if (first2) {
400 intE = int2;
401 first2 = false;
402 } else {
403 intE = new FakeIntegerElement();
404 }
405 obj.add(intE);
406 match2.put(e, intE);
407 }
408
409
410 matchSet.add(match);
411 matchSet.add(match2);
412 }
413 matches.put(expression, matchSet);
414
415 System.out.println("Number of matches: " + matchSet.size());
416 for (int i = 0; i < 10; i++) {
417 Map<Object,Integer> sol = getOneSolution(obj, matches);
418 System.out.println("**********************");
419 Thread.sleep(3000);
420 }
421 }
422
423 public void testGetOneSol3(XExpression expression, Term t) throws Exception {
424 int count = 15000;
425 Random rand = new Random();
426 Map<XExpression, Set<Map<JvmIdentifiableElement,PrimitiveElement>>> matches = new HashMap<XExpression, Set<Map<JvmIdentifiableElement,PrimitiveElement>>>();
427 Set<Map<JvmIdentifiableElement,PrimitiveElement>> matchSet = new HashSet<Map<JvmIdentifiableElement,PrimitiveElement>>();
428 ArrayList<JvmIdentifiableElement> allElem = getJvmIdentifiableElements(expression);
429 List<Object> obj = new ArrayList<Object>();
430 for (int i = 0; i < count; i++) {
431 Map<JvmIdentifiableElement,PrimitiveElement> match = new HashMap<JvmIdentifiableElement,PrimitiveElement>();
432 if (obj.size() > 1) {
433 for (JvmIdentifiableElement e: allElem) {
434 FakeIntegerElement intE = null;
435 int useOld = rand.nextInt(10);
436 if (useOld == 1) {
437 System.out.println("here ");
438 int index = rand.nextInt(obj.size());
439 intE = (FakeIntegerElement) obj.get(index);
440 } else {
441 intE = new FakeIntegerElement();
442 }
443 obj.add(intE);
444 match.put(e, intE);
445 }
446 } else {
447 for (JvmIdentifiableElement e: allElem) {
448 FakeIntegerElement intE = new FakeIntegerElement();
449 obj.add(intE);
450 match.put(e, intE);
451 }
452 }
453 matchSet.add(match);
454 }
455 matches.put(expression, matchSet);
456
457 System.out.println("Number of matches: " + matchSet.size());
458 for (int i = 0; i < 10; i++) {
459 Map<Object,Integer> sol = getOneSolution(obj, matches);
460 System.out.println("**********************");
461 Thread.sleep(3000);
462 }
463 }
464 */
465}