1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
|
package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval
import com.google.common.collect.HashMultiset
import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval
import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationMode
import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationOperator
import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalRedBlackNode
import java.math.BigDecimal
import java.util.Random
import org.junit.Assert
import org.junit.Before
import org.junit.Test
import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.*
class SumAggregatorTest {
val aggregator = new IntervalAggregationOperator(IntervalAggregationMode.SUM)
var IntervalRedBlackNode value = null
@Before
def void reset() {
value = aggregator.createNeutral
}
@Test
def void emptyTest() {
assertEquals(ZERO)
}
@Test
def void addSingleTest() {
add(between(-1, 1))
assertEquals(between(-1, 1))
}
@Test
def void addRemoveTest() {
add(between(-1, 1))
remove(between(-1, 1))
assertEquals(ZERO)
}
@Test
def void addTwoTest() {
add(between(-1, 1))
add(above(2))
assertEquals(above(1))
}
@Test
def void addTwoRemoveFirstTest() {
add(between(-1, 1))
add(above(2))
remove(between(-1, 1))
assertEquals(above(2))
}
@Test
def void addTwoRemoveSecondTest() {
add(between(-1, 1))
add(above(2))
remove(above(2))
assertEquals(between(-1, 1))
}
@Test
def void addMultiplicityTest() {
add(between(-1, 1))
add(between(-1, 1))
add(between(-1, 1))
assertEquals(between(-3, 3))
}
@Test
def void removeAllMultiplicityTest() {
add(between(-1, 1))
add(between(-1, 1))
add(between(-1, 1))
remove(between(-1, 1))
remove(between(-1, 1))
remove(between(-1, 1))
assertEquals(ZERO)
}
@Test
def void removeSomeMultiplicityTest() {
add(between(-1, 1))
add(between(-1, 1))
add(between(-1, 1))
remove(between(-1, 1))
remove(between(-1, 1))
assertEquals(between(-1, 1))
}
@Test
def void largeTest() {
val starts = #[null, new BigDecimal(-3), new BigDecimal(-2), new BigDecimal(-1)]
val ends = #[new BigDecimal(1), new BigDecimal(2), new BigDecimal(3), null]
val current = HashMultiset.create
val random = new Random(1)
for (var int i = 0; i < 1000; i++) {
val start = starts.get(random.nextInt(starts.size))
val end = ends.get(random.nextInt(ends.size))
val interval = Interval.of(start, end)
val isInsert = !current.contains(interval) || random.nextInt(3) == 0
if (isInsert) {
current.add(interval)
} else {
current.remove(interval)
}
val expected = current.stream.reduce(aggregator.mode).orElse(ZERO)
update(interval, isInsert)
assertEquals(expected)
}
}
private def update(Interval interval, boolean isInsert) {
value = aggregator.update(value, interval, isInsert)
val nodes = newArrayList
var node = value.min
while (node !== null) {
nodes += node
node = node.successor
}
value.assertSubtreeIsValid
}
private def add(Interval interval) {
update(interval, true)
}
private def remove(Interval interval) {
update(interval, false)
}
private def assertEquals(Interval interval) {
val actual = aggregator.getAggregate(value)
Assert.assertEquals(interval, actual)
}
}
|