aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/logic/src/main/java/tools/refinery/logic/term/real/RealSumAggregator.java
blob: 4b09018859072092b4ab27d41f4b19ac66a157fb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
/*
 * SPDX-FileCopyrightText: 2021-2023 The Refinery Authors <https://refinery.tools/>
 *
 * SPDX-License-Identifier: EPL-2.0
 */
package tools.refinery.logic.term.real;

import tools.refinery.logic.term.StatefulAggregate;
import tools.refinery.logic.term.StatefulAggregator;

import java.util.Map;
import java.util.TreeMap;

public final class RealSumAggregator implements StatefulAggregator<Double, Double> {
	public static final RealSumAggregator INSTANCE = new RealSumAggregator();

	private RealSumAggregator() {
	}

	@Override
	public Class<Double> getResultType() {
		return Double.class;
	}

	@Override
	public Class<Double> getInputType() {
		return Double.class;
	}

	@Override
	public StatefulAggregate<Double, Double> createEmptyAggregate() {
		return new Aggregate();
	}

	@Override
	public Double getEmptyResult() {
		return 0d;
	}

	private static class Aggregate implements StatefulAggregate<Double, Double> {
		private final Map<Double, Integer> values;

		public Aggregate() {
			values = new TreeMap<>();
		}

		private Aggregate(Aggregate other) {
			values = new TreeMap<>(other.values);
		}

		@Override
		public void add(Double value) {
			values.compute(value, (ignoredValue, currentCount) -> currentCount == null ? 1 : currentCount + 1);
		}

		@Override
		public void remove(Double value) {
			values.compute(value, (theValue, currentCount) -> {
				if (currentCount == null || currentCount <= 0) {
					throw new IllegalStateException("Invalid count %d for value %f".formatted(currentCount, theValue));
				}
				return currentCount.equals(1) ? null : currentCount - 1;
			});
		}

		@Override
		public Double getResult() {
			return values.entrySet()
					.stream()
					.mapToDouble(entry -> entry.getKey() * entry.getValue())
					.reduce(Double::sum)
					.orElse(0d);
		}

		@Override
		public boolean isEmpty() {
			return values.isEmpty();
		}

		@Override
		public StatefulAggregate<Double, Double> deepCopy() {
			return new Aggregate(this);
		}

		@Override
		public boolean contains(Double value) {
			return values.containsKey(value);
		}
	}
}