package hu.bme.mit.inf.dslreasoner.viatra2logic; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import org.eclipse.xtext.common.types.JvmIdentifiableElement; import org.eclipse.xtext.xbase.XBinaryOperation; import org.eclipse.xtext.xbase.XExpression; import org.eclipse.xtext.xbase.XFeatureCall; import org.eclipse.xtext.xbase.XNumberLiteral; import com.microsoft.z3.ArithExpr; import com.microsoft.z3.BoolExpr; import com.microsoft.z3.Context; import com.microsoft.z3.Expr; import com.microsoft.z3.IntExpr; import com.microsoft.z3.Model; import com.microsoft.z3.Solver; import com.microsoft.z3.Status; import com.microsoft.z3.enumerations.Z3_ast_print_mode; import hu.bme.mit.inf.dslreasoner.logic.model.logiclanguage.Term; import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.IntegerElement; import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PrimitiveElement; public class NumericProblemSolver { private static final String N_Base = "org.eclipse.xtext.xbase.lib."; private static final String N_PLUS = "operator_plus"; private static final String N_MINUS = "operator_minus"; private static final String N_POWER = "operator_power"; private static final String N_MULTIPLY = "operator_multiply"; private static final String N_DIVIDE = "operator_divide"; private static final String N_MODULO = "operator_modulo"; private static final String N_LESSTHAN = "operator_lessThan"; private static final String N_LESSEQUALSTHAN = "operator_lessEqualsThan"; private static final String N_GREATERTHAN = "operator_greaterThan"; private static final String N_GREATEREQUALTHAN = "operator_greaterEqualsThan"; private static final String N_EQUALS = "operator_equals"; private static final String N_NOTEQUALS = "operator_notEquals"; private static final String N_EQUALS3 = "operator_tripleEquals"; private static final String N_NOTEQUALS3 = "operator_tripleNotEquals"; private Context ctx; private Solver s; private Map varMap; public NumericProblemSolver() { HashMap cfg = new HashMap(); cfg.put("model", "true"); ctx = new Context(cfg); ctx.setPrintMode(Z3_ast_print_mode.Z3_PRINT_SMTLIB_FULL); s = ctx.mkSolver(); varMap = new HashMap(); } public Context getNumericProblemContext() { return ctx; } public void testIsSat(XExpression expression, Term t) throws Exception { int count = 10000; Map>> matches = new HashMap>>(); Set> matchSet = new HashSet>(); ArrayList allElem = getJvmIdentifiableElements(expression); for (int i = 0; i < count; i++) { Map match = new HashMap(); for (JvmIdentifiableElement e: allElem) { FakeIntegerElement intE = new FakeIntegerElement(); match.put(e, intE); } matchSet.add(match); } matches.put(expression, matchSet); long start = System.currentTimeMillis(); boolean sat = isSatisfiable(matches); long end = System.currentTimeMillis(); System.out.println(sat); System.out.println("Number of matches: " + count); System.out.println("Running time:" + (end - start)); } public void testIsNotSat(XExpression expression, Term t) throws Exception { Map>> matches = new HashMap>>(); Set> matchSet = new HashSet>(); Map match = new HashMap(); ArrayList allElem = getJvmIdentifiableElements(expression); FakeIntegerElement int1 = null; FakeIntegerElement int2 = null; boolean first = true; for (JvmIdentifiableElement e: allElem) { FakeIntegerElement intE = new FakeIntegerElement(); if (first) { int1 = intE; first = false; } else { int2 = intE; } match.put(e, intE); } matchSet.add(match); Map match2 = new HashMap(); boolean first2 = true; for (JvmIdentifiableElement e: allElem) { if (first2) { match2.put(e, int2); first2 = false; } else { match2.put(e, int1); } } matchSet.add(match2); matches.put(expression, matchSet); long start = System.currentTimeMillis(); boolean sat = isSatisfiable(matches); long end = System.currentTimeMillis(); System.out.println(sat); System.out.println("Number of matches: "); System.out.println("Running time:" + (end - start)); } public void testGetOneSol(XExpression expression, Term t) throws Exception { int count = 10; Map>> matches = new HashMap>>(); Set> matchSet = new HashSet>(); ArrayList allElem = getJvmIdentifiableElements(expression); List obj = new ArrayList(); for (int i = 0; i < count; i++) { Map match = new HashMap(); for (JvmIdentifiableElement e: allElem) { FakeIntegerElement intE = new FakeIntegerElement(); obj.add(intE); match.put(e, intE); } matchSet.add(match); matches.put(expression, matchSet); } long start = System.currentTimeMillis(); Map sol = getOneSolution(obj, matches); long end = System.currentTimeMillis(); // Print sol for (Object o: sol.keySet()) { System.out.println(o + " :" + sol.get(o)); } System.out.println("Number of matches: " + count); System.out.println("Running time:" + (end - start)); } public void testGetOneSol2(XExpression expression, Term t) throws Exception { int count = 250; Map>> matches = new HashMap>>(); Set> matchSet = new HashSet>(); ArrayList allElem = getJvmIdentifiableElements(expression); List obj = new ArrayList(); for (int i = 0; i < count; i++) { Map match = new HashMap(); FakeIntegerElement int2 = null; boolean first = false; for (JvmIdentifiableElement e: allElem) { FakeIntegerElement intE = new FakeIntegerElement(); if (first) { first = false; } else { int2 = intE; } obj.add(intE); match.put(e, intE); } Map match2 = new HashMap(); boolean first2 = true; for (JvmIdentifiableElement e: allElem) { FakeIntegerElement intE = null; if (first2) { intE = int2; first2 = false; } else { intE = new FakeIntegerElement(); } obj.add(intE); match2.put(e, intE); } matchSet.add(match); matchSet.add(match2); } matches.put(expression, matchSet); System.out.println("Number of matches: " + matchSet.size()); for (int i = 0; i < 10; i++) { Map sol = getOneSolution(obj, matches); System.out.println("**********************"); Thread.sleep(3000); } } public void testGetOneSol3(XExpression expression, Term t) throws Exception { int count = 15000; Random rand = new Random(); Map>> matches = new HashMap>>(); Set> matchSet = new HashSet>(); ArrayList allElem = getJvmIdentifiableElements(expression); List obj = new ArrayList(); for (int i = 0; i < count; i++) { Map match = new HashMap(); if (obj.size() > 1) { for (JvmIdentifiableElement e: allElem) { FakeIntegerElement intE = null; int useOld = rand.nextInt(10); if (useOld == 1) { System.out.println("here "); int index = rand.nextInt(obj.size()); intE = (FakeIntegerElement) obj.get(index); } else { intE = new FakeIntegerElement(); } obj.add(intE); match.put(e, intE); } } else { for (JvmIdentifiableElement e: allElem) { FakeIntegerElement intE = new FakeIntegerElement(); obj.add(intE); match.put(e, intE); } } matchSet.add(match); } matches.put(expression, matchSet); System.out.println("Number of matches: " + matchSet.size()); for (int i = 0; i < 10; i++) { Map sol = getOneSolution(obj, matches); System.out.println("**********************"); Thread.sleep(3000); } } private ArrayList getJvmIdentifiableElements(XExpression expression) { ArrayList allElem = new ArrayList(); XExpression left = ((XBinaryOperation) expression).getLeftOperand(); XExpression right = ((XBinaryOperation) expression).getRightOperand(); getJvmIdentifiableElementsHelper(left, allElem); getJvmIdentifiableElementsHelper(right, allElem); return allElem; } private void getJvmIdentifiableElementsHelper(XExpression e, List allElem) { if (e instanceof XFeatureCall) { allElem.add(((XFeatureCall) e).getFeature()); } else if (e instanceof XBinaryOperation) { getJvmIdentifiableElementsHelper(((XBinaryOperation) e).getLeftOperand(), allElem); getJvmIdentifiableElementsHelper(((XBinaryOperation) e).getRightOperand(), allElem); } } public boolean isSatisfiable(Map>> matches) throws Exception { BoolExpr problemInstance = formNumericProblemInstance(matches); s.add(problemInstance); return s.check() == Status.SATISFIABLE; } public Map getOneSolution(List objs, Map>> matches) throws Exception { Map sol = new HashMap(); long startformingProblem = System.currentTimeMillis(); BoolExpr problemInstance = formNumericProblemInstance(matches); long endformingProblem = System.currentTimeMillis(); System.out.println("Forming problem: " + (endformingProblem - startformingProblem)); s.add(problemInstance); long startSolvingProblem = System.currentTimeMillis(); if (s.check() == Status.SATISFIABLE) { Model m = s.getModel(); long endSolvingProblem = System.currentTimeMillis(); System.out.println("Solving problem: " + (endSolvingProblem - startSolvingProblem)); long startFormingSolution = System.currentTimeMillis(); for (Object o: objs) { IntExpr val =(IntExpr) m.evaluate(varMap.get(o), false); Integer oSol = Integer.parseInt(val.toString()); sol.put(o, oSol); } long endFormingSolution = System.currentTimeMillis(); System.out.println("Forming solution: " + (endFormingSolution - startFormingSolution)); } else { System.out.println("Unsatisfiable"); } return sol; } private BoolExpr formNumericConstraint(XExpression e, Map aMatch) throws Exception { if (!(e instanceof XBinaryOperation)) { throw new Exception ("error in check expression!!!"); } String name = ((XBinaryOperation) e).getFeature().getQualifiedName(); BoolExpr constraint = null; ArithExpr left_operand = formNumericConstraintHelper(((XBinaryOperation) e).getLeftOperand(), aMatch); ArithExpr right_operand = formNumericConstraintHelper(((XBinaryOperation) e).getRightOperand(), aMatch); if (nameEndsWith(name, N_LESSTHAN)) { constraint = ctx.mkLt(left_operand, right_operand); } else if (nameEndsWith(name, N_LESSEQUALSTHAN)) { constraint = ctx.mkLe(left_operand, right_operand); } else if (nameEndsWith(name, N_GREATERTHAN)) { constraint = ctx.mkGt(left_operand, right_operand); } else if (nameEndsWith(name, N_GREATEREQUALTHAN)) { constraint = ctx.mkGe(left_operand, right_operand); } else if (nameEndsWith(name, N_EQUALS)) { constraint = ctx.mkEq(left_operand, right_operand); } else if (nameEndsWith(name, N_NOTEQUALS)) { constraint = ctx.mkDistinct(left_operand, right_operand); } else if (nameEndsWith(name, N_EQUALS3)) { constraint = ctx.mkGe(left_operand, right_operand); // ??? } else if (nameEndsWith(name, N_NOTEQUALS3)) { constraint = ctx.mkGe(left_operand, right_operand); // ??? } else { throw new Exception ("Unsupported binary operation " + name); } return constraint; } // TODO: add variable: state of the solver private ArithExpr formNumericConstraintHelper(XExpression e, Map aMatch) throws Exception { ArithExpr expr = null; // Variables if (e instanceof XFeatureCall) { PrimitiveElement matchedObj = aMatch.get(((XFeatureCall) e).getFeature()); if (!matchedObj.isValueSet()) { if (varMap.get(matchedObj) == null) { String var_name = ((XFeatureCall) e).getFeature().getQualifiedName() + matchedObj.toString(); expr = (ArithExpr) ctx.mkConst(ctx.mkSymbol(var_name), ctx.getIntSort()); varMap.put(matchedObj, expr); } else { expr = (ArithExpr) varMap.get(matchedObj); } } else { int value = ((IntegerElement) matchedObj).getValue(); expr = (ArithExpr) ctx.mkInt(value); varMap.put(matchedObj, expr); } } // Constants else if (e instanceof XNumberLiteral) { String value = ((XNumberLiteral) e).getValue(); try{ int val = Integer.parseInt(value); expr = (ArithExpr) ctx.mkInt(val);} catch(NumberFormatException err){} } // Expressions with operators else if (e instanceof XBinaryOperation) { String name = ((XBinaryOperation) e).getFeature().getQualifiedName(); ArithExpr left_operand = formNumericConstraintHelper(((XBinaryOperation) e).getLeftOperand(), aMatch); ArithExpr right_operand = formNumericConstraintHelper(((XBinaryOperation) e).getRightOperand(), aMatch); if (nameEndsWith(name, N_PLUS)) { expr = ctx.mkAdd(left_operand, right_operand); } else if (nameEndsWith(name, N_MINUS)) { expr = ctx.mkAdd(left_operand, ctx.mkUnaryMinus(right_operand)); } else if (nameEndsWith(name, N_POWER)) { expr = ctx.mkPower(left_operand, right_operand); } else if (nameEndsWith(name, N_MULTIPLY)) { expr = ctx.mkMul(left_operand, right_operand); } else if (nameEndsWith(name, N_DIVIDE)) { expr = ctx.mkDiv(left_operand, right_operand); } else if (nameEndsWith(name, N_MODULO)) { expr = ctx.mkMod((IntExpr)left_operand, (IntExpr)right_operand); } else { throw new Exception ("Unsupported binary operation " + name); } } else { throw new Exception ("Unsupported expression " + e.getClass().getSimpleName()); } return expr; } private boolean nameEndsWith(String name, String end) { return name.startsWith(N_Base) && name.endsWith(end); } private BoolExpr formNumericProblemInstance(Map>> matches) throws Exception { BoolExpr constraintInstances = ctx.mkTrue(); for (XExpression e: matches.keySet()) { Set> matchSets = matches.get(e); for (Map aMatch: matchSets) { BoolExpr constraintInstance = formNumericConstraint(e, aMatch); constraintInstances = ctx.mkAnd(constraintInstances, constraintInstance); } } return constraintInstances; } }