aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/ModelInitializer.java
blob: 06b8ad77fd8fe533cb382cb301b3f4c6e3ab404e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/*
 * SPDX-FileCopyrightText: 2021-2023 The Refinery Authors <https://refinery.tools/>
 *
 * SPDX-License-Identifier: EPL-2.0
 */
package tools.refinery.language.semantics.model;

import com.google.inject.Inject;
import org.eclipse.collections.api.factory.primitive.ObjectIntMaps;
import org.eclipse.collections.api.map.primitive.MutableObjectIntMap;
import tools.refinery.language.model.problem.*;
import tools.refinery.language.semantics.model.internal.DecisionTree;
import tools.refinery.language.utils.ProblemDesugarer;
import tools.refinery.language.utils.RelationInfo;
import tools.refinery.store.representation.Symbol;
import tools.refinery.store.representation.TruthValue;
import tools.refinery.store.tuple.Tuple;

import java.util.HashMap;
import java.util.Map;

public class ModelInitializer {
	@Inject
	private ProblemDesugarer desugarer;

	private final MutableObjectIntMap<Node> nodeTrace = ObjectIntMaps.mutable.empty();

	private final Map<tools.refinery.language.model.problem.Relation, Symbol<TruthValue>> relationTrace =
			new HashMap<>();

	private int nodeCount = 0;

	public void createModel(Problem problem) {
		var builtinSymbols = desugarer.getBuiltinSymbols(problem).orElseThrow(() -> new IllegalArgumentException(
				"Problem has no builtin library"));
		var collectedSymbols = desugarer.collectSymbols(problem);
		for (var node : collectedSymbols.nodes().keySet()) {
			nodeTrace.put(node, nodeCount);
			nodeCount += 1;
		}
		for (var pair : collectedSymbols.relations().entrySet()) {
			var relation = pair.getKey();
			var relationInfo = pair.getValue();
			var isEqualsRelation = relation == builtinSymbols.equals();
			var decisionTree = mergeAssertions(relationInfo, isEqualsRelation);
			var defaultValue = isEqualsRelation ? TruthValue.FALSE : TruthValue.UNKNOWN;
			relationTrace.put(relation, Symbol.of(
					relationInfo.name(), relationInfo.arity(), TruthValue.class, defaultValue));
		}
	}

	private DecisionTree mergeAssertions(RelationInfo relationInfo, boolean isEqualsRelation) {
		var arity = relationInfo.arity();
		var defaultAssertions = new DecisionTree(arity, isEqualsRelation ? null : TruthValue.UNKNOWN);
		var assertions = new DecisionTree(arity);
		for (var assertion : relationInfo.assertions()) {
			var tuple = getTuple(assertion);
			var value = getTruthValue(assertion.getValue());
			if (assertion.isDefault()) {
				defaultAssertions.mergeValue(tuple, value);
			} else {
				assertions.mergeValue(tuple, value);
			}
		}
		defaultAssertions.overwriteValues(assertions);
		if (isEqualsRelation) {
			for (int i = 0; i < nodeCount; i++) {
				defaultAssertions.setIfMissing(Tuple.of(i, i), TruthValue.TRUE);
			}
			defaultAssertions.setAllMissing(TruthValue.FALSE);
		}
		return defaultAssertions;
	}

	private Tuple getTuple(Assertion assertion) {
		var arguments = assertion.getArguments();
		int arity = arguments.size();
		var nodes = new int[arity];
		for (int i = 0; i < arity; i++) {
			var argument = arguments.get(i);
			if (argument instanceof NodeAssertionArgument nodeArgument) {
				nodes[i] = nodeTrace.getOrThrow(nodeArgument.getNode());
			} else if (argument instanceof WildcardAssertionArgument) {
				nodes[i] = -1;
			} else {
				throw new IllegalArgumentException("Unknown assertion argument: " + argument);
			}
		}
		return Tuple.of(nodes);
	}

	private static TruthValue getTruthValue(Expr expr) {
		if (!(expr instanceof LogicConstant logicAssertionValue)) {
			return TruthValue.ERROR;
		}
		return switch (logicAssertionValue.getLogicValue()) {
			case TRUE -> TruthValue.TRUE;
			case FALSE -> TruthValue.FALSE;
			case UNKNOWN -> TruthValue.UNKNOWN;
			case ERROR -> TruthValue.ERROR;
		};
	}
}