aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar Kristóf Marussy <kris7topher@gmail.com>2019-05-10 17:27:13 -0400
committerLibravatar Kristóf Marussy <kris7topher@gmail.com>2019-05-10 17:27:13 -0400
commit9670538a0e5630edecab8aaf4ba38ae6c81e8606 (patch)
tree155c4dc953dec6d99b5c89ae1029863b3db9ca94
parentMore aggregation operators (diff)
downloadVIATRA-Generator-9670538a0e5630edecab8aaf4ba38ae6c81e8606.tar.gz
VIATRA-Generator-9670538a0e5630edecab8aaf4ba38ae6c81e8606.tar.zst
VIATRA-Generator-9670538a0e5630edecab8aaf4ba38ae6c81e8606.zip
Interval power and aggregator fix
-rw-r--r--Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend123
-rw-r--r--Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MinAggregatorTest.xtend67
-rw-r--r--Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/PowerTest.xtend43
-rw-r--r--Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumAggregatorTest.xtend (renamed from Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend)2
4 files changed, 198 insertions, 37 deletions
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend
index 173be0be..4f0f594f 100644
--- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend
+++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend
@@ -54,26 +54,28 @@ abstract class Interval implements Comparable<Interval> {
54 } 54 }
55 55
56 abstract def Interval min(Interval other) 56 abstract def Interval min(Interval other)
57 57
58 abstract def Interval max(Interval other) 58 abstract def Interval max(Interval other)
59 59
60 abstract def Interval join(Interval other) 60 abstract def Interval join(Interval other)
61 61
62 def operator_plus() { 62 def +() {
63 this 63 this
64 } 64 }
65 65
66 abstract def Interval operator_minus() 66 abstract def Interval -()
67
68 abstract def Interval +(Interval other)
67 69
68 abstract def Interval operator_plus(Interval other) 70 abstract def Interval -(Interval other)
69 71
70 abstract def Interval operator_minus(Interval other) 72 abstract def Interval *(int count)
71 73
72 abstract def Interval operator_multiply(int count) 74 abstract def Interval *(Interval other)
73 75
74 abstract def Interval operator_multiply(Interval other) 76 abstract def Interval /(Interval other)
75 77
76 abstract def Interval operator_divide(Interval other) 78 abstract def Interval **(Interval other)
77 79
78 public static val EMPTY = new Interval { 80 public static val EMPTY = new Interval {
79 override mustEqual(Interval other) { 81 override mustEqual(Interval other) {
@@ -95,7 +97,7 @@ abstract class Interval implements Comparable<Interval> {
95 override min(Interval other) { 97 override min(Interval other) {
96 EMPTY 98 EMPTY
97 } 99 }
98 100
99 override max(Interval other) { 101 override max(Interval other) {
100 EMPTY 102 EMPTY
101 } 103 }
@@ -104,27 +106,31 @@ abstract class Interval implements Comparable<Interval> {
104 other 106 other
105 } 107 }
106 108
107 override operator_minus() { 109 override -() {
108 EMPTY 110 EMPTY
109 } 111 }
110 112
111 override operator_plus(Interval other) { 113 override +(Interval other) {
112 EMPTY 114 EMPTY
113 } 115 }
114 116
115 override operator_minus(Interval other) { 117 override -(Interval other) {
116 EMPTY 118 EMPTY
117 } 119 }
118 120
119 override operator_multiply(int count) { 121 override *(int count) {
120 EMPTY 122 EMPTY
121 } 123 }
122 124
123 override operator_multiply(Interval other) { 125 override *(Interval other) {
124 EMPTY 126 EMPTY
125 } 127 }
126 128
127 override operator_divide(Interval other) { 129 override /(Interval other) {
130 EMPTY
131 }
132
133 override **(Interval other) {
128 EMPTY 134 EMPTY
129 } 135 }
130 136
@@ -221,14 +227,14 @@ abstract class Interval implements Comparable<Interval> {
221 default: throw new IllegalArgumentException("Unknown interval: " + other) 227 default: throw new IllegalArgumentException("Unknown interval: " + other)
222 } 228 }
223 } 229 }
224 230
225 def min(NonEmpty other) { 231 def min(NonEmpty other) {
226 new NonEmpty( 232 new NonEmpty(
227 lower.tryMin(other.lower), 233 lower.tryMin(other.lower),
228 if (other.upper === null) upper else upper?.min(other.upper) 234 if(other.upper === null) upper else if(upper === null) other.upper else upper.min(other.upper)
229 ) 235 )
230 } 236 }
231 237
232 override max(Interval other) { 238 override max(Interval other) {
233 switch (other) { 239 switch (other) {
234 case EMPTY: this 240 case EMPTY: this
@@ -236,10 +242,10 @@ abstract class Interval implements Comparable<Interval> {
236 default: throw new IllegalArgumentException("Unknown interval: " + other) 242 default: throw new IllegalArgumentException("Unknown interval: " + other)
237 } 243 }
238 } 244 }
239 245
240 def max(NonEmpty other) { 246 def max(NonEmpty other) {
241 new NonEmpty( 247 new NonEmpty(
242 if (other.lower === null) lower else lower?.min(other.lower), 248 if(other.lower === null) lower else if(lower === null) other.lower else lower.max(other.lower),
243 upper.tryMax(other.upper) 249 upper.tryMax(other.upper)
244 ) 250 )
245 } 251 }
@@ -252,19 +258,19 @@ abstract class Interval implements Comparable<Interval> {
252 } 258 }
253 } 259 }
254 260
255 override operator_minus() { 261 override -() {
256 new NonEmpty(upper?.negate(ROUND_DOWN), lower?.negate(ROUND_UP)) 262 new NonEmpty(upper?.negate(ROUND_DOWN), lower?.negate(ROUND_UP))
257 } 263 }
258 264
259 override operator_plus(Interval other) { 265 override +(Interval other) {
260 switch (other) { 266 switch (other) {
261 case EMPTY: EMPTY 267 case EMPTY: EMPTY
262 NonEmpty: operator_plus(other) 268 NonEmpty: this + other
263 default: throw new IllegalArgumentException("Unknown interval: " + other) 269 default: throw new IllegalArgumentException("Unknown interval: " + other)
264 } 270 }
265 } 271 }
266 272
267 def operator_plus(NonEmpty other) { 273 def +(NonEmpty other) {
268 new NonEmpty( 274 new NonEmpty(
269 lower.tryAdd(other.lower, ROUND_DOWN), 275 lower.tryAdd(other.lower, ROUND_DOWN),
270 upper.tryAdd(other.upper, ROUND_UP) 276 upper.tryAdd(other.upper, ROUND_UP)
@@ -279,15 +285,15 @@ abstract class Interval implements Comparable<Interval> {
279 } 285 }
280 } 286 }
281 287
282 override operator_minus(Interval other) { 288 override -(Interval other) {
283 switch (other) { 289 switch (other) {
284 case EMPTY: EMPTY 290 case EMPTY: EMPTY
285 NonEmpty: operator_minus(other) 291 NonEmpty: this - other
286 default: throw new IllegalArgumentException("Unknown interval: " + other) 292 default: throw new IllegalArgumentException("Unknown interval: " + other)
287 } 293 }
288 } 294 }
289 295
290 def operator_minus(NonEmpty other) { 296 def -(NonEmpty other) {
291 new NonEmpty( 297 new NonEmpty(
292 lower.trySubtract(other.upper, ROUND_DOWN), 298 lower.trySubtract(other.upper, ROUND_DOWN),
293 upper.trySubtract(other.lower, ROUND_UP) 299 upper.trySubtract(other.lower, ROUND_UP)
@@ -302,7 +308,7 @@ abstract class Interval implements Comparable<Interval> {
302 } 308 }
303 } 309 }
304 310
305 override operator_multiply(int count) { 311 override *(int count) {
306 val bigCount = new BigDecimal(count) 312 val bigCount = new BigDecimal(count)
307 new NonEmpty( 313 new NonEmpty(
308 lower.tryMultiply(bigCount, ROUND_DOWN), 314 lower.tryMultiply(bigCount, ROUND_DOWN),
@@ -310,15 +316,15 @@ abstract class Interval implements Comparable<Interval> {
310 ) 316 )
311 } 317 }
312 318
313 override operator_multiply(Interval other) { 319 override *(Interval other) {
314 switch (other) { 320 switch (other) {
315 case EMPTY: EMPTY 321 case EMPTY: EMPTY
316 NonEmpty: operator_multiply(other) 322 NonEmpty: this * other
317 default: throw new IllegalArgumentException("") 323 default: throw new IllegalArgumentException("Unknown interval: " + other)
318 } 324 }
319 } 325 }
320 326
321 def operator_multiply(NonEmpty other) { 327 def *(NonEmpty other) {
322 if (this == ZERO || other == ZERO) { 328 if (this == ZERO || other == ZERO) {
323 ZERO 329 ZERO
324 } else if (nonpositive) { 330 } else if (nonpositive) {
@@ -407,15 +413,15 @@ abstract class Interval implements Comparable<Interval> {
407 } 413 }
408 } 414 }
409 415
410 override operator_divide(Interval other) { 416 override /(Interval other) {
411 switch (other) { 417 switch (other) {
412 case EMPTY: EMPTY 418 case EMPTY: EMPTY
413 NonEmpty: operator_divide(other) 419 NonEmpty: this / other
414 default: throw new IllegalArgumentException("Unknown interval: " + other) 420 default: throw new IllegalArgumentException("Unknown interval: " + other)
415 } 421 }
416 } 422 }
417 423
418 def operator_divide(NonEmpty other) { 424 def /(NonEmpty other) {
419 if (other == ZERO) { 425 if (other == ZERO) {
420 EMPTY 426 EMPTY
421 } else if (this == ZERO) { 427 } else if (this == ZERO) {
@@ -493,6 +499,51 @@ abstract class Interval implements Comparable<Interval> {
493 } 499 }
494 } 500 }
495 501
502 override **(Interval other) {
503 switch (other) {
504 case EMPTY: EMPTY
505 NonEmpty: this ** other
506 default: throw new IllegalArgumentException("Unknown interval: " + other)
507 }
508 }
509
510 def **(NonEmpty other) {
511 // XXX This should use proper rounding for log and exp instead of
512 // converting to double.
513 // XXX We should not ignore (integer) powers of negative numbers.
514 val lowerLog = if (lower === null || lower <= BigDecimal.ZERO) {
515 null
516 } else {
517 new BigDecimal(Math.log(lower.doubleValue), ROUND_DOWN)
518 }
519 val upperLog = if (upper === null) {
520 null
521 } else if (upper == BigDecimal.ZERO) {
522 return ZERO
523 } else if (upper < BigDecimal.ZERO) {
524 return EMPTY
525 } else {
526 new BigDecimal(Math.log(upper.doubleValue), ROUND_UP)
527 }
528 val log = new NonEmpty(lowerLog, upperLog)
529 val product = log * other
530 if (product instanceof NonEmpty) {
531 val lowerResult = if (product.lower === null) {
532 BigDecimal.ZERO
533 } else {
534 new BigDecimal(Math.exp(product.lower.doubleValue), ROUND_DOWN)
535 }
536 val upperResult = if (product.upper === null) {
537 null
538 } else {
539 new BigDecimal(Math.exp(product.upper.doubleValue), ROUND_UP)
540 }
541 new NonEmpty(lowerResult, upperResult)
542 } else {
543 throw new IllegalArgumentException("Unknown interval: " + product)
544 }
545 }
546
496 override toString() { 547 override toString() {
497 '''«IF lower === null»(-∞«ELSE»[«lower»«ENDIF», «IF upper === null»∞)«ELSE»«upper»]«ENDIF»''' 548 '''«IF lower === null»(-∞«ELSE»[«lower»«ENDIF», «IF upper === null»∞)«ELSE»«upper»]«ENDIF»'''
498 } 549 }
@@ -501,7 +552,7 @@ abstract class Interval implements Comparable<Interval> {
501 switch (o) { 552 switch (o) {
502 case EMPTY: 1 553 case EMPTY: 1
503 NonEmpty: compareTo(o) 554 NonEmpty: compareTo(o)
504 default: throw new IllegalArgumentException("") 555 default: throw new IllegalArgumentException("Unknown interval: " + o)
505 } 556 }
506 } 557 }
507 558
diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MinAggregatorTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MinAggregatorTest.xtend
new file mode 100644
index 00000000..7d46e16c
--- /dev/null
+++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MinAggregatorTest.xtend
@@ -0,0 +1,67 @@
1package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval
2
3import com.google.common.collect.HashMultiset
4import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval
5import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationMode
6import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationOperator
7import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalRedBlackNode
8import java.math.BigDecimal
9import java.util.Random
10import org.junit.Assert
11import org.junit.Before
12import org.junit.Test
13
14import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.*
15
16class MinAggregatorTest {
17 val aggregator = new IntervalAggregationOperator(IntervalAggregationMode.MIN)
18 var IntervalRedBlackNode value = null
19
20 @Before
21 def void reset() {
22 value = aggregator.createNeutral
23 }
24
25 @Test
26 def void emptyTest() {
27 assertEquals(EMPTY)
28 }
29
30 @Test
31 def void largeTest() {
32 val starts = #[null, new BigDecimal(-3), new BigDecimal(-2), new BigDecimal(-1)]
33 val ends = #[new BigDecimal(1), new BigDecimal(2), new BigDecimal(3), null]
34 val current = HashMultiset.create
35 val random = new Random(1)
36 for (var int i = 0; i < 1000; i++) {
37 val start = starts.get(random.nextInt(starts.size))
38 val end = ends.get(random.nextInt(ends.size))
39 val interval = Interval.of(start, end)
40 val isInsert = !current.contains(interval) || random.nextInt(3) == 0
41 if (isInsert) {
42 current.add(interval)
43 } else {
44 current.remove(interval)
45 }
46 val expected = current.stream.reduce(aggregator.mode).orElse(EMPTY)
47 update(interval, isInsert)
48 assertEquals(expected)
49 }
50 }
51
52 private def update(Interval interval, boolean isInsert) {
53 value = aggregator.update(value, interval, isInsert)
54 val nodes = newArrayList
55 var node = value.min
56 while (node !== null) {
57 nodes += node
58 node = node.successor
59 }
60 value.assertSubtreeIsValid
61 }
62
63 private def assertEquals(Interval interval) {
64 val actual = aggregator.getAggregate(value)
65 Assert.assertEquals(interval, actual)
66 }
67}
diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/PowerTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/PowerTest.xtend
new file mode 100644
index 00000000..c842d90d
--- /dev/null
+++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/PowerTest.xtend
@@ -0,0 +1,43 @@
1package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval
2
3import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval
4import java.util.Collection
5import org.junit.Assert
6import org.junit.Test
7import org.junit.runner.RunWith
8import org.junit.runners.Parameterized
9import org.junit.runners.Parameterized.Parameter
10import org.junit.runners.Parameterized.Parameters
11
12import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.*
13
14@RunWith(Parameterized)
15class PowerTest {
16 @Parameters(name="{index}: {0} ** {1} = {2}")
17 static def Collection<Object[]> data() {
18 #[
19 #[EMPTY, EMPTY, EMPTY],
20 #[EMPTY, between(-1, 1), EMPTY],
21 #[between(-1, 1), EMPTY, EMPTY],
22 #[upTo(-1), between(-1, 2), EMPTY],
23 #[upTo(0), between(-1, 2), between(0, 0)],
24 #[upTo(2), between(-1, 2), above(0)],
25 #[upTo(2), between(1, 2), between(0, 4)],
26 #[above(1), between(1, 2), above(1)],
27 #[between(2, 4), upTo(1), between(0, 4)],
28 #[between(0.25, 0.5), upTo(1), above(0.25)],
29 #[between(2, 3), above(1), above(2)],
30 #[between(0.25, 0.5), above(1), between(0, 0.5)],
31 #[between(1, 2), between(-1, 2), between(0.5, 4)]
32 ]
33 }
34
35 @Parameter(0) public var Interval a
36 @Parameter(1) public var Interval b
37 @Parameter(2) public var Interval result
38
39 @Test
40 def void powerTest() {
41 Assert.assertEquals(result, a ** b)
42 }
43}
diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumAggregatorTest.xtend
index 530c081c..56172b6c 100644
--- a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend
+++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumAggregatorTest.xtend
@@ -13,7 +13,7 @@ import org.junit.Test
13 13
14import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* 14import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.*
15 15
16class SumTest { 16class SumAggregatorTest {
17 val aggregator = new IntervalAggregationOperator(IntervalAggregationMode.SUM) 17 val aggregator = new IntervalAggregationOperator(IntervalAggregationMode.SUM)
18 var IntervalRedBlackNode value = null 18 var IntervalRedBlackNode value = null
19 19