diff options
Diffstat (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend')
-rw-r--r-- | Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend | 272 |
1 files changed, 272 insertions, 0 deletions
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend new file mode 100644 index 00000000..3b831433 --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend | |||
@@ -0,0 +1,272 @@ | |||
1 | package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.cardinality | ||
2 | |||
3 | import com.microsoft.z3.AlgebraicNum | ||
4 | import com.microsoft.z3.ArithExpr | ||
5 | import com.microsoft.z3.Context | ||
6 | import com.microsoft.z3.Expr | ||
7 | import com.microsoft.z3.IntNum | ||
8 | import com.microsoft.z3.Optimize | ||
9 | import com.microsoft.z3.RatNum | ||
10 | import com.microsoft.z3.Status | ||
11 | import com.microsoft.z3.Symbol | ||
12 | import java.math.BigDecimal | ||
13 | import java.math.MathContext | ||
14 | import java.math.RoundingMode | ||
15 | import java.util.Map | ||
16 | import org.eclipse.xtend.lib.annotations.Accessors | ||
17 | import org.eclipse.xtend.lib.annotations.FinalFieldsConstructor | ||
18 | |||
19 | class Z3PolyhedronSolver implements PolyhedronSolver { | ||
20 | val boolean lpRelaxation | ||
21 | val double timeoutSeconds | ||
22 | |||
23 | @FinalFieldsConstructor | ||
24 | new() { | ||
25 | } | ||
26 | |||
27 | new() { | ||
28 | this(false, -1) | ||
29 | } | ||
30 | |||
31 | override createSaturationOperator(Polyhedron polyhedron) { | ||
32 | new DisposingZ3SaturationOperator(this, polyhedron) | ||
33 | } | ||
34 | |||
35 | def createPersistentSaturationOperator(Polyhedron polyhedron) { | ||
36 | new Z3SaturationOperator(polyhedron, lpRelaxation, timeoutSeconds) | ||
37 | } | ||
38 | } | ||
39 | |||
40 | @FinalFieldsConstructor | ||
41 | class DisposingZ3SaturationOperator implements PolyhedronSaturationOperator { | ||
42 | val Z3PolyhedronSolver solver | ||
43 | @Accessors val Polyhedron polyhedron | ||
44 | |||
45 | override saturate() { | ||
46 | val persistentOperator = solver.createPersistentSaturationOperator(polyhedron) | ||
47 | try { | ||
48 | persistentOperator.saturate | ||
49 | } finally { | ||
50 | persistentOperator.close | ||
51 | } | ||
52 | } | ||
53 | |||
54 | override close() throws Exception { | ||
55 | // Nothing to close. | ||
56 | } | ||
57 | } | ||
58 | |||
59 | class Z3SaturationOperator extends AbstractPolyhedronSaturationOperator { | ||
60 | static val INFINITY_SYMBOL_NAME = "oo" | ||
61 | static val MULT_SYMBOL_NAME = "*" | ||
62 | static val TIMEOUT_SYMBOL_NAME = "timeout" | ||
63 | static val INTEGER_PRECISION = new BigDecimal(Integer.MAX_VALUE).precision | ||
64 | static val ROUND_DOWN = new MathContext(INTEGER_PRECISION, RoundingMode.FLOOR) | ||
65 | static val ROUND_UP = new MathContext(INTEGER_PRECISION, RoundingMode.CEILING) | ||
66 | // The interval isolating the number is smaller than 1/10^precision. | ||
67 | static val ALGEBRAIC_NUMBER_ROUNDING = 0 | ||
68 | |||
69 | extension val Context context | ||
70 | val Symbol infinitySymbol | ||
71 | val Symbol multSymbol | ||
72 | val Map<Dimension, ArithExpr> variables | ||
73 | val int timeoutMilliseconds | ||
74 | |||
75 | new(Polyhedron polyhedron, boolean lpRelaxation, double timeoutSeconds) { | ||
76 | super(polyhedron) | ||
77 | context = new Context | ||
78 | infinitySymbol = context.mkSymbol(INFINITY_SYMBOL_NAME) | ||
79 | multSymbol = context.mkSymbol(MULT_SYMBOL_NAME) | ||
80 | variables = polyhedron.dimensions.toInvertedMap [ dimension | | ||
81 | val name = dimension.name | ||
82 | if (lpRelaxation) { | ||
83 | mkRealConst(name) | ||
84 | } else { | ||
85 | mkIntConst(name) | ||
86 | } | ||
87 | ] | ||
88 | timeoutMilliseconds = Math.ceil(timeoutSeconds * 1000) as int | ||
89 | } | ||
90 | |||
91 | override doSaturate() { | ||
92 | val status = executeSolver() | ||
93 | convertStatusToSaturationResult(status) | ||
94 | } | ||
95 | |||
96 | private def convertStatusToSaturationResult(Status status) { | ||
97 | switch (status) { | ||
98 | case SATISFIABLE: | ||
99 | PolyhedronSaturationResult.SATURATED | ||
100 | case UNSATISFIABLE: | ||
101 | PolyhedronSaturationResult.EMPTY | ||
102 | case UNKNOWN: | ||
103 | PolyhedronSaturationResult.UNKNOWN | ||
104 | default: | ||
105 | throw new IllegalArgumentException("Unknown Status: " + status) | ||
106 | } | ||
107 | } | ||
108 | |||
109 | private def executeSolver() { | ||
110 | for (expressionToSaturate : polyhedron.expressionsToSaturate) { | ||
111 | val expr = expressionToSaturate.toExpr | ||
112 | val lowerResult = saturateLowerBound(expr, expressionToSaturate) | ||
113 | if (lowerResult != Status.SATISFIABLE) { | ||
114 | return lowerResult | ||
115 | } | ||
116 | val upperResult = saturateUpperBound(expr, expressionToSaturate) | ||
117 | if (upperResult != Status.SATISFIABLE) { | ||
118 | return upperResult | ||
119 | } | ||
120 | } | ||
121 | Status.SATISFIABLE | ||
122 | } | ||
123 | |||
124 | private def saturateLowerBound(ArithExpr expr, LinearBoundedExpression expressionToSaturate) { | ||
125 | val optimize = prepareOptimize | ||
126 | val handle = optimize.MkMinimize(expr) | ||
127 | val status = optimize.Check() | ||
128 | if (status == Status.SATISFIABLE) { | ||
129 | val value = switch (resultExpr : handle.lower) { | ||
130 | IntNum: | ||
131 | resultExpr.getInt() | ||
132 | RatNum: | ||
133 | ceil(resultExpr) | ||
134 | AlgebraicNum: | ||
135 | ceil(resultExpr.toUpper(ALGEBRAIC_NUMBER_ROUNDING)) | ||
136 | default: | ||
137 | if (isNegativeInfinity(resultExpr)) { | ||
138 | null | ||
139 | } else { | ||
140 | throw new IllegalArgumentException("Integer result expected, got: " + resultExpr) | ||
141 | } | ||
142 | } | ||
143 | expressionToSaturate.lowerBound = value | ||
144 | } | ||
145 | status | ||
146 | } | ||
147 | |||
148 | private def floor(RatNum ratNum) { | ||
149 | val numerator = new BigDecimal(ratNum.numerator.bigInteger) | ||
150 | val denominator = new BigDecimal(ratNum.denominator.bigInteger) | ||
151 | numerator.divide(denominator, ROUND_DOWN).setScale(0, RoundingMode.FLOOR).intValue | ||
152 | } | ||
153 | |||
154 | private def saturateUpperBound(ArithExpr expr, LinearBoundedExpression expressionToSaturate) { | ||
155 | val optimize = prepareOptimize | ||
156 | val handle = optimize.MkMaximize(expr) | ||
157 | val status = optimize.Check() | ||
158 | if (status == Status.SATISFIABLE) { | ||
159 | val value = switch (resultExpr : handle.upper) { | ||
160 | IntNum: | ||
161 | resultExpr.getInt() | ||
162 | RatNum: | ||
163 | floor(resultExpr) | ||
164 | AlgebraicNum: | ||
165 | floor(resultExpr.toLower(ALGEBRAIC_NUMBER_ROUNDING)) | ||
166 | default: | ||
167 | if (isPositiveInfinity(resultExpr)) { | ||
168 | null | ||
169 | } else { | ||
170 | throw new IllegalArgumentException("Integer result expected, got: " + resultExpr) | ||
171 | } | ||
172 | } | ||
173 | expressionToSaturate.upperBound = value | ||
174 | } | ||
175 | status | ||
176 | } | ||
177 | |||
178 | private def ceil(RatNum ratNum) { | ||
179 | val numerator = new BigDecimal(ratNum.numerator.bigInteger) | ||
180 | val denominator = new BigDecimal(ratNum.denominator.bigInteger) | ||
181 | numerator.divide(denominator, ROUND_UP).setScale(0, RoundingMode.CEILING).intValue | ||
182 | } | ||
183 | |||
184 | private def isPositiveInfinity(Expr expr) { | ||
185 | expr.app && expr.getFuncDecl.name == infinitySymbol | ||
186 | } | ||
187 | |||
188 | private def isNegativeInfinity(Expr expr) { | ||
189 | // Negative infinity is represented as (* (- 1) oo) | ||
190 | if (!expr.app || expr.getFuncDecl.name != multSymbol || expr.numArgs != 2) { | ||
191 | return false | ||
192 | } | ||
193 | isPositiveInfinity(expr.args.get(1)) | ||
194 | } | ||
195 | |||
196 | private def prepareOptimize() { | ||
197 | val optimize = mkOptimize() | ||
198 | if (timeoutMilliseconds >= 0) { | ||
199 | val params = mkParams() | ||
200 | // We cannot turn TIMEOUT_SYMBOL_NAME into a Symbol in the constructor, | ||
201 | // because there is no add(Symbol, int) overload. | ||
202 | params.add(TIMEOUT_SYMBOL_NAME, timeoutMilliseconds) | ||
203 | optimize.parameters = params | ||
204 | } | ||
205 | assertConstraints(optimize) | ||
206 | optimize | ||
207 | } | ||
208 | |||
209 | private def assertConstraints(Optimize it) { | ||
210 | for (pair : variables.entrySet) { | ||
211 | assertBounds(pair.value, pair.key) | ||
212 | } | ||
213 | for (constraint : nonTrivialConstraints) { | ||
214 | val expr = createLinearCombination(constraint.coefficients) | ||
215 | assertBounds(expr, constraint) | ||
216 | } | ||
217 | } | ||
218 | |||
219 | private def assertBounds(Optimize it, ArithExpr expression, LinearBoundedExpression bounds) { | ||
220 | val lowerBound = bounds.lowerBound | ||
221 | val upperBound = bounds.upperBound | ||
222 | if (lowerBound == upperBound) { | ||
223 | if (lowerBound === null) { | ||
224 | return | ||
225 | } | ||
226 | Assert(mkEq(expression, mkInt(lowerBound))) | ||
227 | } else { | ||
228 | if (lowerBound !== null) { | ||
229 | Assert(mkGe(expression, mkInt(lowerBound))) | ||
230 | } | ||
231 | if (upperBound !== null) { | ||
232 | Assert(mkLe(expression, mkInt(upperBound))) | ||
233 | } | ||
234 | } | ||
235 | } | ||
236 | |||
237 | private def toExpr(LinearBoundedExpression linearBoundedExpression) { | ||
238 | switch (linearBoundedExpression) { | ||
239 | Dimension: variables.get(linearBoundedExpression) | ||
240 | LinearConstraint: createLinearCombination(linearBoundedExpression.coefficients) | ||
241 | default: throw new IllegalArgumentException("Unknown linear bounded expression:" + linearBoundedExpression) | ||
242 | } | ||
243 | } | ||
244 | |||
245 | private def createLinearCombination(Map<Dimension, Integer> coefficients) { | ||
246 | val size = coefficients.size | ||
247 | if (size == 0) { | ||
248 | return mkInt(0) | ||
249 | } | ||
250 | val array = newArrayOfSize(size) | ||
251 | var int i = 0 | ||
252 | for (pair : coefficients.entrySet) { | ||
253 | val variable = variables.get(pair.key) | ||
254 | if (variable === null) { | ||
255 | throw new IllegalArgumentException("Unknown dimension: " + pair.key.name) | ||
256 | } | ||
257 | val coefficient = pair.value | ||
258 | val term = if (coefficient == 1) { | ||
259 | variable | ||
260 | } else { | ||
261 | mkMul(mkInt(coefficient), variable) | ||
262 | } | ||
263 | array.set(i, term) | ||
264 | i++ | ||
265 | } | ||
266 | mkAdd(array) | ||
267 | } | ||
268 | |||
269 | override close() throws Exception { | ||
270 | context.close() | ||
271 | } | ||
272 | } | ||