package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.cardinality import com.microsoft.z3.AlgebraicNum import com.microsoft.z3.ArithExpr import com.microsoft.z3.Context import com.microsoft.z3.Expr import com.microsoft.z3.IntNum import com.microsoft.z3.Optimize import com.microsoft.z3.RatNum import com.microsoft.z3.Status import com.microsoft.z3.Symbol import java.math.BigDecimal import java.math.MathContext import java.math.RoundingMode import java.util.Map import org.eclipse.xtend.lib.annotations.FinalFieldsConstructor class Z3PolyhedronSolver implements PolyhedronSolver { val boolean lpRelaxation val double timeoutSeconds @FinalFieldsConstructor new() { } new() { this(false, -1) } override createSaturationOperator(Polyhedron polyhedron) { new Z3SaturationOperator(polyhedron, lpRelaxation, timeoutSeconds) } } class Z3SaturationOperator extends AbstractPolyhedronSaturationOperator { static val INFINITY_SYMBOL_NAME = "oo" static val MULT_SYMBOL_NAME = "*" static val TIMEOUT_SYMBOL_NAME = "timeout" static val INTEGER_PRECISION = new BigDecimal(Integer.MAX_VALUE).precision static val ROUND_DOWN = new MathContext(INTEGER_PRECISION, RoundingMode.FLOOR) static val ROUND_UP = new MathContext(INTEGER_PRECISION, RoundingMode.CEILING) // The interval isolating the number is smaller than 1/10^precision. static val ALGEBRAIC_NUMBER_ROUNDING = 0 extension val Context context val Symbol infinitySymbol val Symbol multSymbol val Map variables val int timeoutMilliseconds new(Polyhedron polyhedron, boolean lpRelaxation, double timeoutSeconds) { super(polyhedron) context = new Context infinitySymbol = context.mkSymbol(INFINITY_SYMBOL_NAME) multSymbol = context.mkSymbol(MULT_SYMBOL_NAME) variables = polyhedron.dimensions.toInvertedMap [ dimension | val name = dimension.name if (lpRelaxation) { mkRealConst(name) } else { mkIntConst(name) } ] timeoutMilliseconds = Math.ceil(timeoutSeconds * 1000) as int } override doSaturate() { val status = executeSolver() convertStatusToSaturationResult(status) } private def convertStatusToSaturationResult(Status status) { switch (status) { case SATISFIABLE: PolyhedronSaturationResult.SATURATED case UNSATISFIABLE: PolyhedronSaturationResult.EMPTY case UNKNOWN: PolyhedronSaturationResult.UNKNOWN default: throw new IllegalArgumentException("Unknown Status: " + status) } } private def executeSolver() { for (expressionToSaturate : polyhedron.expressionsToSaturate) { val expr = expressionToSaturate.toExpr val lowerResult = saturateLowerBound(expr, expressionToSaturate) if (lowerResult != Status.SATISFIABLE) { return lowerResult } val upperResult = saturateUpperBound(expr, expressionToSaturate) if (upperResult != Status.SATISFIABLE) { return upperResult } } Status.SATISFIABLE } private def saturateLowerBound(ArithExpr expr, LinearBoundedExpression expressionToSaturate) { val optimize = prepareOptimize val handle = optimize.MkMinimize(expr) val status = optimize.Check() if (status == Status.SATISFIABLE) { val value = switch (resultExpr : handle.lower) { IntNum: resultExpr.getInt() RatNum: floor(resultExpr) AlgebraicNum: floor(resultExpr.toLower(ALGEBRAIC_NUMBER_ROUNDING)) default: if (isNegativeInfinity(resultExpr)) { null } else { throw new IllegalArgumentException("Integer result expected, got: " + resultExpr) } } expressionToSaturate.lowerBound = value } status } private def floor(RatNum ratNum) { val numerator = new BigDecimal(ratNum.numerator.bigInteger) val denominator = new BigDecimal(ratNum.denominator.bigInteger) numerator.divide(denominator, ROUND_DOWN).setScale(0, RoundingMode.FLOOR).intValue } private def saturateUpperBound(ArithExpr expr, LinearBoundedExpression expressionToSaturate) { val optimize = prepareOptimize val handle = optimize.MkMaximize(expr) val status = optimize.Check() if (status == Status.SATISFIABLE) { val value = switch (resultExpr : handle.upper) { IntNum: resultExpr.getInt() RatNum: ceil(resultExpr) AlgebraicNum: ceil(resultExpr.toUpper(ALGEBRAIC_NUMBER_ROUNDING)) default: if (isPositiveInfinity(resultExpr)) { null } else { throw new IllegalArgumentException("Integer result expected, got: " + resultExpr) } } expressionToSaturate.upperBound = value } status } private def ceil(RatNum ratNum) { val numerator = new BigDecimal(ratNum.numerator.bigInteger) val denominator = new BigDecimal(ratNum.denominator.bigInteger) numerator.divide(denominator, ROUND_UP).setScale(0, RoundingMode.CEILING).intValue } private def isPositiveInfinity(Expr expr) { expr.app && expr.getFuncDecl.name == infinitySymbol } private def isNegativeInfinity(Expr expr) { // Negative infinity is represented as (* (- 1) oo) if (!expr.app || expr.getFuncDecl.name != multSymbol || expr.numArgs != 2) { return false } isPositiveInfinity(expr.args.get(1)) } private def prepareOptimize() { val optimize = mkOptimize() if (timeoutMilliseconds >= 0) { val params = mkParams() // We cannot turn TIMEOUT_SYMBOL_NAME into a Symbol in the constructor, // because there is no add(Symbol, int) overload. params.add(TIMEOUT_SYMBOL_NAME, timeoutMilliseconds) optimize.parameters = params } assertConstraints(optimize) optimize } private def assertConstraints(Optimize it) { for (pair : variables.entrySet) { assertBounds(pair.value, pair.key) } for (constraint : nonTrivialConstraints) { val expr = createLinearCombination(constraint.coefficients) assertBounds(expr, constraint) } } private def assertBounds(Optimize it, ArithExpr expression, LinearBoundedExpression bounds) { val lowerBound = bounds.lowerBound val upperBound = bounds.upperBound if (lowerBound == upperBound) { if (lowerBound === null) { return } Assert(mkEq(expression, mkInt(lowerBound))) } else { if (lowerBound !== null) { Assert(mkGe(expression, mkInt(lowerBound))) } if (upperBound !== null) { Assert(mkLe(expression, mkInt(upperBound))) } } } private def toExpr(LinearBoundedExpression linearBoundedExpression) { switch (linearBoundedExpression) { Dimension: variables.get(linearBoundedExpression) LinearConstraint: createLinearCombination(linearBoundedExpression.coefficients) default: throw new IllegalArgumentException("Unknown linear bounded expression:" + linearBoundedExpression) } } private def createLinearCombination(Map coefficients) { val size = coefficients.size if (size == 0) { return mkInt(0) } val array = newArrayOfSize(size) var int i = 0 for (pair : coefficients.entrySet) { val variable = variables.get(pair.key) if (variable === null) { throw new IllegalArgumentException("Unknown dimension: " + pair.key.name) } val coefficient = pair.value val term = if (coefficient == 1) { variable } else { mkMul(mkInt(coefficient), variable) } array.set(i, term) i++ } mkAdd(array) } override close() throws Exception { context.close() } }