diff options
-rw-r--r-- | Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend | 123 | ||||
-rw-r--r-- | Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MinAggregatorTest.xtend | 67 | ||||
-rw-r--r-- | Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/PowerTest.xtend | 43 | ||||
-rw-r--r-- | Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumAggregatorTest.xtend (renamed from Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend) | 2 |
4 files changed, 198 insertions, 37 deletions
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend index 173be0be..4f0f594f 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend | |||
@@ -54,26 +54,28 @@ abstract class Interval implements Comparable<Interval> { | |||
54 | } | 54 | } |
55 | 55 | ||
56 | abstract def Interval min(Interval other) | 56 | abstract def Interval min(Interval other) |
57 | 57 | ||
58 | abstract def Interval max(Interval other) | 58 | abstract def Interval max(Interval other) |
59 | 59 | ||
60 | abstract def Interval join(Interval other) | 60 | abstract def Interval join(Interval other) |
61 | 61 | ||
62 | def operator_plus() { | 62 | def +() { |
63 | this | 63 | this |
64 | } | 64 | } |
65 | 65 | ||
66 | abstract def Interval operator_minus() | 66 | abstract def Interval -() |
67 | |||
68 | abstract def Interval +(Interval other) | ||
67 | 69 | ||
68 | abstract def Interval operator_plus(Interval other) | 70 | abstract def Interval -(Interval other) |
69 | 71 | ||
70 | abstract def Interval operator_minus(Interval other) | 72 | abstract def Interval *(int count) |
71 | 73 | ||
72 | abstract def Interval operator_multiply(int count) | 74 | abstract def Interval *(Interval other) |
73 | 75 | ||
74 | abstract def Interval operator_multiply(Interval other) | 76 | abstract def Interval /(Interval other) |
75 | 77 | ||
76 | abstract def Interval operator_divide(Interval other) | 78 | abstract def Interval **(Interval other) |
77 | 79 | ||
78 | public static val EMPTY = new Interval { | 80 | public static val EMPTY = new Interval { |
79 | override mustEqual(Interval other) { | 81 | override mustEqual(Interval other) { |
@@ -95,7 +97,7 @@ abstract class Interval implements Comparable<Interval> { | |||
95 | override min(Interval other) { | 97 | override min(Interval other) { |
96 | EMPTY | 98 | EMPTY |
97 | } | 99 | } |
98 | 100 | ||
99 | override max(Interval other) { | 101 | override max(Interval other) { |
100 | EMPTY | 102 | EMPTY |
101 | } | 103 | } |
@@ -104,27 +106,31 @@ abstract class Interval implements Comparable<Interval> { | |||
104 | other | 106 | other |
105 | } | 107 | } |
106 | 108 | ||
107 | override operator_minus() { | 109 | override -() { |
108 | EMPTY | 110 | EMPTY |
109 | } | 111 | } |
110 | 112 | ||
111 | override operator_plus(Interval other) { | 113 | override +(Interval other) { |
112 | EMPTY | 114 | EMPTY |
113 | } | 115 | } |
114 | 116 | ||
115 | override operator_minus(Interval other) { | 117 | override -(Interval other) { |
116 | EMPTY | 118 | EMPTY |
117 | } | 119 | } |
118 | 120 | ||
119 | override operator_multiply(int count) { | 121 | override *(int count) { |
120 | EMPTY | 122 | EMPTY |
121 | } | 123 | } |
122 | 124 | ||
123 | override operator_multiply(Interval other) { | 125 | override *(Interval other) { |
124 | EMPTY | 126 | EMPTY |
125 | } | 127 | } |
126 | 128 | ||
127 | override operator_divide(Interval other) { | 129 | override /(Interval other) { |
130 | EMPTY | ||
131 | } | ||
132 | |||
133 | override **(Interval other) { | ||
128 | EMPTY | 134 | EMPTY |
129 | } | 135 | } |
130 | 136 | ||
@@ -221,14 +227,14 @@ abstract class Interval implements Comparable<Interval> { | |||
221 | default: throw new IllegalArgumentException("Unknown interval: " + other) | 227 | default: throw new IllegalArgumentException("Unknown interval: " + other) |
222 | } | 228 | } |
223 | } | 229 | } |
224 | 230 | ||
225 | def min(NonEmpty other) { | 231 | def min(NonEmpty other) { |
226 | new NonEmpty( | 232 | new NonEmpty( |
227 | lower.tryMin(other.lower), | 233 | lower.tryMin(other.lower), |
228 | if (other.upper === null) upper else upper?.min(other.upper) | 234 | if(other.upper === null) upper else if(upper === null) other.upper else upper.min(other.upper) |
229 | ) | 235 | ) |
230 | } | 236 | } |
231 | 237 | ||
232 | override max(Interval other) { | 238 | override max(Interval other) { |
233 | switch (other) { | 239 | switch (other) { |
234 | case EMPTY: this | 240 | case EMPTY: this |
@@ -236,10 +242,10 @@ abstract class Interval implements Comparable<Interval> { | |||
236 | default: throw new IllegalArgumentException("Unknown interval: " + other) | 242 | default: throw new IllegalArgumentException("Unknown interval: " + other) |
237 | } | 243 | } |
238 | } | 244 | } |
239 | 245 | ||
240 | def max(NonEmpty other) { | 246 | def max(NonEmpty other) { |
241 | new NonEmpty( | 247 | new NonEmpty( |
242 | if (other.lower === null) lower else lower?.min(other.lower), | 248 | if(other.lower === null) lower else if(lower === null) other.lower else lower.max(other.lower), |
243 | upper.tryMax(other.upper) | 249 | upper.tryMax(other.upper) |
244 | ) | 250 | ) |
245 | } | 251 | } |
@@ -252,19 +258,19 @@ abstract class Interval implements Comparable<Interval> { | |||
252 | } | 258 | } |
253 | } | 259 | } |
254 | 260 | ||
255 | override operator_minus() { | 261 | override -() { |
256 | new NonEmpty(upper?.negate(ROUND_DOWN), lower?.negate(ROUND_UP)) | 262 | new NonEmpty(upper?.negate(ROUND_DOWN), lower?.negate(ROUND_UP)) |
257 | } | 263 | } |
258 | 264 | ||
259 | override operator_plus(Interval other) { | 265 | override +(Interval other) { |
260 | switch (other) { | 266 | switch (other) { |
261 | case EMPTY: EMPTY | 267 | case EMPTY: EMPTY |
262 | NonEmpty: operator_plus(other) | 268 | NonEmpty: this + other |
263 | default: throw new IllegalArgumentException("Unknown interval: " + other) | 269 | default: throw new IllegalArgumentException("Unknown interval: " + other) |
264 | } | 270 | } |
265 | } | 271 | } |
266 | 272 | ||
267 | def operator_plus(NonEmpty other) { | 273 | def +(NonEmpty other) { |
268 | new NonEmpty( | 274 | new NonEmpty( |
269 | lower.tryAdd(other.lower, ROUND_DOWN), | 275 | lower.tryAdd(other.lower, ROUND_DOWN), |
270 | upper.tryAdd(other.upper, ROUND_UP) | 276 | upper.tryAdd(other.upper, ROUND_UP) |
@@ -279,15 +285,15 @@ abstract class Interval implements Comparable<Interval> { | |||
279 | } | 285 | } |
280 | } | 286 | } |
281 | 287 | ||
282 | override operator_minus(Interval other) { | 288 | override -(Interval other) { |
283 | switch (other) { | 289 | switch (other) { |
284 | case EMPTY: EMPTY | 290 | case EMPTY: EMPTY |
285 | NonEmpty: operator_minus(other) | 291 | NonEmpty: this - other |
286 | default: throw new IllegalArgumentException("Unknown interval: " + other) | 292 | default: throw new IllegalArgumentException("Unknown interval: " + other) |
287 | } | 293 | } |
288 | } | 294 | } |
289 | 295 | ||
290 | def operator_minus(NonEmpty other) { | 296 | def -(NonEmpty other) { |
291 | new NonEmpty( | 297 | new NonEmpty( |
292 | lower.trySubtract(other.upper, ROUND_DOWN), | 298 | lower.trySubtract(other.upper, ROUND_DOWN), |
293 | upper.trySubtract(other.lower, ROUND_UP) | 299 | upper.trySubtract(other.lower, ROUND_UP) |
@@ -302,7 +308,7 @@ abstract class Interval implements Comparable<Interval> { | |||
302 | } | 308 | } |
303 | } | 309 | } |
304 | 310 | ||
305 | override operator_multiply(int count) { | 311 | override *(int count) { |
306 | val bigCount = new BigDecimal(count) | 312 | val bigCount = new BigDecimal(count) |
307 | new NonEmpty( | 313 | new NonEmpty( |
308 | lower.tryMultiply(bigCount, ROUND_DOWN), | 314 | lower.tryMultiply(bigCount, ROUND_DOWN), |
@@ -310,15 +316,15 @@ abstract class Interval implements Comparable<Interval> { | |||
310 | ) | 316 | ) |
311 | } | 317 | } |
312 | 318 | ||
313 | override operator_multiply(Interval other) { | 319 | override *(Interval other) { |
314 | switch (other) { | 320 | switch (other) { |
315 | case EMPTY: EMPTY | 321 | case EMPTY: EMPTY |
316 | NonEmpty: operator_multiply(other) | 322 | NonEmpty: this * other |
317 | default: throw new IllegalArgumentException("") | 323 | default: throw new IllegalArgumentException("Unknown interval: " + other) |
318 | } | 324 | } |
319 | } | 325 | } |
320 | 326 | ||
321 | def operator_multiply(NonEmpty other) { | 327 | def *(NonEmpty other) { |
322 | if (this == ZERO || other == ZERO) { | 328 | if (this == ZERO || other == ZERO) { |
323 | ZERO | 329 | ZERO |
324 | } else if (nonpositive) { | 330 | } else if (nonpositive) { |
@@ -407,15 +413,15 @@ abstract class Interval implements Comparable<Interval> { | |||
407 | } | 413 | } |
408 | } | 414 | } |
409 | 415 | ||
410 | override operator_divide(Interval other) { | 416 | override /(Interval other) { |
411 | switch (other) { | 417 | switch (other) { |
412 | case EMPTY: EMPTY | 418 | case EMPTY: EMPTY |
413 | NonEmpty: operator_divide(other) | 419 | NonEmpty: this / other |
414 | default: throw new IllegalArgumentException("Unknown interval: " + other) | 420 | default: throw new IllegalArgumentException("Unknown interval: " + other) |
415 | } | 421 | } |
416 | } | 422 | } |
417 | 423 | ||
418 | def operator_divide(NonEmpty other) { | 424 | def /(NonEmpty other) { |
419 | if (other == ZERO) { | 425 | if (other == ZERO) { |
420 | EMPTY | 426 | EMPTY |
421 | } else if (this == ZERO) { | 427 | } else if (this == ZERO) { |
@@ -493,6 +499,51 @@ abstract class Interval implements Comparable<Interval> { | |||
493 | } | 499 | } |
494 | } | 500 | } |
495 | 501 | ||
502 | override **(Interval other) { | ||
503 | switch (other) { | ||
504 | case EMPTY: EMPTY | ||
505 | NonEmpty: this ** other | ||
506 | default: throw new IllegalArgumentException("Unknown interval: " + other) | ||
507 | } | ||
508 | } | ||
509 | |||
510 | def **(NonEmpty other) { | ||
511 | // XXX This should use proper rounding for log and exp instead of | ||
512 | // converting to double. | ||
513 | // XXX We should not ignore (integer) powers of negative numbers. | ||
514 | val lowerLog = if (lower === null || lower <= BigDecimal.ZERO) { | ||
515 | null | ||
516 | } else { | ||
517 | new BigDecimal(Math.log(lower.doubleValue), ROUND_DOWN) | ||
518 | } | ||
519 | val upperLog = if (upper === null) { | ||
520 | null | ||
521 | } else if (upper == BigDecimal.ZERO) { | ||
522 | return ZERO | ||
523 | } else if (upper < BigDecimal.ZERO) { | ||
524 | return EMPTY | ||
525 | } else { | ||
526 | new BigDecimal(Math.log(upper.doubleValue), ROUND_UP) | ||
527 | } | ||
528 | val log = new NonEmpty(lowerLog, upperLog) | ||
529 | val product = log * other | ||
530 | if (product instanceof NonEmpty) { | ||
531 | val lowerResult = if (product.lower === null) { | ||
532 | BigDecimal.ZERO | ||
533 | } else { | ||
534 | new BigDecimal(Math.exp(product.lower.doubleValue), ROUND_DOWN) | ||
535 | } | ||
536 | val upperResult = if (product.upper === null) { | ||
537 | null | ||
538 | } else { | ||
539 | new BigDecimal(Math.exp(product.upper.doubleValue), ROUND_UP) | ||
540 | } | ||
541 | new NonEmpty(lowerResult, upperResult) | ||
542 | } else { | ||
543 | throw new IllegalArgumentException("Unknown interval: " + product) | ||
544 | } | ||
545 | } | ||
546 | |||
496 | override toString() { | 547 | override toString() { |
497 | '''«IF lower === null»(-∞«ELSE»[«lower»«ENDIF», «IF upper === null»∞)«ELSE»«upper»]«ENDIF»''' | 548 | '''«IF lower === null»(-∞«ELSE»[«lower»«ENDIF», «IF upper === null»∞)«ELSE»«upper»]«ENDIF»''' |
498 | } | 549 | } |
@@ -501,7 +552,7 @@ abstract class Interval implements Comparable<Interval> { | |||
501 | switch (o) { | 552 | switch (o) { |
502 | case EMPTY: 1 | 553 | case EMPTY: 1 |
503 | NonEmpty: compareTo(o) | 554 | NonEmpty: compareTo(o) |
504 | default: throw new IllegalArgumentException("") | 555 | default: throw new IllegalArgumentException("Unknown interval: " + o) |
505 | } | 556 | } |
506 | } | 557 | } |
507 | 558 | ||
diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MinAggregatorTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MinAggregatorTest.xtend new file mode 100644 index 00000000..7d46e16c --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MinAggregatorTest.xtend | |||
@@ -0,0 +1,67 @@ | |||
1 | package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval | ||
2 | |||
3 | import com.google.common.collect.HashMultiset | ||
4 | import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval | ||
5 | import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationMode | ||
6 | import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationOperator | ||
7 | import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalRedBlackNode | ||
8 | import java.math.BigDecimal | ||
9 | import java.util.Random | ||
10 | import org.junit.Assert | ||
11 | import org.junit.Before | ||
12 | import org.junit.Test | ||
13 | |||
14 | import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* | ||
15 | |||
16 | class MinAggregatorTest { | ||
17 | val aggregator = new IntervalAggregationOperator(IntervalAggregationMode.MIN) | ||
18 | var IntervalRedBlackNode value = null | ||
19 | |||
20 | @Before | ||
21 | def void reset() { | ||
22 | value = aggregator.createNeutral | ||
23 | } | ||
24 | |||
25 | @Test | ||
26 | def void emptyTest() { | ||
27 | assertEquals(EMPTY) | ||
28 | } | ||
29 | |||
30 | @Test | ||
31 | def void largeTest() { | ||
32 | val starts = #[null, new BigDecimal(-3), new BigDecimal(-2), new BigDecimal(-1)] | ||
33 | val ends = #[new BigDecimal(1), new BigDecimal(2), new BigDecimal(3), null] | ||
34 | val current = HashMultiset.create | ||
35 | val random = new Random(1) | ||
36 | for (var int i = 0; i < 1000; i++) { | ||
37 | val start = starts.get(random.nextInt(starts.size)) | ||
38 | val end = ends.get(random.nextInt(ends.size)) | ||
39 | val interval = Interval.of(start, end) | ||
40 | val isInsert = !current.contains(interval) || random.nextInt(3) == 0 | ||
41 | if (isInsert) { | ||
42 | current.add(interval) | ||
43 | } else { | ||
44 | current.remove(interval) | ||
45 | } | ||
46 | val expected = current.stream.reduce(aggregator.mode).orElse(EMPTY) | ||
47 | update(interval, isInsert) | ||
48 | assertEquals(expected) | ||
49 | } | ||
50 | } | ||
51 | |||
52 | private def update(Interval interval, boolean isInsert) { | ||
53 | value = aggregator.update(value, interval, isInsert) | ||
54 | val nodes = newArrayList | ||
55 | var node = value.min | ||
56 | while (node !== null) { | ||
57 | nodes += node | ||
58 | node = node.successor | ||
59 | } | ||
60 | value.assertSubtreeIsValid | ||
61 | } | ||
62 | |||
63 | private def assertEquals(Interval interval) { | ||
64 | val actual = aggregator.getAggregate(value) | ||
65 | Assert.assertEquals(interval, actual) | ||
66 | } | ||
67 | } | ||
diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/PowerTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/PowerTest.xtend new file mode 100644 index 00000000..c842d90d --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/PowerTest.xtend | |||
@@ -0,0 +1,43 @@ | |||
1 | package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval | ||
2 | |||
3 | import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval | ||
4 | import java.util.Collection | ||
5 | import org.junit.Assert | ||
6 | import org.junit.Test | ||
7 | import org.junit.runner.RunWith | ||
8 | import org.junit.runners.Parameterized | ||
9 | import org.junit.runners.Parameterized.Parameter | ||
10 | import org.junit.runners.Parameterized.Parameters | ||
11 | |||
12 | import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* | ||
13 | |||
14 | @RunWith(Parameterized) | ||
15 | class PowerTest { | ||
16 | @Parameters(name="{index}: {0} ** {1} = {2}") | ||
17 | static def Collection<Object[]> data() { | ||
18 | #[ | ||
19 | #[EMPTY, EMPTY, EMPTY], | ||
20 | #[EMPTY, between(-1, 1), EMPTY], | ||
21 | #[between(-1, 1), EMPTY, EMPTY], | ||
22 | #[upTo(-1), between(-1, 2), EMPTY], | ||
23 | #[upTo(0), between(-1, 2), between(0, 0)], | ||
24 | #[upTo(2), between(-1, 2), above(0)], | ||
25 | #[upTo(2), between(1, 2), between(0, 4)], | ||
26 | #[above(1), between(1, 2), above(1)], | ||
27 | #[between(2, 4), upTo(1), between(0, 4)], | ||
28 | #[between(0.25, 0.5), upTo(1), above(0.25)], | ||
29 | #[between(2, 3), above(1), above(2)], | ||
30 | #[between(0.25, 0.5), above(1), between(0, 0.5)], | ||
31 | #[between(1, 2), between(-1, 2), between(0.5, 4)] | ||
32 | ] | ||
33 | } | ||
34 | |||
35 | @Parameter(0) public var Interval a | ||
36 | @Parameter(1) public var Interval b | ||
37 | @Parameter(2) public var Interval result | ||
38 | |||
39 | @Test | ||
40 | def void powerTest() { | ||
41 | Assert.assertEquals(result, a ** b) | ||
42 | } | ||
43 | } | ||
diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumAggregatorTest.xtend index 530c081c..56172b6c 100644 --- a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumAggregatorTest.xtend | |||
@@ -13,7 +13,7 @@ import org.junit.Test | |||
13 | 13 | ||
14 | import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* | 14 | import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* |
15 | 15 | ||
16 | class SumTest { | 16 | class SumAggregatorTest { |
17 | val aggregator = new IntervalAggregationOperator(IntervalAggregationMode.SUM) | 17 | val aggregator = new IntervalAggregationOperator(IntervalAggregationMode.SUM) |
18 | var IntervalRedBlackNode value = null | 18 | var IntervalRedBlackNode value = null |
19 | 19 | ||