diff options
Diffstat (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend')
-rw-r--r-- | Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend | 248 |
1 files changed, 248 insertions, 0 deletions
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend new file mode 100644 index 00000000..23444956 --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/Z3PolyhedronSolver.xtend | |||
@@ -0,0 +1,248 @@ | |||
1 | package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.cardinality | ||
2 | |||
3 | import com.microsoft.z3.AlgebraicNum | ||
4 | import com.microsoft.z3.ArithExpr | ||
5 | import com.microsoft.z3.Context | ||
6 | import com.microsoft.z3.Expr | ||
7 | import com.microsoft.z3.IntNum | ||
8 | import com.microsoft.z3.Optimize | ||
9 | import com.microsoft.z3.RatNum | ||
10 | import com.microsoft.z3.Status | ||
11 | import com.microsoft.z3.Symbol | ||
12 | import java.math.BigDecimal | ||
13 | import java.math.MathContext | ||
14 | import java.math.RoundingMode | ||
15 | import java.util.Map | ||
16 | import org.eclipse.xtend.lib.annotations.FinalFieldsConstructor | ||
17 | |||
18 | class Z3PolyhedronSolver implements PolyhedronSolver { | ||
19 | val boolean lpRelaxation | ||
20 | val double timeoutSeconds | ||
21 | |||
22 | @FinalFieldsConstructor | ||
23 | new() { | ||
24 | } | ||
25 | |||
26 | new() { | ||
27 | this(false, -1) | ||
28 | } | ||
29 | |||
30 | override createSaturationOperator(Polyhedron polyhedron) { | ||
31 | new Z3SaturationOperator(polyhedron, lpRelaxation, timeoutSeconds) | ||
32 | } | ||
33 | } | ||
34 | |||
35 | class Z3SaturationOperator extends AbstractPolyhedronSaturationOperator { | ||
36 | static val INFINITY_SYMBOL_NAME = "oo" | ||
37 | static val MULT_SYMBOL_NAME = "*" | ||
38 | static val TIMEOUT_SYMBOL_NAME = "timeout" | ||
39 | static val INTEGER_PRECISION = new BigDecimal(Integer.MAX_VALUE).precision | ||
40 | static val ROUND_DOWN = new MathContext(INTEGER_PRECISION, RoundingMode.FLOOR) | ||
41 | static val ROUND_UP = new MathContext(INTEGER_PRECISION, RoundingMode.CEILING) | ||
42 | // The interval isolating the number is smaller than 1/10^precision. | ||
43 | static val ALGEBRAIC_NUMBER_ROUNDING = 0 | ||
44 | |||
45 | extension val Context context | ||
46 | val Symbol infinitySymbol | ||
47 | val Symbol multSymbol | ||
48 | val Map<Dimension, ArithExpr> variables | ||
49 | val int timeoutMilliseconds | ||
50 | |||
51 | new(Polyhedron polyhedron, boolean lpRelaxation, double timeoutSeconds) { | ||
52 | super(polyhedron) | ||
53 | context = new Context | ||
54 | infinitySymbol = context.mkSymbol(INFINITY_SYMBOL_NAME) | ||
55 | multSymbol = context.mkSymbol(MULT_SYMBOL_NAME) | ||
56 | variables = polyhedron.dimensions.toInvertedMap [ dimension | | ||
57 | val name = dimension.name | ||
58 | if (lpRelaxation) { | ||
59 | mkRealConst(name) | ||
60 | } else { | ||
61 | mkIntConst(name) | ||
62 | } | ||
63 | ] | ||
64 | timeoutMilliseconds = Math.ceil(timeoutSeconds * 1000) as int | ||
65 | } | ||
66 | |||
67 | override doSaturate() { | ||
68 | val status = executeSolver() | ||
69 | convertStatusToSaturationResult(status) | ||
70 | } | ||
71 | |||
72 | private def convertStatusToSaturationResult(Status status) { | ||
73 | switch (status) { | ||
74 | case SATISFIABLE: | ||
75 | PolyhedronSaturationResult.SATURATED | ||
76 | case UNSATISFIABLE: | ||
77 | PolyhedronSaturationResult.EMPTY | ||
78 | case UNKNOWN: | ||
79 | PolyhedronSaturationResult.UNKNOWN | ||
80 | default: | ||
81 | throw new IllegalArgumentException("Unknown Status: " + status) | ||
82 | } | ||
83 | } | ||
84 | |||
85 | private def executeSolver() { | ||
86 | for (expressionToSaturate : polyhedron.expressionsToSaturate) { | ||
87 | val expr = expressionToSaturate.toExpr | ||
88 | val lowerResult = saturateLowerBound(expr, expressionToSaturate) | ||
89 | if (lowerResult != Status.SATISFIABLE) { | ||
90 | return lowerResult | ||
91 | } | ||
92 | val upperResult = saturateUpperBound(expr, expressionToSaturate) | ||
93 | if (upperResult != Status.SATISFIABLE) { | ||
94 | return upperResult | ||
95 | } | ||
96 | } | ||
97 | Status.SATISFIABLE | ||
98 | } | ||
99 | |||
100 | private def saturateLowerBound(ArithExpr expr, LinearBoundedExpression expressionToSaturate) { | ||
101 | val optimize = prepareOptimize | ||
102 | val handle = optimize.MkMinimize(expr) | ||
103 | val status = optimize.Check() | ||
104 | if (status == Status.SATISFIABLE) { | ||
105 | val value = switch (resultExpr : handle.lower) { | ||
106 | IntNum: | ||
107 | resultExpr.getInt() | ||
108 | RatNum: | ||
109 | floor(resultExpr) | ||
110 | AlgebraicNum: | ||
111 | floor(resultExpr.toLower(ALGEBRAIC_NUMBER_ROUNDING)) | ||
112 | default: | ||
113 | if (isNegativeInfinity(resultExpr)) { | ||
114 | null | ||
115 | } else { | ||
116 | throw new IllegalArgumentException("Integer result expected, got: " + resultExpr) | ||
117 | } | ||
118 | } | ||
119 | expressionToSaturate.lowerBound = value | ||
120 | } | ||
121 | status | ||
122 | } | ||
123 | |||
124 | private def floor(RatNum ratNum) { | ||
125 | val numerator = new BigDecimal(ratNum.numerator.bigInteger) | ||
126 | val denominator = new BigDecimal(ratNum.denominator.bigInteger) | ||
127 | numerator.divide(denominator, ROUND_DOWN).setScale(0, RoundingMode.FLOOR).intValue | ||
128 | } | ||
129 | |||
130 | private def saturateUpperBound(ArithExpr expr, LinearBoundedExpression expressionToSaturate) { | ||
131 | val optimize = prepareOptimize | ||
132 | val handle = optimize.MkMaximize(expr) | ||
133 | val status = optimize.Check() | ||
134 | if (status == Status.SATISFIABLE) { | ||
135 | val value = switch (resultExpr : handle.upper) { | ||
136 | IntNum: | ||
137 | resultExpr.getInt() | ||
138 | RatNum: | ||
139 | ceil(resultExpr) | ||
140 | AlgebraicNum: | ||
141 | ceil(resultExpr.toUpper(ALGEBRAIC_NUMBER_ROUNDING)) | ||
142 | default: | ||
143 | if (isPositiveInfinity(resultExpr)) { | ||
144 | null | ||
145 | } else { | ||
146 | throw new IllegalArgumentException("Integer result expected, got: " + resultExpr) | ||
147 | } | ||
148 | } | ||
149 | expressionToSaturate.upperBound = value | ||
150 | } | ||
151 | status | ||
152 | } | ||
153 | |||
154 | private def ceil(RatNum ratNum) { | ||
155 | val numerator = new BigDecimal(ratNum.numerator.bigInteger) | ||
156 | val denominator = new BigDecimal(ratNum.denominator.bigInteger) | ||
157 | numerator.divide(denominator, ROUND_UP).setScale(0, RoundingMode.CEILING).intValue | ||
158 | } | ||
159 | |||
160 | private def isPositiveInfinity(Expr expr) { | ||
161 | expr.app && expr.getFuncDecl.name == infinitySymbol | ||
162 | } | ||
163 | |||
164 | private def isNegativeInfinity(Expr expr) { | ||
165 | // Negative infinity is represented as (* (- 1) oo) | ||
166 | if (!expr.app || expr.getFuncDecl.name != multSymbol || expr.numArgs != 2) { | ||
167 | return false | ||
168 | } | ||
169 | isPositiveInfinity(expr.args.get(1)) | ||
170 | } | ||
171 | |||
172 | private def prepareOptimize() { | ||
173 | val optimize = mkOptimize() | ||
174 | if (timeoutMilliseconds >= 0) { | ||
175 | val params = mkParams() | ||
176 | // We cannot turn TIMEOUT_SYMBOL_NAME into a Symbol in the constructor, | ||
177 | // because there is no add(Symbol, int) overload. | ||
178 | params.add(TIMEOUT_SYMBOL_NAME, timeoutMilliseconds) | ||
179 | optimize.parameters = params | ||
180 | } | ||
181 | assertConstraints(optimize) | ||
182 | optimize | ||
183 | } | ||
184 | |||
185 | private def assertConstraints(Optimize it) { | ||
186 | for (pair : variables.entrySet) { | ||
187 | assertBounds(pair.value, pair.key) | ||
188 | } | ||
189 | for (constraint : nonTrivialConstraints) { | ||
190 | val expr = createLinearCombination(constraint.coefficients) | ||
191 | assertBounds(expr, constraint) | ||
192 | } | ||
193 | } | ||
194 | |||
195 | private def assertBounds(Optimize it, ArithExpr expression, LinearBoundedExpression bounds) { | ||
196 | val lowerBound = bounds.lowerBound | ||
197 | val upperBound = bounds.upperBound | ||
198 | if (lowerBound == upperBound) { | ||
199 | if (lowerBound === null) { | ||
200 | return | ||
201 | } | ||
202 | Assert(mkEq(expression, mkInt(lowerBound))) | ||
203 | } else { | ||
204 | if (lowerBound !== null) { | ||
205 | Assert(mkGe(expression, mkInt(lowerBound))) | ||
206 | } | ||
207 | if (upperBound !== null) { | ||
208 | Assert(mkLe(expression, mkInt(upperBound))) | ||
209 | } | ||
210 | } | ||
211 | } | ||
212 | |||
213 | private def toExpr(LinearBoundedExpression linearBoundedExpression) { | ||
214 | switch (linearBoundedExpression) { | ||
215 | Dimension: variables.get(linearBoundedExpression) | ||
216 | LinearConstraint: createLinearCombination(linearBoundedExpression.coefficients) | ||
217 | default: throw new IllegalArgumentException("Unknown linear bounded expression:" + linearBoundedExpression) | ||
218 | } | ||
219 | } | ||
220 | |||
221 | private def createLinearCombination(Map<Dimension, Integer> coefficients) { | ||
222 | val size = coefficients.size | ||
223 | if (size == 0) { | ||
224 | return mkInt(0) | ||
225 | } | ||
226 | val array = newArrayOfSize(size) | ||
227 | var int i = 0 | ||
228 | for (pair : coefficients.entrySet) { | ||
229 | val variable = variables.get(pair.key) | ||
230 | if (variable === null) { | ||
231 | throw new IllegalArgumentException("Unknown dimension: " + pair.key.name) | ||
232 | } | ||
233 | val coefficient = pair.value | ||
234 | val term = if (coefficient == 1) { | ||
235 | variable | ||
236 | } else { | ||
237 | mkMul(mkInt(coefficient), variable) | ||
238 | } | ||
239 | array.set(i, term) | ||
240 | i++ | ||
241 | } | ||
242 | mkAdd(array) | ||
243 | } | ||
244 | |||
245 | override close() throws Exception { | ||
246 | context.close() | ||
247 | } | ||
248 | } | ||