aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/translator/multiobject/MultiObjectInitializer.java
blob: 084bf6f9b497aae505ec89eb0ab8407f2b281606 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
/*
 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/>
 *
 * SPDX-License-Identifier: EPL-2.0
 */
package tools.refinery.store.reasoning.translator.multiobject;

import org.jetbrains.annotations.NotNull;
import tools.refinery.store.model.Model;
import tools.refinery.store.reasoning.ReasoningAdapter;
import tools.refinery.store.reasoning.refinement.PartialModelInitializer;
import tools.refinery.store.reasoning.seed.ModelSeed;
import tools.refinery.store.representation.Symbol;
import tools.refinery.store.representation.TruthValue;
import tools.refinery.store.representation.cardinality.CardinalityInterval;
import tools.refinery.store.representation.cardinality.CardinalityIntervals;
import tools.refinery.store.tuple.Tuple;

import java.util.Arrays;

class MultiObjectInitializer implements PartialModelInitializer {
	private final Symbol<CardinalityInterval> countSymbol;

	public MultiObjectInitializer(Symbol<CardinalityInterval> countSymbol) {
		this.countSymbol = countSymbol;
	}

	@Override
	public void initialize(Model model, ModelSeed modelSeed) {
		var intervals = initializeIntervals(modelSeed);
		initializeExists(intervals, modelSeed);
		initializeEquals(intervals, modelSeed);
		var countInterpretation = model.getInterpretation(countSymbol);
		for (int i = 0; i < intervals.length; i++) {
			var interval = intervals[i];
			if (interval.isEmpty()) {
				throw new IllegalArgumentException("Inconsistent existence or equality for node " + i);
			}
			countInterpretation.put(Tuple.of(i), intervals[i]);
		}
	}

	@NotNull
	private CardinalityInterval[] initializeIntervals(ModelSeed modelSeed) {
		var intervals = new CardinalityInterval[modelSeed.getNodeCount()];
		if (modelSeed.containsSeed(MultiObjectTranslator.COUNT_SYMBOL)) {
			Arrays.fill(intervals, CardinalityIntervals.ONE);
			var cursor = modelSeed.getCursor(MultiObjectTranslator.COUNT_SYMBOL, CardinalityIntervals.ONE);
			while (cursor.move()) {
				int i = cursor.getKey().get(0);
				checkNodeId(intervals, i);
				intervals[i] = cursor.getValue();
			}
		} else {
			Arrays.fill(intervals, CardinalityIntervals.SET);
			if (!modelSeed.containsSeed(ReasoningAdapter.EXISTS_SYMBOL) ||
				!modelSeed.containsSeed(ReasoningAdapter.EQUALS_SYMBOL)) {
				throw new IllegalArgumentException("Seed for %s and %s is required if there is no seed for %s"
						.formatted(ReasoningAdapter.EXISTS_SYMBOL, ReasoningAdapter.EQUALS_SYMBOL,
								MultiObjectTranslator.COUNT_SYMBOL));
			}
		}
		return intervals;
	}

	private void initializeExists(CardinalityInterval[] intervals, ModelSeed modelSeed) {
		if (!modelSeed.containsSeed(ReasoningAdapter.EXISTS_SYMBOL)) {
			return;
		}
		var cursor = modelSeed.getCursor(ReasoningAdapter.EXISTS_SYMBOL, TruthValue.UNKNOWN);
		while (cursor.move()) {
			int i = cursor.getKey().get(0);
			checkNodeId(intervals, i);
			switch (cursor.getValue()) {
			case TRUE -> intervals[i] = intervals[i].meet(CardinalityIntervals.SOME);
			case FALSE -> intervals[i] = intervals[i].meet(CardinalityIntervals.NONE);
			case ERROR -> throw new IllegalArgumentException("Inconsistent existence for node " + i);
			default -> throw new IllegalArgumentException("Invalid existence truth value %s for node %d"
					.formatted(cursor.getValue(), i));
			}
		}
	}

	private void initializeEquals(CardinalityInterval[] intervals, ModelSeed modelSeed) {
		if (!modelSeed.containsSeed(ReasoningAdapter.EQUALS_SYMBOL)) {
			return;
		}
		var seed = modelSeed.getSeed(ReasoningAdapter.EQUALS_SYMBOL);
		var cursor = seed.getCursor(TruthValue.FALSE, modelSeed.getNodeCount());
		while (cursor.move()) {
			var key = cursor.getKey();
			int i = key.get(0);
			int otherIndex = key.get(1);
			if (i != otherIndex) {
				throw new IllegalArgumentException("Off-diagonal equivalence (%d, %d) is not permitted"
						.formatted(i, otherIndex));
			}
			checkNodeId(intervals, i);
			switch (cursor.getValue()) {
			case TRUE -> intervals[i] = intervals[i].meet(CardinalityIntervals.LONE);
			case UNKNOWN -> {
				// Nothing do to, {@code intervals} is initialized with unknown equality.
			}
			case ERROR -> throw new IllegalArgumentException("Inconsistent equality for node " + i);
			default -> throw new IllegalArgumentException("Invalid equality truth value %s for node %d"
					.formatted(cursor.getValue(), i));
			}
		}
		for (int i = 0; i < intervals.length; i++) {
			if (seed.get(Tuple.of(i, i)) == TruthValue.FALSE) {
				throw new IllegalArgumentException("Inconsistent equality for node " + i);
			}
		}
	}

	private void checkNodeId(CardinalityInterval[] intervals, int nodeId) {
		if (nodeId < 0 || nodeId >= intervals.length) {
			throw new IllegalArgumentException("Expected node id %d to be lower than model size %d"
					.formatted(nodeId, intervals.length));
		}
	}
}