aboutsummaryrefslogtreecommitdiffstats
path: root/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend
diff options
context:
space:
mode:
Diffstat (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend')
-rw-r--r--Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend248
1 files changed, 248 insertions, 0 deletions
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend
new file mode 100644
index 00000000..23444956
--- /dev/null
+++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend
@@ -0,0 +1,248 @@
1package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.cardinality
2
3import com.microsoft.z3.AlgebraicNum
4import com.microsoft.z3.ArithExpr
5import com.microsoft.z3.Context
6import com.microsoft.z3.Expr
7import com.microsoft.z3.IntNum
8import com.microsoft.z3.Optimize
9import com.microsoft.z3.RatNum
10import com.microsoft.z3.Status
11import com.microsoft.z3.Symbol
12import java.math.BigDecimal
13import java.math.MathContext
14import java.math.RoundingMode
15import java.util.Map
16import org.eclipse.xtend.lib.annotations.FinalFieldsConstructor
17
18class Z3PolyhedronSolver implements PolyhedronSolver {
19 val boolean lpRelaxation
20 val double timeoutSeconds
21
22 @FinalFieldsConstructor
23 new() {
24 }
25
26 new() {
27 this(false, -1)
28 }
29
30 override createSaturationOperator(Polyhedron polyhedron) {
31 new Z3SaturationOperator(polyhedron, lpRelaxation, timeoutSeconds)
32 }
33}
34
35class Z3SaturationOperator extends AbstractPolyhedronSaturationOperator {
36 static val INFINITY_SYMBOL_NAME = "oo"
37 static val MULT_SYMBOL_NAME = "*"
38 static val TIMEOUT_SYMBOL_NAME = "timeout"
39 static val INTEGER_PRECISION = new BigDecimal(Integer.MAX_VALUE).precision
40 static val ROUND_DOWN = new MathContext(INTEGER_PRECISION, RoundingMode.FLOOR)
41 static val ROUND_UP = new MathContext(INTEGER_PRECISION, RoundingMode.CEILING)
42 // The interval isolating the number is smaller than 1/10^precision.
43 static val ALGEBRAIC_NUMBER_ROUNDING = 0
44
45 extension val Context context
46 val Symbol infinitySymbol
47 val Symbol multSymbol
48 val Map<Dimension, ArithExpr> variables
49 val int timeoutMilliseconds
50
51 new(Polyhedron polyhedron, boolean lpRelaxation, double timeoutSeconds) {
52 super(polyhedron)
53 context = new Context
54 infinitySymbol = context.mkSymbol(INFINITY_SYMBOL_NAME)
55 multSymbol = context.mkSymbol(MULT_SYMBOL_NAME)
56 variables = polyhedron.dimensions.toInvertedMap [ dimension |
57 val name = dimension.name
58 if (lpRelaxation) {
59 mkRealConst(name)
60 } else {
61 mkIntConst(name)
62 }
63 ]
64 timeoutMilliseconds = Math.ceil(timeoutSeconds * 1000) as int
65 }
66
67 override doSaturate() {
68 val status = executeSolver()
69 convertStatusToSaturationResult(status)
70 }
71
72 private def convertStatusToSaturationResult(Status status) {
73 switch (status) {
74 case SATISFIABLE:
75 PolyhedronSaturationResult.SATURATED
76 case UNSATISFIABLE:
77 PolyhedronSaturationResult.EMPTY
78 case UNKNOWN:
79 PolyhedronSaturationResult.UNKNOWN
80 default:
81 throw new IllegalArgumentException("Unknown Status: " + status)
82 }
83 }
84
85 private def executeSolver() {
86 for (expressionToSaturate : polyhedron.expressionsToSaturate) {
87 val expr = expressionToSaturate.toExpr
88 val lowerResult = saturateLowerBound(expr, expressionToSaturate)
89 if (lowerResult != Status.SATISFIABLE) {
90 return lowerResult
91 }
92 val upperResult = saturateUpperBound(expr, expressionToSaturate)
93 if (upperResult != Status.SATISFIABLE) {
94 return upperResult
95 }
96 }
97 Status.SATISFIABLE
98 }
99
100 private def saturateLowerBound(ArithExpr expr, LinearBoundedExpression expressionToSaturate) {
101 val optimize = prepareOptimize
102 val handle = optimize.MkMinimize(expr)
103 val status = optimize.Check()
104 if (status == Status.SATISFIABLE) {
105 val value = switch (resultExpr : handle.lower) {
106 IntNum:
107 resultExpr.getInt()
108 RatNum:
109 floor(resultExpr)
110 AlgebraicNum:
111 floor(resultExpr.toLower(ALGEBRAIC_NUMBER_ROUNDING))
112 default:
113 if (isNegativeInfinity(resultExpr)) {
114 null
115 } else {
116 throw new IllegalArgumentException("Integer result expected, got: " + resultExpr)
117 }
118 }
119 expressionToSaturate.lowerBound = value
120 }
121 status
122 }
123
124 private def floor(RatNum ratNum) {
125 val numerator = new BigDecimal(ratNum.numerator.bigInteger)
126 val denominator = new BigDecimal(ratNum.denominator.bigInteger)
127 numerator.divide(denominator, ROUND_DOWN).setScale(0, RoundingMode.FLOOR).intValue
128 }
129
130 private def saturateUpperBound(ArithExpr expr, LinearBoundedExpression expressionToSaturate) {
131 val optimize = prepareOptimize
132 val handle = optimize.MkMaximize(expr)
133 val status = optimize.Check()
134 if (status == Status.SATISFIABLE) {
135 val value = switch (resultExpr : handle.upper) {
136 IntNum:
137 resultExpr.getInt()
138 RatNum:
139 ceil(resultExpr)
140 AlgebraicNum:
141 ceil(resultExpr.toUpper(ALGEBRAIC_NUMBER_ROUNDING))
142 default:
143 if (isPositiveInfinity(resultExpr)) {
144 null
145 } else {
146 throw new IllegalArgumentException("Integer result expected, got: " + resultExpr)
147 }
148 }
149 expressionToSaturate.upperBound = value
150 }
151 status
152 }
153
154 private def ceil(RatNum ratNum) {
155 val numerator = new BigDecimal(ratNum.numerator.bigInteger)
156 val denominator = new BigDecimal(ratNum.denominator.bigInteger)
157 numerator.divide(denominator, ROUND_UP).setScale(0, RoundingMode.CEILING).intValue
158 }
159
160 private def isPositiveInfinity(Expr expr) {
161 expr.app && expr.getFuncDecl.name == infinitySymbol
162 }
163
164 private def isNegativeInfinity(Expr expr) {
165 // Negative infinity is represented as (* (- 1) oo)
166 if (!expr.app || expr.getFuncDecl.name != multSymbol || expr.numArgs != 2) {
167 return false
168 }
169 isPositiveInfinity(expr.args.get(1))
170 }
171
172 private def prepareOptimize() {
173 val optimize = mkOptimize()
174 if (timeoutMilliseconds >= 0) {
175 val params = mkParams()
176 // We cannot turn TIMEOUT_SYMBOL_NAME into a Symbol in the constructor,
177 // because there is no add(Symbol, int) overload.
178 params.add(TIMEOUT_SYMBOL_NAME, timeoutMilliseconds)
179 optimize.parameters = params
180 }
181 assertConstraints(optimize)
182 optimize
183 }
184
185 private def assertConstraints(Optimize it) {
186 for (pair : variables.entrySet) {
187 assertBounds(pair.value, pair.key)
188 }
189 for (constraint : nonTrivialConstraints) {
190 val expr = createLinearCombination(constraint.coefficients)
191 assertBounds(expr, constraint)
192 }
193 }
194
195 private def assertBounds(Optimize it, ArithExpr expression, LinearBoundedExpression bounds) {
196 val lowerBound = bounds.lowerBound
197 val upperBound = bounds.upperBound
198 if (lowerBound == upperBound) {
199 if (lowerBound === null) {
200 return
201 }
202 Assert(mkEq(expression, mkInt(lowerBound)))
203 } else {
204 if (lowerBound !== null) {
205 Assert(mkGe(expression, mkInt(lowerBound)))
206 }
207 if (upperBound !== null) {
208 Assert(mkLe(expression, mkInt(upperBound)))
209 }
210 }
211 }
212
213 private def toExpr(LinearBoundedExpression linearBoundedExpression) {
214 switch (linearBoundedExpression) {
215 Dimension: variables.get(linearBoundedExpression)
216 LinearConstraint: createLinearCombination(linearBoundedExpression.coefficients)
217 default: throw new IllegalArgumentException("Unknown linear bounded expression:" + linearBoundedExpression)
218 }
219 }
220
221 private def createLinearCombination(Map<Dimension, Integer> coefficients) {
222 val size = coefficients.size
223 if (size == 0) {
224 return mkInt(0)
225 }
226 val array = newArrayOfSize(size)
227 var int i = 0
228 for (pair : coefficients.entrySet) {
229 val variable = variables.get(pair.key)
230 if (variable === null) {
231 throw new IllegalArgumentException("Unknown dimension: " + pair.key.name)
232 }
233 val coefficient = pair.value
234 val term = if (coefficient == 1) {
235 variable
236 } else {
237 mkMul(mkInt(coefficient), variable)
238 }
239 array.set(i, term)
240 i++
241 }
242 mkAdd(array)
243 }
244
245 override close() throws Exception {
246 context.close()
247 }
248}