aboutsummaryrefslogtreecommitdiffstats
path: root/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalHullAggregatorOperator.xtend
blob: ce48eca18708b8ae61b05e2f82ddbfe0b1a3ecf3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval

import java.math.BigDecimal
import java.math.MathContext
import java.util.SortedMap
import java.util.TreeMap
import java.util.stream.Stream
import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.IMultisetAggregationOperator

abstract class IntervalHullAggregatorOperator<T extends Comparable<T>> implements IMultisetAggregationOperator<T, SortedMap<T, Integer>, Interval> {
	protected new() {
	}
	
	override getName() {
		"intervalHull"
	}

	override getShortDescription() {
		"Calculates the interval hull of a set of numbers"
	}

	override createNeutral() {
		new TreeMap
	}

	override getAggregate(SortedMap<T, Integer> result) {
		if (result.neutral) {
			Interval.EMPTY
		} else {
			toInterval(result.firstKey, result.lastKey)
		}
	}
	
	protected abstract def BigDecimal toBigDecimal(T value, MathContext mc)
	
	private def toInterval(T min, T max) {
		Interval.of(min.toBigDecimal(Interval.ROUND_DOWN), max.toBigDecimal(Interval.ROUND_UP))
	}

	override isNeutral(SortedMap<T, Integer> result) {
		result.empty
	}

	override update(SortedMap<T, Integer> oldResult, T updateValue, boolean isInsertion) {
		if (isInsertion) {
			oldResult.compute(updateValue) [ key, value |
				if (value === null) {
					1
				} else if (value > 0) {
					value + 1
				} else {
					throw new IllegalStateException("Invalid count: " + value)
				}
			]
		} else {
			oldResult.compute(updateValue) [ key, value |
				if (value === 1) {
					null
				} else if (value > 1) {
					value - 1
				} else {
					throw new IllegalStateException("Invalid count: " + value)
				}
			]
		}
		oldResult
	}

	override aggregateStream(Stream<T> stream) {
		val iterator = stream.iterator
		if (!iterator.hasNext) {
			return Interval.EMPTY
		}
		var min = iterator.next
		var max = min
		while (iterator.hasNext) {
			val element = iterator.next
			if (element.compareTo(min) < 0) {
				min = element
			}
			if (element.compareTo(max) > 0) {
				max = element
			}
		}
		toInterval(min, max)
	}
}