aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar Kristóf Marussy <kristof@marussy.com>2024-05-31 19:22:46 +0200
committerLibravatar Kristóf Marussy <kristof@marussy.com>2024-06-01 20:17:47 +0200
commit8a62e425b77c43a541d58d4615a13cd182d12a81 (patch)
tree2f84b7d5e3f3dbc4dbda7da06ee93732418024ce
parentfix(reasoning): candidate rounding mode (diff)
downloadrefinery-8a62e425b77c43a541d58d4615a13cd182d12a81.tar.gz
refinery-8a62e425b77c43a541d58d4615a13cd182d12a81.tar.zst
refinery-8a62e425b77c43a541d58d4615a13cd182d12a81.zip
feat: generate multiple solutions
Switch to partial interpretation based neighborhood calculation when multiple models are request to avoid returning isomorphic models.
-rw-r--r--subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java52
-rw-r--r--subprojects/generator/src/main/java/tools/refinery/generator/ModelGenerator.java43
-rw-r--r--subprojects/generator/src/main/java/tools/refinery/generator/ModelGeneratorFactory.java20
-rw-r--r--subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/interpretation/PartialNeighbourhoodCalculator.java63
-rw-r--r--subprojects/store/build.gradle.kts2
-rw-r--r--subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java171
-rw-r--r--subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/LazyNeighbourhoodCalculator.java195
-rw-r--r--subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java117
8 files changed, 327 insertions, 336 deletions
diff --git a/subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java b/subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java
index b1c28df0..0fb97316 100644
--- a/subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java
+++ b/subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java
@@ -1,5 +1,5 @@
1/* 1/*
2 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/> 2 * SPDX-FileCopyrightText: 2023-2024 The Refinery Authors <https://refinery.tools/>
3 * 3 *
4 * SPDX-License-Identifier: EPL-2.0 4 * SPDX-License-Identifier: EPL-2.0
5 */ 5 */
@@ -17,9 +17,12 @@ import java.io.IOException;
17import java.util.ArrayList; 17import java.util.ArrayList;
18import java.util.List; 18import java.util.List;
19import java.util.Map; 19import java.util.Map;
20import java.util.regex.Pattern;
20 21
21@Parameters(commandDescription = "Generate a model from a partial model") 22@Parameters(commandDescription = "Generate a model from a partial model")
22public class GenerateCommand { 23public class GenerateCommand {
24 private static final Pattern EXTENSION_REGEX = Pattern.compile("(.+)\\.([^./\\\\]+)");
25
23 @Inject 26 @Inject
24 private ProblemLoader loader; 27 private ProblemLoader loader;
25 28
@@ -31,6 +34,7 @@ public class GenerateCommand {
31 private List<String> scopes = new ArrayList<>(); 34 private List<String> scopes = new ArrayList<>();
32 private List<String> overrideScopes = new ArrayList<>(); 35 private List<String> overrideScopes = new ArrayList<>();
33 private long randomSeed = 1; 36 private long randomSeed = 1;
37 private int count = 1;
34 38
35 @Parameter(description = "input path", required = true) 39 @Parameter(description = "input path", required = true)
36 public void setInputPath(String inputPath) { 40 public void setInputPath(String inputPath) {
@@ -57,21 +61,47 @@ public class GenerateCommand {
57 this.randomSeed = randomSeed; 61 this.randomSeed = randomSeed;
58 } 62 }
59 63
64 @Parameter(names = {"-solution-number", "-n"}, description = "Maximum number of solutions")
65 public void setCount(int count) {
66 if (count <= 0) {
67 throw new IllegalArgumentException("Count must be positive");
68 }
69 this.count = count;
70 }
71
60 public void run() throws IOException { 72 public void run() throws IOException {
73 if (count > 1 && isStandardStream(outputPath)) {
74 throw new IllegalArgumentException("Must provide output path if count is larger than 1");
75 }
61 loader.extraPath(System.getProperty("user.dir")); 76 loader.extraPath(System.getProperty("user.dir"));
62 var problem = isStandardStream(inputPath) ? loader.loadStream(System.in) : loader.loadFile(inputPath); 77 var problem = isStandardStream(inputPath) ? loader.loadStream(System.in) : loader.loadFile(inputPath);
63 problem = loader.loadScopeConstraints(problem, scopes, overrideScopes); 78 problem = loader.loadScopeConstraints(problem, scopes, overrideScopes);
79 generatorFactory.partialInterpretationBasedNeighbourhoods(count >= 2);
64 var generator = generatorFactory.createGenerator(problem); 80 var generator = generatorFactory.createGenerator(problem);
65 generator.setRandomSeed(randomSeed); 81 generator.setRandomSeed(randomSeed);
82 generator.setMaxNumberOfSolutions(count);
66 generator.generate(); 83 generator.generate();
67 var solution = generator.serializeSolution();
68 var solutionResource = solution.eResource();
69 var saveOptions = Map.of(); 84 var saveOptions = Map.of();
70 if (isStandardStream(outputPath)) { 85 if (count == 1) {
71 printSolution(solutionResource, saveOptions); 86 var solution = generator.serializeSolution();
87 var solutionResource = solution.eResource();
88 if (isStandardStream(outputPath)) {
89 printSolution(solutionResource, saveOptions);
90 } else {
91 try (var outputStream = new FileOutputStream(outputPath)) {
92 solutionResource.save(outputStream, saveOptions);
93 }
94 }
72 } else { 95 } else {
73 try (var outputStream = new FileOutputStream(outputPath)) { 96 int solutionCount = generator.getSolutionCount();
74 solutionResource.save(outputStream, saveOptions); 97 for (int i = 0; i < solutionCount; i++) {
98 generator.loadSolution(i);
99 var solution = generator.serializeSolution();
100 var solutionResource = solution.eResource();
101 var pathWithIndex = getFileNameWithIndex(outputPath, i + 1);
102 try (var outputStream = new FileOutputStream(pathWithIndex)) {
103 solutionResource.save(outputStream, saveOptions);
104 }
75 } 105 }
76 } 106 }
77 } 107 }
@@ -85,4 +115,12 @@ public class GenerateCommand {
85 private void printSolution(Resource solutionResource, Map<?, ?> saveOptions) throws IOException { 115 private void printSolution(Resource solutionResource, Map<?, ?> saveOptions) throws IOException {
86 solutionResource.save(System.out, saveOptions); 116 solutionResource.save(System.out, saveOptions);
87 } 117 }
118
119 private String getFileNameWithIndex(String simpleName, int index) {
120 var match = EXTENSION_REGEX.matcher(simpleName);
121 if (match.matches()) {
122 return "%s_%03d.%s".formatted(match.group(1), index, match.group(2));
123 }
124 return "%s_%03d".formatted(simpleName, index);
125 }
88} 126}
diff --git a/subprojects/generator/src/main/java/tools/refinery/generator/ModelGenerator.java b/subprojects/generator/src/main/java/tools/refinery/generator/ModelGenerator.java
index 36190b76..8dff5622 100644
--- a/subprojects/generator/src/main/java/tools/refinery/generator/ModelGenerator.java
+++ b/subprojects/generator/src/main/java/tools/refinery/generator/ModelGenerator.java
@@ -1,5 +1,5 @@
1/* 1/*
2 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/> 2 * SPDX-FileCopyrightText: 2023-2024 The Refinery Authors <https://refinery.tools/>
3 * 3 *
4 * SPDX-License-Identifier: EPL-2.0 4 * SPDX-License-Identifier: EPL-2.0
5 */ 5 */
@@ -11,6 +11,7 @@ import tools.refinery.language.semantics.ProblemTrace;
11import tools.refinery.language.semantics.SolutionSerializer; 11import tools.refinery.language.semantics.SolutionSerializer;
12import tools.refinery.logic.AbstractValue; 12import tools.refinery.logic.AbstractValue;
13import tools.refinery.store.dse.strategy.BestFirstStoreManager; 13import tools.refinery.store.dse.strategy.BestFirstStoreManager;
14import tools.refinery.store.dse.transition.statespace.SolutionStore;
14import tools.refinery.store.map.Version; 15import tools.refinery.store.map.Version;
15import tools.refinery.store.model.ModelStore; 16import tools.refinery.store.model.ModelStore;
16import tools.refinery.store.reasoning.interpretation.PartialInterpretation; 17import tools.refinery.store.reasoning.interpretation.PartialInterpretation;
@@ -22,7 +23,8 @@ public class ModelGenerator extends ModelFacade {
22 private final Version initialVersion; 23 private final Version initialVersion;
23 private final Provider<SolutionSerializer> solutionSerializerProvider; 24 private final Provider<SolutionSerializer> solutionSerializerProvider;
24 private long randomSeed = 1; 25 private long randomSeed = 1;
25 private boolean lastGenerationSuccessful; 26 private int maxNumberOfSolutions = 1;
27 private SolutionStore solutionStore;
26 28
27 ModelGenerator(ProblemTrace problemTrace, ModelStore store, ModelSeed modelSeed, 29 ModelGenerator(ProblemTrace problemTrace, ModelStore store, ModelSeed modelSeed,
28 Provider<SolutionSerializer> solutionSerializerProvider) { 30 Provider<SolutionSerializer> solutionSerializerProvider) {
@@ -37,26 +39,51 @@ public class ModelGenerator extends ModelFacade {
37 39
38 public void setRandomSeed(long randomSeed) { 40 public void setRandomSeed(long randomSeed) {
39 this.randomSeed = randomSeed; 41 this.randomSeed = randomSeed;
40 this.lastGenerationSuccessful = false; 42 this.solutionStore = null;
41 } 43 }
42 44
45 public int getMaxNumberOfSolutions() {
46 return maxNumberOfSolutions;
47 }
48
49 public void setMaxNumberOfSolutions(int maxNumberOfSolutions) {
50 this.maxNumberOfSolutions = maxNumberOfSolutions;
51 this.solutionStore = null;
52 }
53
54 public int getSolutionCount() {
55 if (!isLastGenerationSuccessful()) {
56 return 0;
57 }
58 return this.solutionStore.getSolutions().size();
59 }
60
61 public void loadSolution(int index) {
62 if (index >= getSolutionCount()) {
63 throw new IndexOutOfBoundsException("No such solution");
64 }
65 getModel().restore(solutionStore.getSolutions().get(index).version());
66 }
67
68 // It makes more sense to check for success than for failure.
69 @SuppressWarnings("BooleanMethodIsAlwaysInverted")
43 public boolean isLastGenerationSuccessful() { 70 public boolean isLastGenerationSuccessful() {
44 return lastGenerationSuccessful; 71 return solutionStore != null;
45 } 72 }
46 73
47 // This method only makes sense if it returns {@code true} on success. 74 // This method only makes sense if it returns {@code true} on success.
48 @SuppressWarnings("BooleanMethodIsAlwaysInverted") 75 @SuppressWarnings("BooleanMethodIsAlwaysInverted")
49 public boolean tryGenerate() { 76 public boolean tryGenerate() {
50 lastGenerationSuccessful = false; 77 solutionStore = null;
51 randomSeed++; 78 randomSeed++;
52 var bestFirst = new BestFirstStoreManager(getModelStore(), 1); 79 var bestFirst = new BestFirstStoreManager(getModelStore(), maxNumberOfSolutions);
53 bestFirst.startExploration(initialVersion, randomSeed); 80 bestFirst.startExploration(initialVersion, randomSeed);
54 var solutions = bestFirst.getSolutionStore().getSolutions(); 81 var solutions = bestFirst.getSolutionStore().getSolutions();
55 if (solutions.isEmpty()) { 82 if (solutions.isEmpty()) {
56 return false; 83 return false;
57 } 84 }
58 getModel().restore(solutions.getFirst().version()); 85 getModel().restore(solutions.getFirst().version());
59 lastGenerationSuccessful = true; 86 solutionStore = bestFirst.getSolutionStore();
60 return true; 87 return true;
61 } 88 }
62 89
@@ -80,7 +107,7 @@ public class ModelGenerator extends ModelFacade {
80 } 107 }
81 108
82 private void checkSuccessfulGeneration() { 109 private void checkSuccessfulGeneration() {
83 if (!lastGenerationSuccessful) { 110 if (!isLastGenerationSuccessful()) {
84 throw new IllegalStateException("No generated model is available"); 111 throw new IllegalStateException("No generated model is available");
85 } 112 }
86 } 113 }
diff --git a/subprojects/generator/src/main/java/tools/refinery/generator/ModelGeneratorFactory.java b/subprojects/generator/src/main/java/tools/refinery/generator/ModelGeneratorFactory.java
index 587601f2..ec273cf4 100644
--- a/subprojects/generator/src/main/java/tools/refinery/generator/ModelGeneratorFactory.java
+++ b/subprojects/generator/src/main/java/tools/refinery/generator/ModelGeneratorFactory.java
@@ -15,8 +15,11 @@ import tools.refinery.store.dse.transition.DesignSpaceExplorationAdapter;
15import tools.refinery.store.model.ModelStore; 15import tools.refinery.store.model.ModelStore;
16import tools.refinery.store.query.interpreter.QueryInterpreterAdapter; 16import tools.refinery.store.query.interpreter.QueryInterpreterAdapter;
17import tools.refinery.store.reasoning.ReasoningAdapter; 17import tools.refinery.store.reasoning.ReasoningAdapter;
18import tools.refinery.store.reasoning.interpretation.PartialNeighbourhoodCalculator;
18import tools.refinery.store.reasoning.literal.Concreteness; 19import tools.refinery.store.reasoning.literal.Concreteness;
20import tools.refinery.store.statecoding.StateCodeCalculatorFactory;
19import tools.refinery.store.statecoding.StateCoderAdapter; 21import tools.refinery.store.statecoding.StateCoderAdapter;
22import tools.refinery.store.statecoding.neighbourhood.NeighbourhoodCalculator;
20import tools.refinery.store.util.CancellationToken; 23import tools.refinery.store.util.CancellationToken;
21 24
22import java.util.Collection; 25import java.util.Collection;
@@ -33,6 +36,8 @@ public final class ModelGeneratorFactory {
33 36
34 private boolean debugPartialInterpretations; 37 private boolean debugPartialInterpretations;
35 38
39 private boolean partialInterpretationBasedNeighbourhoods;
40
36 public ModelGeneratorFactory cancellationToken(CancellationToken cancellationToken) { 41 public ModelGeneratorFactory cancellationToken(CancellationToken cancellationToken) {
37 this.cancellationToken = cancellationToken; 42 this.cancellationToken = cancellationToken;
38 return this; 43 return this;
@@ -43,6 +48,10 @@ public final class ModelGeneratorFactory {
43 return this; 48 return this;
44 } 49 }
45 50
51 public void partialInterpretationBasedNeighbourhoods(boolean partialInterpretationBasedNeighbourhoods) {
52 this.partialInterpretationBasedNeighbourhoods = partialInterpretationBasedNeighbourhoods;
53 }
54
46 public ModelGenerator createGenerator(Problem problem) { 55 public ModelGenerator createGenerator(Problem problem) {
47 var initializer = initializerProvider.get(); 56 var initializer = initializerProvider.get();
48 initializer.readProblem(problem); 57 initializer.readProblem(problem);
@@ -51,7 +60,8 @@ public final class ModelGeneratorFactory {
51 .cancellationToken(cancellationToken) 60 .cancellationToken(cancellationToken)
52 .with(QueryInterpreterAdapter.builder()) 61 .with(QueryInterpreterAdapter.builder())
53 .with(PropagationAdapter.builder()) 62 .with(PropagationAdapter.builder())
54 .with(StateCoderAdapter.builder()) 63 .with(StateCoderAdapter.builder()
64 .stateCodeCalculatorFactory(getStateCoderCalculatorFactory()))
55 .with(DesignSpaceExplorationAdapter.builder()) 65 .with(DesignSpaceExplorationAdapter.builder())
56 .with(ReasoningAdapter.builder() 66 .with(ReasoningAdapter.builder()
57 .requiredInterpretations(getRequiredInterpretations())); 67 .requiredInterpretations(getRequiredInterpretations()));
@@ -62,7 +72,13 @@ public final class ModelGeneratorFactory {
62 } 72 }
63 73
64 private Collection<Concreteness> getRequiredInterpretations() { 74 private Collection<Concreteness> getRequiredInterpretations() {
65 return debugPartialInterpretations ? Set.of(Concreteness.PARTIAL, Concreteness.CANDIDATE) : 75 return debugPartialInterpretations || partialInterpretationBasedNeighbourhoods ?
76 Set.of(Concreteness.PARTIAL, Concreteness.CANDIDATE) :
66 Set.of(Concreteness.CANDIDATE); 77 Set.of(Concreteness.CANDIDATE);
67 } 78 }
79
80 private StateCodeCalculatorFactory getStateCoderCalculatorFactory() {
81 return partialInterpretationBasedNeighbourhoods ? PartialNeighbourhoodCalculator.FACTORY :
82 NeighbourhoodCalculator::new;
83 }
68} 84}
diff --git a/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/interpretation/PartialNeighbourhoodCalculator.java b/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/interpretation/PartialNeighbourhoodCalculator.java
new file mode 100644
index 00000000..859cf7c1
--- /dev/null
+++ b/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/interpretation/PartialNeighbourhoodCalculator.java
@@ -0,0 +1,63 @@
1/*
2 * SPDX-FileCopyrightText: 2024 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.store.reasoning.interpretation;
7
8import org.eclipse.collections.api.set.primitive.IntSet;
9import tools.refinery.store.map.Cursor;
10import tools.refinery.store.model.Model;
11import tools.refinery.store.query.ModelQueryAdapter;
12import tools.refinery.store.reasoning.ReasoningAdapter;
13import tools.refinery.store.reasoning.literal.Concreteness;
14import tools.refinery.store.reasoning.representation.PartialSymbol;
15import tools.refinery.store.statecoding.StateCodeCalculatorFactory;
16import tools.refinery.store.statecoding.StateCoderResult;
17import tools.refinery.store.statecoding.neighbourhood.AbstractNeighbourhoodCalculator;
18import tools.refinery.store.tuple.Tuple;
19
20import java.util.List;
21
22public class PartialNeighbourhoodCalculator extends AbstractNeighbourhoodCalculator<PartialInterpretation<?, ?>> {
23 private final ModelQueryAdapter queryAdapter;
24
25 public static final StateCodeCalculatorFactory FACTORY =
26 (model, ignoredInterpretations, individuals) -> new PartialNeighbourhoodCalculator(model, individuals);
27
28 protected PartialNeighbourhoodCalculator(Model model, IntSet individuals) {
29 super(model, individuals);
30 queryAdapter = model.getAdapter(ModelQueryAdapter.class);
31 }
32
33 @Override
34 public StateCoderResult calculateCodes() {
35 queryAdapter.flushChanges();
36 return super.calculateCodes();
37 }
38
39 @Override
40 protected List<PartialInterpretation<?, ?>> getInterpretations() {
41 var adapter = getModel().getAdapter(ReasoningAdapter.class);
42 var partialSymbols = adapter.getStoreAdapter().getPartialSymbols();
43 return partialSymbols.stream()
44 .<PartialInterpretation<?, ?>>map(partialSymbol ->
45 adapter.getPartialInterpretation(Concreteness.PARTIAL, (PartialSymbol<?, ?>) partialSymbol))
46 .toList();
47 }
48
49 @Override
50 protected int getArity(PartialInterpretation<?, ?> interpretation) {
51 return interpretation.getPartialSymbol().arity();
52 }
53
54 @Override
55 protected Object getNullValue(PartialInterpretation<?, ?> interpretation) {
56 return interpretation.get(Tuple.of());
57 }
58
59 @Override
60 protected Cursor<Tuple, ?> getCursor(PartialInterpretation<?, ?> interpretation) {
61 return interpretation.getAll();
62 }
63}
diff --git a/subprojects/store/build.gradle.kts b/subprojects/store/build.gradle.kts
index f96922a9..e48c0088 100644
--- a/subprojects/store/build.gradle.kts
+++ b/subprojects/store/build.gradle.kts
@@ -10,6 +10,6 @@ plugins {
10} 10}
11 11
12dependencies { 12dependencies {
13 implementation(libs.eclipseCollections.api) 13 api(libs.eclipseCollections.api)
14 runtimeOnly(libs.eclipseCollections) 14 runtimeOnly(libs.eclipseCollections)
15} 15}
diff --git a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java
index 4cef6786..5bfc4725 100644
--- a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java
+++ b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java
@@ -1,5 +1,5 @@
1/* 1/*
2 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/> 2 * SPDX-FileCopyrightText: 2023-2024 The Refinery Authors <https://refinery.tools/>
3 * 3 *
4 * SPDX-License-Identifier: EPL-2.0 4 * SPDX-License-Identifier: EPL-2.0
5 */ 5 */
@@ -8,39 +8,87 @@ package tools.refinery.store.statecoding.neighbourhood;
8import org.eclipse.collections.api.factory.primitive.IntLongMaps; 8import org.eclipse.collections.api.factory.primitive.IntLongMaps;
9import org.eclipse.collections.api.map.primitive.MutableIntLongMap; 9import org.eclipse.collections.api.map.primitive.MutableIntLongMap;
10import org.eclipse.collections.api.set.primitive.IntSet; 10import org.eclipse.collections.api.set.primitive.IntSet;
11import tools.refinery.store.model.AnyInterpretation; 11import tools.refinery.store.map.Cursor;
12import tools.refinery.store.model.Interpretation;
13import tools.refinery.store.model.Model; 12import tools.refinery.store.model.Model;
14import tools.refinery.store.statecoding.ObjectCode; 13import tools.refinery.store.statecoding.ObjectCode;
14import tools.refinery.store.statecoding.StateCodeCalculator;
15import tools.refinery.store.statecoding.StateCoderResult;
15import tools.refinery.store.tuple.Tuple; 16import tools.refinery.store.tuple.Tuple;
16import tools.refinery.store.tuple.Tuple0;
17 17
18import java.util.*; 18import java.util.*;
19 19
20public abstract class AbstractNeighbourhoodCalculator { 20public abstract class AbstractNeighbourhoodCalculator<T> implements StateCodeCalculator {
21 protected final Model model; 21 private final Model model;
22 protected final List<AnyInterpretation> nullImpactValues; 22 private final IntSet individuals;
23 protected final LinkedHashMap<AnyInterpretation, long[]> impactValues; 23 private List<T> nullImpactValues;
24 protected final MutableIntLongMap individualHashValues = IntLongMaps.mutable.empty(); 24 private LinkedHashMap<T, long[]> impactValues;
25 private MutableIntLongMap individualHashValues;
26 private ObjectCodeImpl previousObjectCode = new ObjectCodeImpl();
27 private ObjectCodeImpl nextObjectCode = new ObjectCodeImpl();
25 28
26 protected static final long PRIME = 31; 29 protected static final long PRIME = 31;
27 30
28 protected AbstractNeighbourhoodCalculator(Model model, List<? extends AnyInterpretation> interpretations, 31 protected AbstractNeighbourhoodCalculator(Model model, IntSet individuals) {
29 IntSet individuals) {
30 this.model = model; 32 this.model = model;
31 this.nullImpactValues = new ArrayList<>(); 33 this.individuals = individuals;
32 this.impactValues = new LinkedHashMap<>(); 34 }
35
36 protected Model getModel() {
37 return model;
38 }
39
40 protected abstract List<T> getInterpretations();
41
42 protected abstract int getArity(T interpretation);
43
44 protected abstract Object getNullValue(T interpretation);
45
46 // We need the wildcard here, because we don't know the value type.
47 @SuppressWarnings("squid:S1452")
48 protected abstract Cursor<Tuple, ?> getCursor(T interpretation);
49
50 @Override
51 public StateCoderResult calculateCodes() {
52 model.checkCancelled();
53 ensureInitialized();
54 previousObjectCode.clear();
55 nextObjectCode.clear();
56 initializeWithIndividuals(previousObjectCode);
57
58 int rounds = 0;
59 do {
60 model.checkCancelled();
61 constructNextObjectCodes(previousObjectCode, nextObjectCode);
62 var tempObjectCode = previousObjectCode;
63 previousObjectCode = nextObjectCode;
64 nextObjectCode = tempObjectCode;
65 nextObjectCode.clear();
66 rounds++;
67 } while (rounds <= 7 && rounds <= previousObjectCode.getEffectiveSize());
68
69 long result = calculateLastSum(previousObjectCode);
70 return new StateCoderResult((int) result, previousObjectCode);
71 }
72
73 private void ensureInitialized() {
74 if (impactValues != null) {
75 return;
76 }
77
78 nullImpactValues = new ArrayList<>();
79 impactValues = new LinkedHashMap<>();
80 individualHashValues = IntLongMaps.mutable.empty();
33 // Random isn't used for cryptographical purposes but just to assign distinguishable identifiers to symbols. 81 // Random isn't used for cryptographical purposes but just to assign distinguishable identifiers to symbols.
34 @SuppressWarnings("squid:S2245") 82 @SuppressWarnings("squid:S2245")
35 Random random = new Random(1); 83 Random random = new Random(1);
36 84
37 var individualsInOrder = individuals.toSortedList(Integer::compare); 85 var individualsInOrder = individuals.toSortedList(Integer::compare);
38 for(int i = 0; i<individualsInOrder.size(); i++) { 86 for (int i = 0; i < individualsInOrder.size(); i++) {
39 individualHashValues.put(individualsInOrder.get(i), random.nextLong()); 87 individualHashValues.put(individualsInOrder.get(i), random.nextLong());
40 } 88 }
41 89
42 for (AnyInterpretation interpretation : interpretations) { 90 for (var interpretation : getInterpretations()) {
43 int arity = interpretation.getSymbol().arity(); 91 int arity = getArity(interpretation);
44 if (arity == 0) { 92 if (arity == 0) {
45 nullImpactValues.add(interpretation); 93 nullImpactValues.add(interpretation);
46 } else { 94 } else {
@@ -50,6 +98,88 @@ public abstract class AbstractNeighbourhoodCalculator {
50 } 98 }
51 impactValues.put(interpretation, impact); 99 impactValues.put(interpretation, impact);
52 } 100 }
101
102 }
103 }
104
105 private long calculateLastSum(ObjectCode codes) {
106 long result = 0;
107 for (var nullImpactValue : nullImpactValues) {
108 result = result * PRIME + Objects.hashCode(getNullValue(nullImpactValue));
109 }
110
111 for (int i = 0; i < codes.getSize(); i++) {
112 final long hash = codes.get(i);
113 result += hash*PRIME;
114 }
115
116 return result;
117 }
118
119 private void constructNextObjectCodes(ObjectCodeImpl previous, ObjectCodeImpl next) {
120 for (var impactValueEntry : this.impactValues.entrySet()) {
121 model.checkCancelled();
122 var interpretation = impactValueEntry.getKey();
123 var cursor = getCursor(interpretation);
124 int arity = getArity(interpretation);
125 long[] impactValue = impactValueEntry.getValue();
126 impactCalculation(previous, next, impactValue, cursor, arity);
127 }
128 }
129
130 protected void impactCalculation(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValue,
131 Cursor<Tuple, ?> cursor, int arity) {
132 switch (arity) {
133 case 1 -> {
134 while (cursor.move()) {
135 impactCalculation1(previous, next, impactValue, cursor);
136 }
137 }
138 case 2 -> {
139 while (cursor.move()) {
140 impactCalculation2(previous, next, impactValue, cursor);
141 }
142 }
143 default -> {
144 while (cursor.move()) {
145 impactCalculationN(previous, next, impactValue, cursor);
146 }
147 }
148 }
149 }
150
151 private void impactCalculation1(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValues,
152 Cursor<Tuple, ?> cursor) {
153
154 Tuple tuple = cursor.getKey();
155 int o = tuple.get(0);
156 Object value = cursor.getValue();
157 long tupleHash = getTupleHash1(tuple, value, previous);
158 addHash(next, o, impactValues[0], tupleHash);
159 }
160
161 private void impactCalculation2(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValues,
162 Cursor<Tuple, ?> cursor) {
163 final Tuple tuple = cursor.getKey();
164 final int o1 = tuple.get(0);
165 final int o2 = tuple.get(1);
166
167 Object value = cursor.getValue();
168 long tupleHash = getTupleHash2(tuple, value, previous);
169
170 addHash(next, o1, impactValues[0], tupleHash);
171 addHash(next, o2, impactValues[1], tupleHash);
172 }
173
174 private void impactCalculationN(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValues,
175 Cursor<Tuple, ?> cursor) {
176 final Tuple tuple = cursor.getKey();
177
178 Object value = cursor.getValue();
179 long tupleHash = getTupleHashN(tuple, value, previous);
180
181 for (int i = 0; i < tuple.getSize(); i++) {
182 addHash(next, tuple.get(i), impactValues[i], tupleHash);
53 } 183 }
54 } 184 }
55 185
@@ -88,13 +218,4 @@ public abstract class AbstractNeighbourhoodCalculator {
88 long x = tupleHash * impact; 218 long x = tupleHash * impact;
89 objectCodeImpl.set(o, objectCodeImpl.get(o) + x); 219 objectCodeImpl.set(o, objectCodeImpl.get(o) + x);
90 } 220 }
91
92 protected long calculateModelCode(long lastSum) {
93 long result = 0;
94 for (var nullImpactValue : nullImpactValues) {
95 result = result * PRIME + Objects.hashCode(((Interpretation<?>) nullImpactValue).get(Tuple0.INSTANCE));
96 }
97 result += lastSum;
98 return result;
99 }
100} 221}
diff --git a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/LazyNeighbourhoodCalculator.java b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/LazyNeighbourhoodCalculator.java
deleted file mode 100644
index 04335141..00000000
--- a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/LazyNeighbourhoodCalculator.java
+++ /dev/null
@@ -1,195 +0,0 @@
1/*
2 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.store.statecoding.neighbourhood;
7
8import org.eclipse.collections.api.factory.primitive.LongIntMaps;
9import org.eclipse.collections.api.map.primitive.LongIntMap;
10import org.eclipse.collections.api.map.primitive.MutableLongIntMap;
11import org.eclipse.collections.api.set.primitive.IntSet;
12import tools.refinery.store.map.Cursor;
13import tools.refinery.store.model.AnyInterpretation;
14import tools.refinery.store.model.Interpretation;
15import tools.refinery.store.model.Model;
16import tools.refinery.store.statecoding.StateCodeCalculator;
17import tools.refinery.store.statecoding.StateCoderResult;
18import tools.refinery.store.tuple.Tuple;
19
20import java.util.List;
21
22public class LazyNeighbourhoodCalculator extends AbstractNeighbourhoodCalculator implements StateCodeCalculator {
23 public LazyNeighbourhoodCalculator(Model model, List<? extends AnyInterpretation> interpretations,
24 IntSet individuals) {
25 super(model, interpretations, individuals);
26 }
27
28 public StateCoderResult calculateCodes() {
29 ObjectCodeImpl previousObjectCode = new ObjectCodeImpl();
30 MutableLongIntMap prevHash2Amount = LongIntMaps.mutable.empty();
31
32 long lastSum;
33 // All hash code is 0, except to the individuals.
34 int lastSize = 1;
35 boolean first = true;
36
37 boolean grows;
38 int rounds = 0;
39 do {
40 final ObjectCodeImpl nextObjectCode;
41 if (first) {
42 nextObjectCode = new ObjectCodeImpl();
43 initializeWithIndividuals(nextObjectCode);
44 } else {
45 nextObjectCode = new ObjectCodeImpl(previousObjectCode);
46 }
47 constructNextObjectCodes(previousObjectCode, nextObjectCode, prevHash2Amount);
48
49 MutableLongIntMap nextHash2Amount = LongIntMaps.mutable.empty();
50 lastSum = calculateLastSum(previousObjectCode, nextObjectCode, prevHash2Amount, nextHash2Amount);
51
52 int nextSize = nextHash2Amount.size();
53 grows = nextSize > lastSize;
54 lastSize = nextSize;
55 first = false;
56
57 previousObjectCode = nextObjectCode;
58 prevHash2Amount = nextHash2Amount;
59 } while (grows && rounds++ < 4/*&& lastSize < previousObjectCode.getSize()*/);
60
61 long result = calculateModelCode(lastSum);
62
63 return new StateCoderResult((int) result, previousObjectCode);
64 }
65
66 private long calculateLastSum(ObjectCodeImpl previous, ObjectCodeImpl next, LongIntMap hash2Amount,
67 MutableLongIntMap nextHash2Amount) {
68 long lastSum = 0;
69 for (int i = 0; i < next.getSize(); i++) {
70 final long hash;
71 if (isUnique(hash2Amount, previous, i)) {
72 hash = previous.get(i);
73 next.set(i, hash);
74 } else {
75 hash = next.get(i);
76 }
77
78 final int amount = nextHash2Amount.get(hash);
79 nextHash2Amount.put(hash, amount + 1);
80
81 final long shifted1 = hash >>> 8;
82 final long shifted2 = hash << 8;
83 final long shifted3 = hash >> 2;
84 lastSum += shifted1 * shifted3 + shifted2;
85 }
86 return lastSum;
87 }
88
89 private void constructNextObjectCodes(ObjectCodeImpl previous, ObjectCodeImpl next, LongIntMap hash2Amount) {
90 for (var impactValueEntry : this.impactValues.entrySet()) {
91 Interpretation<?> interpretation = (Interpretation<?>) impactValueEntry.getKey();
92 var cursor = interpretation.getAll();
93 int arity = interpretation.getSymbol().arity();
94 long[] impactValue = impactValueEntry.getValue();
95
96 if (arity == 1) {
97 while (cursor.move()) {
98 lazyImpactCalculation1(hash2Amount, previous, next, impactValue, cursor);
99 }
100 } else if (arity == 2) {
101 while (cursor.move()) {
102 lazyImpactCalculation2(hash2Amount, previous, next, impactValue, cursor);
103 }
104 } else {
105 while (cursor.move()) {
106 lazyImpactCalculationN(hash2Amount, previous, next, impactValue, cursor);
107 }
108 }
109 }
110 }
111
112 private boolean isUnique(LongIntMap hash2Amount, ObjectCodeImpl objectCodeImpl, int object) {
113 final long hash = objectCodeImpl.get(object);
114 if (hash == 0) {
115 return false;
116 }
117 final int amount = hash2Amount.get(hash);
118 return amount == 1;
119 }
120
121 private void lazyImpactCalculation1(LongIntMap hash2Amount, ObjectCodeImpl previous, ObjectCodeImpl next,
122 long[] impactValues, Cursor<Tuple, ?> cursor) {
123
124 Tuple tuple = cursor.getKey();
125 int o = tuple.get(0);
126
127 if (isUnique(hash2Amount, previous, o)) {
128 next.ensureSize(o);
129 } else {
130 Object value = cursor.getValue();
131 long tupleHash = getTupleHash1(tuple, value, previous);
132
133 addHash(next, o, impactValues[0], tupleHash);
134 }
135 }
136
137 private void lazyImpactCalculation2(LongIntMap hash2Amount, ObjectCodeImpl previous, ObjectCodeImpl next,
138 long[] impactValues, Cursor<Tuple, ?> cursor) {
139 final Tuple tuple = cursor.getKey();
140 final int o1 = tuple.get(0);
141 final int o2 = tuple.get(1);
142
143 final boolean u1 = isUnique(hash2Amount, previous, o1);
144 final boolean u2 = isUnique(hash2Amount, previous, o2);
145
146 if (u1 && u2) {
147 next.ensureSize(o1);
148 next.ensureSize(o2);
149 } else {
150 Object value = cursor.getValue();
151 long tupleHash = getTupleHash2(tuple, value, previous);
152
153 if (!u1) {
154 addHash(next, o1, impactValues[0], tupleHash);
155 next.ensureSize(o2);
156 }
157 if (!u2) {
158 next.ensureSize(o1);
159 addHash(next, o2, impactValues[1], tupleHash);
160 }
161 }
162 }
163
164 private void lazyImpactCalculationN(LongIntMap hash2Amount, ObjectCodeImpl previous, ObjectCodeImpl next,
165 long[] impactValues, Cursor<Tuple, ?> cursor) {
166 final Tuple tuple = cursor.getKey();
167
168 final boolean[] uniques = new boolean[tuple.getSize()];
169 boolean allUnique = true;
170 for (int i = 0; i < tuple.getSize(); i++) {
171 final boolean isUnique = isUnique(hash2Amount, previous, tuple.get(i));
172 uniques[i] = isUnique;
173 allUnique &= isUnique;
174 }
175
176 if (allUnique) {
177 for (int i = 0; i < tuple.getSize(); i++) {
178 next.ensureSize(tuple.get(i));
179 }
180 } else {
181 Object value = cursor.getValue();
182 long tupleHash = getTupleHashN(tuple, value, previous);
183
184 for (int i = 0; i < tuple.getSize(); i++) {
185 int o = tuple.get(i);
186 if (!uniques[i]) {
187 addHash(next, o, impactValues[i], tupleHash);
188 } else {
189 next.ensureSize(o);
190 }
191 }
192 }
193 }
194
195}
diff --git a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java
index 5e6de53b..f6071c5b 100644
--- a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java
+++ b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java
@@ -1,5 +1,5 @@
1/* 1/*
2 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/> 2 * SPDX-FileCopyrightText: 2023-2024 The Refinery Authors <https://refinery.tools/>
3 * 3 *
4 * SPDX-License-Identifier: EPL-2.0 4 * SPDX-License-Identifier: EPL-2.0
5 */ 5 */
@@ -9,115 +9,36 @@ import org.eclipse.collections.api.set.primitive.IntSet;
9import tools.refinery.store.map.Cursor; 9import tools.refinery.store.map.Cursor;
10import tools.refinery.store.model.Interpretation; 10import tools.refinery.store.model.Interpretation;
11import tools.refinery.store.model.Model; 11import tools.refinery.store.model.Model;
12import tools.refinery.store.statecoding.ObjectCode;
13import tools.refinery.store.statecoding.StateCodeCalculator;
14import tools.refinery.store.statecoding.StateCoderResult;
15import tools.refinery.store.tuple.Tuple; 12import tools.refinery.store.tuple.Tuple;
16import tools.refinery.store.tuple.Tuple0;
17 13
18import java.util.List; 14import java.util.List;
19import java.util.Objects;
20 15
21public class NeighbourhoodCalculator extends AbstractNeighbourhoodCalculator implements StateCodeCalculator { 16public class NeighbourhoodCalculator extends AbstractNeighbourhoodCalculator<Interpretation<?>> {
22 private ObjectCodeImpl previousObjectCode = new ObjectCodeImpl(); 17 private final List<Interpretation<?>> interpretations;
23 private ObjectCodeImpl nextObjectCode = new ObjectCodeImpl();
24 18
25 public NeighbourhoodCalculator(Model model, List<? extends Interpretation<?>> interpretations, IntSet individuals) { 19 public NeighbourhoodCalculator(Model model, List<? extends Interpretation<?>> interpretations,
26 super(model, interpretations, individuals); 20 IntSet individuals) {
21 super(model, individuals);
22 this.interpretations = List.copyOf(interpretations);
27 } 23 }
28 24
29 public StateCoderResult calculateCodes() { 25 @Override
30 model.checkCancelled(); 26 public List<Interpretation<?>> getInterpretations() {
31 previousObjectCode.clear(); 27 return interpretations;
32 nextObjectCode.clear();
33 initializeWithIndividuals(previousObjectCode);
34
35 int rounds = 0;
36 do {
37 model.checkCancelled();
38 constructNextObjectCodes(previousObjectCode, nextObjectCode);
39 var tempObjectCode = previousObjectCode;
40 previousObjectCode = nextObjectCode;
41 nextObjectCode = tempObjectCode;
42 nextObjectCode.clear();
43 rounds++;
44 } while (rounds <= 7 && rounds <= previousObjectCode.getEffectiveSize());
45
46 long result = calculateLastSum(previousObjectCode);
47 return new StateCoderResult((int) result, previousObjectCode);
48 } 28 }
49 29
50 private long calculateLastSum(ObjectCode codes) { 30 @Override
51 long result = 0; 31 protected int getArity(Interpretation<?> interpretation) {
52 for (var nullImpactValue : nullImpactValues) { 32 return interpretation.getSymbol().arity();
53 result = result * PRIME + Objects.hashCode(((Interpretation<?>) nullImpactValue).get(Tuple0.INSTANCE));
54 }
55
56 for (int i = 0; i < codes.getSize(); i++) {
57 final long hash = codes.get(i);
58 result += hash*PRIME;
59 }
60
61 return result;
62 } 33 }
63 34
64 private void constructNextObjectCodes(ObjectCodeImpl previous, ObjectCodeImpl next) { 35 @Override
65 for (var impactValueEntry : this.impactValues.entrySet()) { 36 protected Object getNullValue(Interpretation<?> interpretation) {
66 model.checkCancelled(); 37 return interpretation.get(Tuple.of());
67 Interpretation<?> interpretation = (Interpretation<?>) impactValueEntry.getKey();
68 var cursor = interpretation.getAll();
69 int arity = interpretation.getSymbol().arity();
70 long[] impactValue = impactValueEntry.getValue();
71
72 if (arity == 1) {
73 while (cursor.move()) {
74 impactCalculation1(previous, next, impactValue, cursor);
75 }
76 } else if (arity == 2) {
77 while (cursor.move()) {
78 impactCalculation2(previous, next, impactValue, cursor);
79 }
80 } else {
81 while (cursor.move()) {
82 impactCalculationN(previous, next, impactValue, cursor);
83 }
84 }
85 }
86 } 38 }
87 39
88 40 @Override
89 private void impactCalculation1(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValues, 41 protected Cursor<Tuple, ?> getCursor(Interpretation<?> interpretation) {
90 Cursor<Tuple, ?> cursor) { 42 return interpretation.getAll();
91
92 Tuple tuple = cursor.getKey();
93 int o = tuple.get(0);
94 Object value = cursor.getValue();
95 long tupleHash = getTupleHash1(tuple, value, previous);
96 addHash(next, o, impactValues[0], tupleHash);
97 }
98
99 private void impactCalculation2(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValues,
100 Cursor<Tuple, ?> cursor) {
101 final Tuple tuple = cursor.getKey();
102 final int o1 = tuple.get(0);
103 final int o2 = tuple.get(1);
104
105 Object value = cursor.getValue();
106 long tupleHash = getTupleHash2(tuple, value, previous);
107
108 addHash(next, o1, impactValues[0], tupleHash);
109 addHash(next, o2, impactValues[1], tupleHash);
110 }
111
112 private void impactCalculationN(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValues,
113 Cursor<Tuple, ?> cursor) {
114 final Tuple tuple = cursor.getKey();
115
116 Object value = cursor.getValue();
117 long tupleHash = getTupleHashN(tuple, value, previous);
118
119 for (int i = 0; i < tuple.getSize(); i++) {
120 addHash(next, tuple.get(i), impactValues[i], tupleHash);
121 }
122 } 43 }
123} 44}