aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/seed/MapBasedSeed.java
blob: 8b1c3bb3d3011a3194642bf54405ed21de6ff2ff (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/*
 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/>
 *
 * SPDX-License-Identifier: EPL-2.0
 */
package tools.refinery.store.reasoning.seed;

import tools.refinery.store.map.Cursor;
import tools.refinery.store.map.Cursors;
import tools.refinery.store.tuple.Tuple;

import java.util.Map;
import java.util.Objects;

record MapBasedSeed<T>(int arity, Class<T> valueType, T reducedValue, Map<Tuple, T> map) implements Seed<T> {
	@Override
	public T get(Tuple key) {
		var value = map.get(key);
		return value == null ? reducedValue : value;
	}

	@Override
	public Cursor<Tuple, T> getCursor(T defaultValue, int nodeCount) {
		if (Objects.equals(defaultValue, reducedValue)) {
			return Cursors.of(map);
		}
		return new CartesianProductCursor<>(arity, nodeCount, reducedValue, defaultValue, map);
	}

	private static class CartesianProductCursor<T> implements Cursor<Tuple, T> {
		private final int nodeCount;
		private final T reducedValue;
		private final T defaultValue;
		private final Map<Tuple, T> map;
		private final int[] counter;
		private State state = State.INITIAL;
		private Tuple key;
		private T value;

		private CartesianProductCursor(int arity, int nodeCount, T reducedValue, T defaultValue, Map<Tuple, T> map) {
			this.nodeCount = nodeCount;
			this.reducedValue = reducedValue;
			this.defaultValue = defaultValue;
			this.map = map;
			counter = new int[arity];
		}

		@Override
		public Tuple getKey() {
			return key;
		}

		@Override
		public T getValue() {
			return value;
		}

		@Override
		public boolean isTerminated() {
			return state == State.TERMINATED;
		}

		@Override
		public boolean move() {
			return switch (state) {
				case INITIAL -> {
					state = State.STARTED;
					yield checkValue() || moveToNext();
				}
				case STARTED -> moveToNext();
				case TERMINATED -> false;
			};
		}

		private boolean moveToNext() {
			do {
				increment();
			} while (state != State.TERMINATED && !checkValue());
			return state != State.TERMINATED;
		}

		private void increment() {
			int i = counter.length - 1;
			while (i >= 0) {
				counter[i]++;
				if (counter[i] < nodeCount) {
					return;
				}
				counter[i] = 0;
				i--;
			}
			state = State.TERMINATED;
		}

		private boolean checkValue() {
			key = Tuple.of(counter);
			var valueInMap = map.get(key);
			if (Objects.equals(valueInMap, defaultValue)) {
				return false;
			}
			value = valueInMap == null ? reducedValue : valueInMap;
			return true;
		}

		private enum State {
			INITIAL,
			STARTED,
			TERMINATED
		}
	}
}