diff options
Diffstat (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend')
-rw-r--r-- | Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend | 584 |
1 files changed, 584 insertions, 0 deletions
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend new file mode 100644 index 00000000..691c8783 --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend | |||
@@ -0,0 +1,584 @@ | |||
1 | package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval | ||
2 | |||
3 | import java.math.BigDecimal | ||
4 | import java.math.MathContext | ||
5 | import java.math.RoundingMode | ||
6 | import org.eclipse.xtend.lib.annotations.Data | ||
7 | |||
8 | abstract class Interval implements Comparable<Interval> { | ||
9 | static val PRECISION = 32 | ||
10 | package static val ROUND_DOWN = new MathContext(PRECISION, RoundingMode.FLOOR) | ||
11 | package static val ROUND_UP = new MathContext(PRECISION, RoundingMode.CEILING) | ||
12 | |||
13 | private new() { | ||
14 | } | ||
15 | |||
16 | abstract def boolean mustEqual(Interval other) | ||
17 | |||
18 | abstract def boolean mayEqual(Interval other) | ||
19 | |||
20 | def mustNotEqual(Interval other) { | ||
21 | !mayEqual(other) | ||
22 | } | ||
23 | |||
24 | def mayNotEqual(Interval other) { | ||
25 | !mustEqual(other) | ||
26 | } | ||
27 | |||
28 | abstract def boolean mustBeLessThan(Interval other) | ||
29 | |||
30 | abstract def boolean mayBeLessThan(Interval other) | ||
31 | |||
32 | def mustBeLessThanOrEqual(Interval other) { | ||
33 | !mayBeGreaterThan(other) | ||
34 | } | ||
35 | |||
36 | def mayBeLessThanOrEqual(Interval other) { | ||
37 | !mustBeGreaterThan(other) | ||
38 | } | ||
39 | |||
40 | def mustBeGreaterThan(Interval other) { | ||
41 | other.mustBeLessThan(this) | ||
42 | } | ||
43 | |||
44 | def mayBeGreaterThan(Interval other) { | ||
45 | other.mayBeLessThan(this) | ||
46 | } | ||
47 | |||
48 | def mustBeGreaterThanOrEqual(Interval other) { | ||
49 | other.mustBeLessThanOrEqual(this) | ||
50 | } | ||
51 | |||
52 | def mayBeGreaterThanOrEqual(Interval other) { | ||
53 | other.mayBeLessThanOrEqual(this) | ||
54 | } | ||
55 | |||
56 | abstract def Interval min(Interval other) | ||
57 | |||
58 | abstract def Interval max(Interval other) | ||
59 | |||
60 | abstract def Interval join(Interval other) | ||
61 | |||
62 | def +() { | ||
63 | this | ||
64 | } | ||
65 | |||
66 | abstract def Interval -() | ||
67 | |||
68 | abstract def Interval +(Interval other) | ||
69 | |||
70 | abstract def Interval -(Interval other) | ||
71 | |||
72 | abstract def Interval *(int count) | ||
73 | |||
74 | abstract def Interval *(Interval other) | ||
75 | |||
76 | abstract def Interval /(Interval other) | ||
77 | |||
78 | abstract def Interval **(Interval other) | ||
79 | |||
80 | public static val EMPTY = new Interval { | ||
81 | override mustEqual(Interval other) { | ||
82 | true | ||
83 | } | ||
84 | |||
85 | override mayEqual(Interval other) { | ||
86 | false | ||
87 | } | ||
88 | |||
89 | override mustBeLessThan(Interval other) { | ||
90 | true | ||
91 | } | ||
92 | |||
93 | override mayBeLessThan(Interval other) { | ||
94 | false | ||
95 | } | ||
96 | |||
97 | override min(Interval other) { | ||
98 | EMPTY | ||
99 | } | ||
100 | |||
101 | override max(Interval other) { | ||
102 | EMPTY | ||
103 | } | ||
104 | |||
105 | override join(Interval other) { | ||
106 | other | ||
107 | } | ||
108 | |||
109 | override -() { | ||
110 | EMPTY | ||
111 | } | ||
112 | |||
113 | override +(Interval other) { | ||
114 | EMPTY | ||
115 | } | ||
116 | |||
117 | override -(Interval other) { | ||
118 | EMPTY | ||
119 | } | ||
120 | |||
121 | override *(int count) { | ||
122 | EMPTY | ||
123 | } | ||
124 | |||
125 | override *(Interval other) { | ||
126 | EMPTY | ||
127 | } | ||
128 | |||
129 | override /(Interval other) { | ||
130 | EMPTY | ||
131 | } | ||
132 | |||
133 | override **(Interval other) { | ||
134 | EMPTY | ||
135 | } | ||
136 | |||
137 | override toString() { | ||
138 | "∅" | ||
139 | } | ||
140 | |||
141 | override compareTo(Interval o) { | ||
142 | if (o == EMPTY) { | ||
143 | 0 | ||
144 | } else { | ||
145 | -1 | ||
146 | } | ||
147 | } | ||
148 | |||
149 | } | ||
150 | |||
151 | public static val Interval ZERO = new NonEmpty(BigDecimal.ZERO, BigDecimal.ZERO) | ||
152 | |||
153 | public static val Interval UNBOUNDED = new NonEmpty(null, null) | ||
154 | |||
155 | static def Interval of(BigDecimal lower, BigDecimal upper) { | ||
156 | new NonEmpty(lower, upper) | ||
157 | } | ||
158 | |||
159 | static def between(double lower, double upper) { | ||
160 | of(new BigDecimal(lower, ROUND_DOWN), new BigDecimal(upper, ROUND_UP)) | ||
161 | } | ||
162 | |||
163 | static def upTo(double upper) { | ||
164 | of(null, new BigDecimal(upper, ROUND_UP)) | ||
165 | } | ||
166 | |||
167 | static def above(double lower) { | ||
168 | of(new BigDecimal(lower, ROUND_DOWN), null) | ||
169 | } | ||
170 | |||
171 | @Data | ||
172 | private static class NonEmpty extends Interval { | ||
173 | val BigDecimal lower | ||
174 | val BigDecimal upper | ||
175 | |||
176 | /** | ||
177 | * Construct a new non-empty interval. | ||
178 | * | ||
179 | * @param lower The lower bound of the interval. Use <code>null</code> for negative infinity. | ||
180 | * @param upper The upper bound of the interval. Use <code>null</code> for positive infinity. | ||
181 | */ | ||
182 | new(BigDecimal lower, BigDecimal upper) { | ||
183 | if (lower !== null && upper !== null && lower > upper) { | ||
184 | throw new IllegalArgumentException("Lower bound of interval must not be larger than upper bound") | ||
185 | } | ||
186 | this.lower = lower | ||
187 | this.upper = upper | ||
188 | } | ||
189 | |||
190 | override mustEqual(Interval other) { | ||
191 | switch (other) { | ||
192 | case EMPTY: true | ||
193 | NonEmpty: lower == upper && lower == other.lower && lower == other.upper | ||
194 | default: throw new IllegalArgumentException("Unknown interval: " + other) | ||
195 | } | ||
196 | } | ||
197 | |||
198 | override mayEqual(Interval other) { | ||
199 | if (other instanceof NonEmpty) { | ||
200 | (lower === null || other.upper === null || lower <= other.upper) && | ||
201 | (other.lower === null || upper === null || other.lower <= upper) | ||
202 | } else { | ||
203 | false | ||
204 | } | ||
205 | } | ||
206 | |||
207 | override mustBeLessThan(Interval other) { | ||
208 | switch (other) { | ||
209 | case EMPTY: true | ||
210 | NonEmpty: upper !== null && other.lower !== null && upper < other.lower | ||
211 | default: throw new IllegalArgumentException("Unknown interval: " + other) | ||
212 | } | ||
213 | } | ||
214 | |||
215 | override mayBeLessThan(Interval other) { | ||
216 | if (other instanceof NonEmpty) { | ||
217 | lower === null || other.upper === null || lower < other.upper | ||
218 | } else { | ||
219 | false | ||
220 | } | ||
221 | } | ||
222 | |||
223 | override min(Interval other) { | ||
224 | switch (other) { | ||
225 | case EMPTY: this | ||
226 | NonEmpty: min(other) | ||
227 | default: throw new IllegalArgumentException("Unknown interval: " + other) | ||
228 | } | ||
229 | } | ||
230 | |||
231 | def min(NonEmpty other) { | ||
232 | new NonEmpty( | ||
233 | lower.tryMin(other.lower), | ||
234 | if(other.upper === null) upper else if(upper === null) other.upper else upper.min(other.upper) | ||
235 | ) | ||
236 | } | ||
237 | |||
238 | override max(Interval other) { | ||
239 | switch (other) { | ||
240 | case EMPTY: this | ||
241 | NonEmpty: max(other) | ||
242 | default: throw new IllegalArgumentException("Unknown interval: " + other) | ||
243 | } | ||
244 | } | ||
245 | |||
246 | def max(NonEmpty other) { | ||
247 | new NonEmpty( | ||
248 | if(other.lower === null) lower else if(lower === null) other.lower else lower.max(other.lower), | ||
249 | upper.tryMax(other.upper) | ||
250 | ) | ||
251 | } | ||
252 | |||
253 | override join(Interval other) { | ||
254 | switch (other) { | ||
255 | case EMPTY: this | ||
256 | NonEmpty: new NonEmpty(lower.tryMin(other.lower), upper.tryMax(other.upper)) | ||
257 | default: throw new IllegalArgumentException("Unknown interval: " + other) | ||
258 | } | ||
259 | } | ||
260 | |||
261 | override -() { | ||
262 | new NonEmpty(upper?.negate(ROUND_DOWN), lower?.negate(ROUND_UP)) | ||
263 | } | ||
264 | |||
265 | override +(Interval other) { | ||
266 | switch (other) { | ||
267 | case EMPTY: EMPTY | ||
268 | NonEmpty: this + other | ||
269 | default: throw new IllegalArgumentException("Unknown interval: " + other) | ||
270 | } | ||
271 | } | ||
272 | |||
273 | def +(NonEmpty other) { | ||
274 | new NonEmpty( | ||
275 | lower.tryAdd(other.lower, ROUND_DOWN), | ||
276 | upper.tryAdd(other.upper, ROUND_UP) | ||
277 | ) | ||
278 | } | ||
279 | |||
280 | private static def tryAdd(BigDecimal a, BigDecimal b, MathContext mc) { | ||
281 | if (b === null) { | ||
282 | null | ||
283 | } else { | ||
284 | a?.add(b, mc) | ||
285 | } | ||
286 | } | ||
287 | |||
288 | override -(Interval other) { | ||
289 | switch (other) { | ||
290 | case EMPTY: EMPTY | ||
291 | NonEmpty: this - other | ||
292 | default: throw new IllegalArgumentException("Unknown interval: " + other) | ||
293 | } | ||
294 | } | ||
295 | |||
296 | def -(NonEmpty other) { | ||
297 | new NonEmpty( | ||
298 | lower.trySubtract(other.upper, ROUND_DOWN), | ||
299 | upper.trySubtract(other.lower, ROUND_UP) | ||
300 | ) | ||
301 | } | ||
302 | |||
303 | private static def trySubtract(BigDecimal a, BigDecimal b, MathContext mc) { | ||
304 | if (b === null) { | ||
305 | null | ||
306 | } else { | ||
307 | a?.subtract(b, mc) | ||
308 | } | ||
309 | } | ||
310 | |||
311 | override *(int count) { | ||
312 | val bigCount = new BigDecimal(count) | ||
313 | new NonEmpty( | ||
314 | lower.tryMultiply(bigCount, ROUND_DOWN), | ||
315 | upper.tryMultiply(bigCount, ROUND_UP) | ||
316 | ) | ||
317 | } | ||
318 | |||
319 | override *(Interval other) { | ||
320 | switch (other) { | ||
321 | case EMPTY: EMPTY | ||
322 | NonEmpty: this * other | ||
323 | default: throw new IllegalArgumentException("Unknown interval: " + other) | ||
324 | } | ||
325 | } | ||
326 | |||
327 | def *(NonEmpty other) { | ||
328 | if (this == ZERO || other == ZERO) { | ||
329 | ZERO | ||
330 | } else if (nonpositive) { | ||
331 | if (other.nonpositive) { | ||
332 | new NonEmpty( | ||
333 | upper.multiply(other.upper, ROUND_DOWN), | ||
334 | lower.tryMultiply(other.lower, ROUND_UP) | ||
335 | ) | ||
336 | } else if (other.nonnegative) { | ||
337 | new NonEmpty( | ||
338 | lower.tryMultiply(other.upper, ROUND_DOWN), | ||
339 | upper.multiply(other.lower, ROUND_UP) | ||
340 | ) | ||
341 | } else { | ||
342 | new NonEmpty( | ||
343 | lower.tryMultiply(other.upper, ROUND_DOWN), | ||
344 | lower.tryMultiply(other.lower, ROUND_UP) | ||
345 | ) | ||
346 | } | ||
347 | } else if (nonnegative) { | ||
348 | if (other.nonpositive) { | ||
349 | new NonEmpty( | ||
350 | upper.tryMultiply(other.lower, ROUND_DOWN), | ||
351 | lower.multiply(other.upper, ROUND_UP) | ||
352 | ) | ||
353 | } else if (other.nonnegative) { | ||
354 | new NonEmpty( | ||
355 | lower.multiply(other.lower, ROUND_DOWN), | ||
356 | upper.tryMultiply(other.upper, ROUND_UP) | ||
357 | ) | ||
358 | } else { | ||
359 | new NonEmpty( | ||
360 | upper.tryMultiply(other.lower, ROUND_DOWN), | ||
361 | upper.tryMultiply(other.upper, ROUND_UP) | ||
362 | ) | ||
363 | } | ||
364 | } else { | ||
365 | if (other.nonpositive) { | ||
366 | new NonEmpty( | ||
367 | upper.tryMultiply(other.lower, ROUND_DOWN), | ||
368 | lower.tryMultiply(other.lower, ROUND_UP) | ||
369 | ) | ||
370 | } else if (other.nonnegative) { | ||
371 | new NonEmpty( | ||
372 | lower.tryMultiply(other.upper, ROUND_DOWN), | ||
373 | upper.tryMultiply(other.upper, ROUND_UP) | ||
374 | ) | ||
375 | } else { | ||
376 | new NonEmpty( | ||
377 | lower.tryMultiply(other.upper, ROUND_DOWN).tryMin(upper.tryMultiply(other.lower, ROUND_DOWN)), | ||
378 | lower.tryMultiply(other.lower, ROUND_UP).tryMax(upper.tryMultiply(other.upper, ROUND_UP)) | ||
379 | ) | ||
380 | } | ||
381 | } | ||
382 | } | ||
383 | |||
384 | private def isNonpositive() { | ||
385 | upper !== null && upper <= BigDecimal.ZERO | ||
386 | } | ||
387 | |||
388 | private def isNonnegative() { | ||
389 | lower !== null && lower >= BigDecimal.ZERO | ||
390 | } | ||
391 | |||
392 | private static def tryMultiply(BigDecimal a, BigDecimal b, MathContext mc) { | ||
393 | if (b === null) { | ||
394 | null | ||
395 | } else { | ||
396 | a?.multiply(b, mc) | ||
397 | } | ||
398 | } | ||
399 | |||
400 | private static def tryMin(BigDecimal a, BigDecimal b) { | ||
401 | if (b === null) { | ||
402 | null | ||
403 | } else { | ||
404 | a?.min(b) | ||
405 | } | ||
406 | } | ||
407 | |||
408 | private static def tryMax(BigDecimal a, BigDecimal b) { | ||
409 | if (b === null) { | ||
410 | null | ||
411 | } else { | ||
412 | a?.max(b) | ||
413 | } | ||
414 | } | ||
415 | |||
416 | override /(Interval other) { | ||
417 | switch (other) { | ||
418 | case EMPTY: EMPTY | ||
419 | NonEmpty: this / other | ||
420 | default: throw new IllegalArgumentException("Unknown interval: " + other) | ||
421 | } | ||
422 | } | ||
423 | |||
424 | def /(NonEmpty other) { | ||
425 | if (other == ZERO) { | ||
426 | EMPTY | ||
427 | } else if (this == ZERO) { | ||
428 | ZERO | ||
429 | } else if (other.strictlyNegative) { | ||
430 | if (nonpositive) { | ||
431 | new NonEmpty( | ||
432 | upper.tryDivide(other.lower, ROUND_DOWN), | ||
433 | lower.tryDivide(other.upper, ROUND_UP) | ||
434 | ) | ||
435 | } else if (nonnegative) { | ||
436 | new NonEmpty( | ||
437 | upper.tryDivide(other.upper, ROUND_DOWN), | ||
438 | lower.tryDivide(other.lower, ROUND_UP) | ||
439 | ) | ||
440 | } else { // lower < 0 < upper | ||
441 | new NonEmpty( | ||
442 | upper.tryDivide(other.upper, ROUND_DOWN), | ||
443 | lower.tryDivide(other.upper, ROUND_UP) | ||
444 | ) | ||
445 | } | ||
446 | } else if (other.strictlyPositive) { | ||
447 | if (nonpositive) { | ||
448 | new NonEmpty( | ||
449 | lower.tryDivide(other.lower, ROUND_DOWN), | ||
450 | upper.tryDivide(other.upper, ROUND_UP) | ||
451 | ) | ||
452 | } else if (nonnegative) { | ||
453 | new NonEmpty( | ||
454 | lower.tryDivide(other.upper, ROUND_DOWN), | ||
455 | upper.tryDivide(other.lower, ROUND_UP) | ||
456 | ) | ||
457 | } else { // lower < 0 < upper | ||
458 | new NonEmpty( | ||
459 | lower.tryDivide(other.lower, ROUND_DOWN), | ||
460 | upper.tryDivide(other.lower, ROUND_UP) | ||
461 | ) | ||
462 | } | ||
463 | } else { // other contains 0 | ||
464 | if (other.lower == BigDecimal.ZERO) { // 0 == other.lower < other.upper, because [0, 0] was exluded earlier | ||
465 | if (nonpositive) { | ||
466 | new NonEmpty(null, upper.tryDivide(other.upper, ROUND_UP)) | ||
467 | } else if (nonnegative) { | ||
468 | new NonEmpty(lower.tryDivide(other.upper, ROUND_DOWN), null) | ||
469 | } else { // lower < 0 < upper | ||
470 | UNBOUNDED | ||
471 | } | ||
472 | } else if (other.upper == BigDecimal.ZERO) { // other.lower < other.upper == 0 | ||
473 | if (nonpositive) { | ||
474 | new NonEmpty(upper.tryDivide(other.lower, ROUND_DOWN), null) | ||
475 | } else if (nonnegative) { | ||
476 | new NonEmpty(null, lower.tryDivide(other.lower, ROUND_UP)) | ||
477 | } else { // lower < 0 < upper | ||
478 | UNBOUNDED | ||
479 | } | ||
480 | } else { // other.lower < 0 < other.upper | ||
481 | UNBOUNDED | ||
482 | } | ||
483 | } | ||
484 | } | ||
485 | |||
486 | private def isStrictlyNegative() { | ||
487 | upper !== null && upper < BigDecimal.ZERO | ||
488 | } | ||
489 | |||
490 | private def isStrictlyPositive() { | ||
491 | lower !== null && lower > BigDecimal.ZERO | ||
492 | } | ||
493 | |||
494 | private static def tryDivide(BigDecimal a, BigDecimal b, MathContext mc) { | ||
495 | if (b === null) { | ||
496 | BigDecimal.ZERO | ||
497 | } else { | ||
498 | a?.divide(b, mc) | ||
499 | } | ||
500 | } | ||
501 | |||
502 | override **(Interval other) { | ||
503 | switch (other) { | ||
504 | case EMPTY: EMPTY | ||
505 | NonEmpty: this ** other | ||
506 | default: throw new IllegalArgumentException("Unknown interval: " + other) | ||
507 | } | ||
508 | } | ||
509 | |||
510 | def **(NonEmpty other) { | ||
511 | // XXX This should use proper rounding for log and exp instead of | ||
512 | // converting to double. | ||
513 | // XXX We should not ignore (integer) powers of negative numbers. | ||
514 | val lowerLog = if (lower === null || lower <= BigDecimal.ZERO) { | ||
515 | null | ||
516 | } else { | ||
517 | new BigDecimal(Math.log(lower.doubleValue), ROUND_DOWN) | ||
518 | } | ||
519 | val upperLog = if (upper === null) { | ||
520 | null | ||
521 | } else if (upper == BigDecimal.ZERO) { | ||
522 | return ZERO | ||
523 | } else if (upper < BigDecimal.ZERO) { | ||
524 | return EMPTY | ||
525 | } else { | ||
526 | new BigDecimal(Math.log(upper.doubleValue), ROUND_UP) | ||
527 | } | ||
528 | val log = new NonEmpty(lowerLog, upperLog) | ||
529 | val product = log * other | ||
530 | if (product instanceof NonEmpty) { | ||
531 | val lowerResult = if (product.lower === null) { | ||
532 | BigDecimal.ZERO | ||
533 | } else { | ||
534 | new BigDecimal(Math.exp(product.lower.doubleValue), ROUND_DOWN) | ||
535 | } | ||
536 | val upperResult = if (product.upper === null) { | ||
537 | null | ||
538 | } else { | ||
539 | new BigDecimal(Math.exp(product.upper.doubleValue), ROUND_UP) | ||
540 | } | ||
541 | new NonEmpty(lowerResult, upperResult) | ||
542 | } else { | ||
543 | throw new IllegalArgumentException("Unknown interval: " + product) | ||
544 | } | ||
545 | } | ||
546 | |||
547 | override toString() { | ||
548 | '''«IF lower === null»(-∞«ELSE»[«lower»«ENDIF», «IF upper === null»∞)«ELSE»«upper»]«ENDIF»''' | ||
549 | } | ||
550 | |||
551 | override compareTo(Interval o) { | ||
552 | switch (o) { | ||
553 | case EMPTY: 1 | ||
554 | NonEmpty: compareTo(o) | ||
555 | default: throw new IllegalArgumentException("Unknown interval: " + o) | ||
556 | } | ||
557 | } | ||
558 | |||
559 | def compareTo(NonEmpty o) { | ||
560 | if (lower === null) { | ||
561 | if (o.lower !== null) { | ||
562 | return -1 | ||
563 | } | ||
564 | } else if (o.lower === null) { // lower !== null | ||
565 | return 1 | ||
566 | } else { // both lower and o.lower are finite | ||
567 | val lowerDifference = lower.compareTo(o.lower) | ||
568 | if (lowerDifference != 0) { | ||
569 | return lowerDifference | ||
570 | } | ||
571 | } | ||
572 | if (upper === null) { | ||
573 | if (o.upper === null) { | ||
574 | return 0 | ||
575 | } else { | ||
576 | return 1 | ||
577 | } | ||
578 | } else if (o.upper === null) { // upper !== null | ||
579 | return -1 | ||
580 | } | ||
581 | upper.compareTo(o.upper) | ||
582 | } | ||
583 | } | ||
584 | } | ||