aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar Kristóf Marussy <kristof@marussy.com>2023-08-19 21:28:57 +0200
committerLibravatar Kristóf Marussy <kristof@marussy.com>2023-08-19 21:28:57 +0200
commit38464bf2e8e1fa9e81836329ee496ac8055736ff (patch)
tree0c4ce067f15ff6e4d2b76313e415f785c98d20c5
parentfix: abstract type chain elimination (diff)
downloadrefinery-38464bf2e8e1fa9e81836329ee496ac8055736ff.tar.gz
refinery-38464bf2e8e1fa9e81836329ee496ac8055736ff.tar.zst
refinery-38464bf2e8e1fa9e81836329ee496ac8055736ff.zip
fix: nullary model initialization
Decision trees can only handle relations with 1 level and up, so we use a special case for nullary relations.
-rw-r--r--subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/ModelInitializer.java12
-rw-r--r--subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/DecisionTree.java16
-rw-r--r--subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/MutableSeed.java28
-rw-r--r--subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/NullaryMutableSeed.java83
-rw-r--r--subprojects/language-semantics/src/test/java/tools/refinery/language/semantics/model/internal/DecisionTreeTests.java (renamed from subprojects/language-semantics/src/test/java/tools/refinery/language/semantics/model/tests/DecisionTreeTests.java)3
5 files changed, 129 insertions, 13 deletions
diff --git a/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/ModelInitializer.java b/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/ModelInitializer.java
index 12bb94c2..5f854ac3 100644
--- a/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/ModelInitializer.java
+++ b/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/ModelInitializer.java
@@ -9,7 +9,7 @@ import com.google.inject.Inject;
9import org.eclipse.collections.api.factory.primitive.ObjectIntMaps; 9import org.eclipse.collections.api.factory.primitive.ObjectIntMaps;
10import org.eclipse.collections.api.map.primitive.MutableObjectIntMap; 10import org.eclipse.collections.api.map.primitive.MutableObjectIntMap;
11import tools.refinery.language.model.problem.*; 11import tools.refinery.language.model.problem.*;
12import tools.refinery.language.semantics.model.internal.DecisionTree; 12import tools.refinery.language.semantics.model.internal.MutableSeed;
13import tools.refinery.language.utils.BuiltinSymbols; 13import tools.refinery.language.utils.BuiltinSymbols;
14import tools.refinery.language.utils.ProblemDesugarer; 14import tools.refinery.language.utils.ProblemDesugarer;
15import tools.refinery.language.utils.ProblemUtil; 15import tools.refinery.language.utils.ProblemUtil;
@@ -306,7 +306,7 @@ public class ModelInitializer {
306 } 306 }
307 307
308 private void collectEnumAssertions(EnumDeclaration enumDeclaration) { 308 private void collectEnumAssertions(EnumDeclaration enumDeclaration) {
309 var overlay = new DecisionTree(1, null); 309 var overlay = MutableSeed.of(1, null);
310 for (var literal : enumDeclaration.getLiterals()) { 310 for (var literal : enumDeclaration.getLiterals()) {
311 collectIndividualAssertions(literal); 311 collectIndividualAssertions(literal);
312 var nodeId = getNodeId(literal); 312 var nodeId = getNodeId(literal);
@@ -535,15 +535,15 @@ public class ModelInitializer {
535 return argumentList; 535 return argumentList;
536 } 536 }
537 537
538 private record RelationInfo(PartialRelation partialRelation, DecisionTree assertions, 538 private record RelationInfo(PartialRelation partialRelation, MutableSeed<TruthValue> assertions,
539 DecisionTree defaultAssertions) { 539 MutableSeed<TruthValue> defaultAssertions) {
540 public RelationInfo(String name, int arity, TruthValue value, TruthValue defaultValue) { 540 public RelationInfo(String name, int arity, TruthValue value, TruthValue defaultValue) {
541 this(new PartialRelation(name, arity), value, defaultValue); 541 this(new PartialRelation(name, arity), value, defaultValue);
542 } 542 }
543 543
544 public RelationInfo(PartialRelation partialRelation, TruthValue value, TruthValue defaultValue) { 544 public RelationInfo(PartialRelation partialRelation, TruthValue value, TruthValue defaultValue) {
545 this(partialRelation, new DecisionTree(partialRelation.arity(), value), 545 this(partialRelation, MutableSeed.of(partialRelation.arity(), value),
546 new DecisionTree(partialRelation.arity(), defaultValue)); 546 MutableSeed.of(partialRelation.arity(), defaultValue));
547 } 547 }
548 548
549 public Seed<TruthValue> toSeed(int nodeCount) { 549 public Seed<TruthValue> toSeed(int nodeCount) {
diff --git a/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/DecisionTree.java b/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/DecisionTree.java
index d693dec3..c5479859 100644
--- a/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/DecisionTree.java
+++ b/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/DecisionTree.java
@@ -7,11 +7,10 @@ package tools.refinery.language.semantics.model.internal;
7 7
8import org.eclipse.collections.api.factory.primitive.IntObjectMaps; 8import org.eclipse.collections.api.factory.primitive.IntObjectMaps;
9import tools.refinery.store.map.Cursor; 9import tools.refinery.store.map.Cursor;
10import tools.refinery.store.reasoning.seed.Seed;
11import tools.refinery.store.tuple.Tuple;
12import tools.refinery.store.representation.TruthValue; 10import tools.refinery.store.representation.TruthValue;
11import tools.refinery.store.tuple.Tuple;
13 12
14public class DecisionTree implements Seed<TruthValue> { 13class DecisionTree implements MutableSeed<TruthValue> {
15 private final int levels; 14 private final int levels;
16 15
17 private final DecisionTreeNode root; 16 private final DecisionTreeNode root;
@@ -50,26 +49,33 @@ public class DecisionTree implements Seed<TruthValue> {
50 return root.getValue(levels - 1, tuple).getTruthValue(); 49 return root.getValue(levels - 1, tuple).getTruthValue();
51 } 50 }
52 51
52 @Override
53 public void mergeValue(Tuple tuple, TruthValue truthValue) { 53 public void mergeValue(Tuple tuple, TruthValue truthValue) {
54 if (truthValue != null) { 54 if (truthValue != null) {
55 root.mergeValue(levels - 1, tuple, truthValue); 55 root.mergeValue(levels - 1, tuple, truthValue);
56 } 56 }
57 } 57 }
58 58
59 @Override
59 public void setIfMissing(Tuple tuple, TruthValue truthValue) { 60 public void setIfMissing(Tuple tuple, TruthValue truthValue) {
60 if (truthValue != null) { 61 if (truthValue != null) {
61 root.setIfMissing(levels - 1, tuple, truthValue); 62 root.setIfMissing(levels - 1, tuple, truthValue);
62 } 63 }
63 } 64 }
64 65
66 @Override
65 public void setAllMissing(TruthValue truthValue) { 67 public void setAllMissing(TruthValue truthValue) {
66 if (truthValue != null) { 68 if (truthValue != null) {
67 root.setAllMissing(truthValue); 69 root.setAllMissing(truthValue);
68 } 70 }
69 } 71 }
70 72
71 public void overwriteValues(DecisionTree values) { 73 @Override
72 root.overwriteValues(values.root); 74 public void overwriteValues(MutableSeed<TruthValue> values) {
75 if (!(values instanceof DecisionTree decisionTree)) {
76 throw new IllegalArgumentException("Incompatible overwrite: " + values);
77 }
78 root.overwriteValues(decisionTree.root);
73 } 79 }
74 80
75 public TruthValue getReducedValue() { 81 public TruthValue getReducedValue() {
diff --git a/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/MutableSeed.java b/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/MutableSeed.java
new file mode 100644
index 00000000..99019e2a
--- /dev/null
+++ b/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/MutableSeed.java
@@ -0,0 +1,28 @@
1/*
2 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.language.semantics.model.internal;
7
8import tools.refinery.store.reasoning.seed.Seed;
9import tools.refinery.store.representation.TruthValue;
10import tools.refinery.store.tuple.Tuple;
11
12public interface MutableSeed<T> extends Seed<T> {
13 void mergeValue(Tuple tuple, T value);
14
15 void setIfMissing(Tuple tuple, T value);
16
17 void setAllMissing(T value);
18
19 void overwriteValues(MutableSeed<T> other);
20
21 static MutableSeed<TruthValue> of(int levels, TruthValue initialValue) {
22 if (levels == 0) {
23 return new NullaryMutableSeed(initialValue);
24 } else {
25 return new DecisionTree(levels, initialValue);
26 }
27 }
28}
diff --git a/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/NullaryMutableSeed.java b/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/NullaryMutableSeed.java
new file mode 100644
index 00000000..80644b1f
--- /dev/null
+++ b/subprojects/language-semantics/src/main/java/tools/refinery/language/semantics/model/internal/NullaryMutableSeed.java
@@ -0,0 +1,83 @@
1/*
2 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.language.semantics.model.internal;
7
8import tools.refinery.store.map.Cursor;
9import tools.refinery.store.map.Cursors;
10import tools.refinery.store.representation.TruthValue;
11import tools.refinery.store.tuple.Tuple;
12
13class NullaryMutableSeed implements MutableSeed<TruthValue> {
14 private DecisionTreeValue value;
15
16 public NullaryMutableSeed(TruthValue reducedValue) {
17
18 value = DecisionTreeValue.fromTruthValue(reducedValue);
19 }
20
21 @Override
22 public int arity() {
23 return 0;
24 }
25
26 @Override
27 public Class<TruthValue> valueType() {
28 return TruthValue.class;
29 }
30
31 @Override
32 public TruthValue reducedValue() {
33 return value.getTruthValue();
34 }
35
36 @Override
37 public TruthValue get(Tuple key) {
38 validateKey(key);
39 return reducedValue();
40 }
41
42 private static void validateKey(Tuple key) {
43 if (key.getSize() > 0) {
44 throw new IllegalArgumentException("Invalid key: " + key);
45 }
46 }
47
48 @Override
49 public Cursor<Tuple, TruthValue> getCursor(TruthValue defaultValue, int nodeCount) {
50 if (value == DecisionTreeValue.UNSET || value.getTruthValue() == defaultValue) {
51 return Cursors.empty();
52 }
53 return Cursors.singleton(Tuple.of(), value.getTruthValue());
54 }
55
56 @Override
57 public void mergeValue(Tuple tuple, TruthValue value) {
58 this.value = DecisionTreeValue.fromTruthValue(this.value.merge(value));
59 }
60
61 @Override
62 public void setIfMissing(Tuple tuple, TruthValue value) {
63 validateKey(tuple);
64 setAllMissing(value);
65 }
66
67 @Override
68 public void setAllMissing(TruthValue value) {
69 if (this.value == DecisionTreeValue.UNSET) {
70 this.value = DecisionTreeValue.fromTruthValue(value);
71 }
72 }
73
74 @Override
75 public void overwriteValues(MutableSeed<TruthValue> other) {
76 if (!(other instanceof NullaryMutableSeed nullaryMutableSeed)) {
77 throw new IllegalArgumentException("Incompatible overwrite: " + other);
78 }
79 if (nullaryMutableSeed.value != DecisionTreeValue.UNSET) {
80 value = nullaryMutableSeed.value;
81 }
82 }
83}
diff --git a/subprojects/language-semantics/src/test/java/tools/refinery/language/semantics/model/tests/DecisionTreeTests.java b/subprojects/language-semantics/src/test/java/tools/refinery/language/semantics/model/internal/DecisionTreeTests.java
index 3c43d3bd..5d039308 100644
--- a/subprojects/language-semantics/src/test/java/tools/refinery/language/semantics/model/tests/DecisionTreeTests.java
+++ b/subprojects/language-semantics/src/test/java/tools/refinery/language/semantics/model/internal/DecisionTreeTests.java
@@ -3,10 +3,9 @@
3 * 3 *
4 * SPDX-License-Identifier: EPL-2.0 4 * SPDX-License-Identifier: EPL-2.0
5 */ 5 */
6package tools.refinery.language.semantics.model.tests; 6package tools.refinery.language.semantics.model.internal;
7 7
8import org.junit.jupiter.api.Test; 8import org.junit.jupiter.api.Test;
9import tools.refinery.language.semantics.model.internal.DecisionTree;
10import tools.refinery.store.representation.TruthValue; 9import tools.refinery.store.representation.TruthValue;
11import tools.refinery.store.tuple.Tuple; 10import tools.refinery.store.tuple.Tuple;
12 11