aboutsummaryrefslogtreecommitdiffstats
path: root/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumAggregatorTest.xtend
blob: 56172b6c22cd8d36a146f2d7b0b193bb297df7e5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval

import com.google.common.collect.HashMultiset
import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval
import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationMode
import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationOperator
import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalRedBlackNode
import java.math.BigDecimal
import java.util.Random
import org.junit.Assert
import org.junit.Before
import org.junit.Test

import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.*

class SumAggregatorTest {
	val aggregator = new IntervalAggregationOperator(IntervalAggregationMode.SUM)
	var IntervalRedBlackNode value = null

	@Before
	def void reset() {
		value = aggregator.createNeutral
	}

	@Test
	def void emptyTest() {
		assertEquals(ZERO)
	}

	@Test
	def void addSingleTest() {
		add(between(-1, 1))
		assertEquals(between(-1, 1))
	}

	@Test
	def void addRemoveTest() {
		add(between(-1, 1))
		remove(between(-1, 1))
		assertEquals(ZERO)
	}

	@Test
	def void addTwoTest() {
		add(between(-1, 1))
		add(above(2))
		assertEquals(above(1))
	}

	@Test
	def void addTwoRemoveFirstTest() {
		add(between(-1, 1))
		add(above(2))
		remove(between(-1, 1))
		assertEquals(above(2))
	}

	@Test
	def void addTwoRemoveSecondTest() {
		add(between(-1, 1))
		add(above(2))
		remove(above(2))
		assertEquals(between(-1, 1))
	}

	@Test
	def void addMultiplicityTest() {
		add(between(-1, 1))
		add(between(-1, 1))
		add(between(-1, 1))
		assertEquals(between(-3, 3))
	}

	@Test
	def void removeAllMultiplicityTest() {
		add(between(-1, 1))
		add(between(-1, 1))
		add(between(-1, 1))
		remove(between(-1, 1))
		remove(between(-1, 1))
		remove(between(-1, 1))
		assertEquals(ZERO)
	}

	@Test
	def void removeSomeMultiplicityTest() {
		add(between(-1, 1))
		add(between(-1, 1))
		add(between(-1, 1))
		remove(between(-1, 1))
		remove(between(-1, 1))
		assertEquals(between(-1, 1))
	}

	@Test
	def void largeTest() {
		val starts = #[null, new BigDecimal(-3), new BigDecimal(-2), new BigDecimal(-1)]
		val ends = #[new BigDecimal(1), new BigDecimal(2), new BigDecimal(3), null]
		val current = HashMultiset.create
		val random = new Random(1)
		for (var int i = 0; i < 1000; i++) {
			val start = starts.get(random.nextInt(starts.size))
			val end = ends.get(random.nextInt(ends.size))
			val interval = Interval.of(start, end)
			val isInsert = !current.contains(interval) || random.nextInt(3) == 0
			if (isInsert) {
				current.add(interval)
			} else {
				current.remove(interval)
			}
			val expected = current.stream.reduce(aggregator.mode).orElse(ZERO)
			update(interval, isInsert)
			assertEquals(expected)
		}
	}

	private def update(Interval interval, boolean isInsert) {
		value = aggregator.update(value, interval, isInsert)
		val nodes = newArrayList
		var node = value.min
		while (node !== null) {
			nodes += node
			node = node.successor
		}
		value.assertSubtreeIsValid
	}

	private def add(Interval interval) {
		update(interval, true)
	}

	private def remove(Interval interval) {
		update(interval, false)
	}

	private def assertEquals(Interval interval) {
		val actual = aggregator.getAggregate(value)
		Assert.assertEquals(interval, actual)
	}
}