aboutsummaryrefslogtreecommitdiffstats
path: root/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalRedBlackNode.xtend
blob: 3aa575bc9756ca59c430db2612162f91d321868e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval

abstract class IntervalRedBlackNode extends RedBlackNode<IntervalRedBlackNode> {
	public val Interval interval
	public var int count = 1
	public var Interval result

	new(Interval interval) {
		this.interval = interval
	}

	def boolean isMultiplicitySensitive() {
		false
	}

	def Interval multiply(Interval interval, int count) {
		interval
	}

	abstract def Interval op(Interval left, Interval right)

	override augment() {
		val value = calcualteAugmentation()
		if (result == value) {
			false
		} else {
			result = value
			true
		}
	}

	private def calcualteAugmentation() {
		var value = multiply(interval, count)
		if (!left.leaf) {
			value = op(value, left.result)
		}
		if (!right.leaf) {
			value = op(value, right.result)
		}
		value
	}

	override assertNodeIsValid() {
		super.assertNodeIsValid()
		if (leaf) {
			return
		}
		if (count <= 0) {
			throw new IllegalStateException("Node with nonpositive count")
		}
		val value = calcualteAugmentation()
		if (result != value) {
			throw new IllegalStateException("Node with invalid augmentation: " + result + " != " + value)
		}
	}

	override assertSubtreeIsValid() {
		super.assertSubtreeIsValid()
		assertNodeIsValid()
	}

	override compareTo(IntervalRedBlackNode other) {
		if (leaf || other.leaf) {
			throw new IllegalArgumentException("One of the nodes is a leaf node")
		}
		interval.compareTo(other.interval)
	}

	def add(IntervalRedBlackNode newNode) {
		if (parent !== null) {
			throw new IllegalArgumentException("This is not the root of a tree")
		}
		if (leaf) {
			newNode.isRed = false
			newNode.left = this
			newNode.right = this
			newNode.parent = null
			newNode.augment
			return newNode
		}
		val modifiedNode = addWithoutFixup(newNode)
		if (modifiedNode === newNode) {
			// Must augment here, because fixInsertion() might call augment()
			// on a node repeatedly, which might lose the change notification the
			// second time it is called, and the augmentation will fail to
			// reach the root.
			modifiedNode.augmentRecursively
			modifiedNode.isRed = true
			return modifiedNode.fixInsertion
		}
		if (multiplicitySensitive) {
			modifiedNode.augmentRecursively
		}
		this
	}

	private def addWithoutFixup(IntervalRedBlackNode newNode) {
		var node = this
		while (!node.leaf) {
			val comparison = node.interval.compareTo(newNode.interval)
			if (comparison < 0) {
				if (node.left.leaf) {
					newNode.left = node.left
					newNode.right = node.left
					node.left = newNode
					newNode.parent = node
					return newNode
				} else {
					node = node.left
				}
			} else if (comparison > 0) {
				if (node.right.leaf) {
					newNode.left = node.right
					newNode.right = node.right
					node.right = newNode
					newNode.parent = node
					return newNode
				} else {
					node = node.right
				}
			} else { // comparison == 0
				newNode.parent = null
				node.count++
				return node
			}
		}
		throw new IllegalStateException("Reached leaf node while searching for insertion point")
	}

	private def augmentRecursively() {
		for (var node = this; node !== null; node = node.parent) {
			if (!node.augment) {
				return
			}
		}
	}

	def remove(Interval interval) {
		val node = find(interval)
		node.count--
		if (node.count == 0) {
			return node.remove
		}
		if (multiplicitySensitive) {
			node.augmentRecursively
		}
		this
	}

	private def find(Interval interval) {
		var node = this
		while (!node.leaf) {
			val comparison = node.interval.compareTo(interval)
			if (comparison < 0) {
				node = node.left
			} else if (comparison > 0) {
				node = node.right
			} else { // comparison == 0
				return node
			}
		}
		throw new IllegalStateException("Reached leaf node while searching for interval to remove")
	}

	override toString() {
		if (leaf) {
			"L"
		} else {
			'''
				«IF isRed»R«ELSE»B«ENDIF» «count»«interval» : «result»
					«left»
					«right»
			'''
		}
	}

}