aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/seed/ModelSeed.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/seed/ModelSeed.java')
-rw-r--r--subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/seed/ModelSeed.java95
1 files changed, 95 insertions, 0 deletions
diff --git a/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/seed/ModelSeed.java b/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/seed/ModelSeed.java
new file mode 100644
index 00000000..e6b3eaf9
--- /dev/null
+++ b/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/seed/ModelSeed.java
@@ -0,0 +1,95 @@
1/*
2 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.store.reasoning.seed;
7
8import tools.refinery.store.map.Cursor;
9import tools.refinery.store.reasoning.representation.AnyPartialSymbol;
10import tools.refinery.store.reasoning.representation.PartialSymbol;
11import tools.refinery.store.tuple.Tuple;
12
13import java.util.Collections;
14import java.util.LinkedHashMap;
15import java.util.Map;
16import java.util.Set;
17import java.util.function.Consumer;
18
19public class ModelSeed {
20 private final int nodeCount;
21 private final Map<AnyPartialSymbol, Seed<?>> seeds;
22
23 private ModelSeed(int nodeCount, Map<AnyPartialSymbol, Seed<?>> seeds) {
24 this.nodeCount = nodeCount;
25 this.seeds = seeds;
26 }
27
28 public int getNodeCount() {
29 return nodeCount;
30 }
31
32 public <A> Seed<A> getSeed(PartialSymbol<A, ?> partialSymbol) {
33 var seed = seeds.get(partialSymbol);
34 if (seed == null) {
35 throw new IllegalArgumentException("No seed for partial symbol " + partialSymbol);
36 }
37 // The builder makes sure only well-typed seeds can be added.
38 @SuppressWarnings("unchecked")
39 var typedSeed = (Seed<A>) seed;
40 return typedSeed;
41 }
42
43 public boolean containsSeed(AnyPartialSymbol symbol) {
44 return seeds.containsKey(symbol);
45 }
46
47 public Set<AnyPartialSymbol> getSeededSymbols() {
48 return Collections.unmodifiableSet(seeds.keySet());
49 }
50
51 public <A> Cursor<Tuple, A> getCursor(PartialSymbol<A, ?> partialSymbol, A defaultValue) {
52 return getSeed(partialSymbol).getCursor(defaultValue, nodeCount);
53 }
54
55 public static Builder builder(int nodeCount) {
56 return new Builder(nodeCount);
57 }
58
59 public static class Builder {
60 private final int nodeCount;
61 private final Map<AnyPartialSymbol, Seed<?>> seeds = new LinkedHashMap<>();
62
63 private Builder(int nodeCount) {
64 if (nodeCount < 0) {
65 throw new IllegalArgumentException("Node count must not be negative");
66 }
67 this.nodeCount = nodeCount;
68 }
69
70 public <A> Builder seed(PartialSymbol<A, ?> partialSymbol, Seed<A> seed) {
71 if (seed.arity() != partialSymbol.arity()) {
72 throw new IllegalStateException("Expected seed of arity %d for partial symbol %s, but got %d instead"
73 .formatted(partialSymbol.arity(), partialSymbol, seed.arity()));
74 }
75 if (!seed.valueType().equals(partialSymbol.abstractDomain().abstractType())) {
76 throw new IllegalStateException("Expected seed of type %s for partial symbol %s, but got %s instead"
77 .formatted(partialSymbol.abstractDomain().abstractType(), partialSymbol, seed.valueType()));
78 }
79 if (seeds.put(partialSymbol, seed) != null) {
80 throw new IllegalArgumentException("Duplicate seed for partial symbol " + partialSymbol);
81 }
82 return this;
83 }
84
85 public <A> Builder seed(PartialSymbol<A, ?> partialSymbol, Consumer<Seed.Builder<A>> callback) {
86 var builder = Seed.builder(partialSymbol);
87 callback.accept(builder);
88 return seed(partialSymbol, builder.build());
89 }
90
91 public ModelSeed build() {
92 return new ModelSeed(nodeCount, Collections.unmodifiableMap(seeds));
93 }
94 }
95}