From 8a62e425b77c43a541d58d4615a13cd182d12a81 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Fri, 31 May 2024 19:22:46 +0200 Subject: feat: generate multiple solutions Switch to partial interpretation based neighborhood calculation when multiple models are request to avoid returning isomorphic models. --- .../AbstractNeighbourhoodCalculator.java | 171 +++++++++++++++--- .../neighbourhood/LazyNeighbourhoodCalculator.java | 195 --------------------- .../neighbourhood/NeighbourhoodCalculator.java | 117 ++----------- 3 files changed, 165 insertions(+), 318 deletions(-) delete mode 100644 subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/LazyNeighbourhoodCalculator.java (limited to 'subprojects/store/src/main/java/tools') diff --git a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java index 4cef6786..5bfc4725 100644 --- a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java +++ b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: 2023 The Refinery Authors + * SPDX-FileCopyrightText: 2023-2024 The Refinery Authors * * SPDX-License-Identifier: EPL-2.0 */ @@ -8,39 +8,87 @@ package tools.refinery.store.statecoding.neighbourhood; import org.eclipse.collections.api.factory.primitive.IntLongMaps; import org.eclipse.collections.api.map.primitive.MutableIntLongMap; import org.eclipse.collections.api.set.primitive.IntSet; -import tools.refinery.store.model.AnyInterpretation; -import tools.refinery.store.model.Interpretation; +import tools.refinery.store.map.Cursor; import tools.refinery.store.model.Model; import tools.refinery.store.statecoding.ObjectCode; +import tools.refinery.store.statecoding.StateCodeCalculator; +import tools.refinery.store.statecoding.StateCoderResult; import tools.refinery.store.tuple.Tuple; -import tools.refinery.store.tuple.Tuple0; import java.util.*; -public abstract class AbstractNeighbourhoodCalculator { - protected final Model model; - protected final List nullImpactValues; - protected final LinkedHashMap impactValues; - protected final MutableIntLongMap individualHashValues = IntLongMaps.mutable.empty(); +public abstract class AbstractNeighbourhoodCalculator implements StateCodeCalculator { + private final Model model; + private final IntSet individuals; + private List nullImpactValues; + private LinkedHashMap impactValues; + private MutableIntLongMap individualHashValues; + private ObjectCodeImpl previousObjectCode = new ObjectCodeImpl(); + private ObjectCodeImpl nextObjectCode = new ObjectCodeImpl(); protected static final long PRIME = 31; - protected AbstractNeighbourhoodCalculator(Model model, List interpretations, - IntSet individuals) { + protected AbstractNeighbourhoodCalculator(Model model, IntSet individuals) { this.model = model; - this.nullImpactValues = new ArrayList<>(); - this.impactValues = new LinkedHashMap<>(); + this.individuals = individuals; + } + + protected Model getModel() { + return model; + } + + protected abstract List getInterpretations(); + + protected abstract int getArity(T interpretation); + + protected abstract Object getNullValue(T interpretation); + + // We need the wildcard here, because we don't know the value type. + @SuppressWarnings("squid:S1452") + protected abstract Cursor getCursor(T interpretation); + + @Override + public StateCoderResult calculateCodes() { + model.checkCancelled(); + ensureInitialized(); + previousObjectCode.clear(); + nextObjectCode.clear(); + initializeWithIndividuals(previousObjectCode); + + int rounds = 0; + do { + model.checkCancelled(); + constructNextObjectCodes(previousObjectCode, nextObjectCode); + var tempObjectCode = previousObjectCode; + previousObjectCode = nextObjectCode; + nextObjectCode = tempObjectCode; + nextObjectCode.clear(); + rounds++; + } while (rounds <= 7 && rounds <= previousObjectCode.getEffectiveSize()); + + long result = calculateLastSum(previousObjectCode); + return new StateCoderResult((int) result, previousObjectCode); + } + + private void ensureInitialized() { + if (impactValues != null) { + return; + } + + nullImpactValues = new ArrayList<>(); + impactValues = new LinkedHashMap<>(); + individualHashValues = IntLongMaps.mutable.empty(); // Random isn't used for cryptographical purposes but just to assign distinguishable identifiers to symbols. @SuppressWarnings("squid:S2245") Random random = new Random(1); var individualsInOrder = individuals.toSortedList(Integer::compare); - for(int i = 0; i cursor, int arity) { + switch (arity) { + case 1 -> { + while (cursor.move()) { + impactCalculation1(previous, next, impactValue, cursor); + } + } + case 2 -> { + while (cursor.move()) { + impactCalculation2(previous, next, impactValue, cursor); + } + } + default -> { + while (cursor.move()) { + impactCalculationN(previous, next, impactValue, cursor); + } + } + } + } + + private void impactCalculation1(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValues, + Cursor cursor) { + + Tuple tuple = cursor.getKey(); + int o = tuple.get(0); + Object value = cursor.getValue(); + long tupleHash = getTupleHash1(tuple, value, previous); + addHash(next, o, impactValues[0], tupleHash); + } + + private void impactCalculation2(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValues, + Cursor cursor) { + final Tuple tuple = cursor.getKey(); + final int o1 = tuple.get(0); + final int o2 = tuple.get(1); + + Object value = cursor.getValue(); + long tupleHash = getTupleHash2(tuple, value, previous); + + addHash(next, o1, impactValues[0], tupleHash); + addHash(next, o2, impactValues[1], tupleHash); + } + + private void impactCalculationN(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValues, + Cursor cursor) { + final Tuple tuple = cursor.getKey(); + + Object value = cursor.getValue(); + long tupleHash = getTupleHashN(tuple, value, previous); + + for (int i = 0; i < tuple.getSize(); i++) { + addHash(next, tuple.get(i), impactValues[i], tupleHash); } } @@ -88,13 +218,4 @@ public abstract class AbstractNeighbourhoodCalculator { long x = tupleHash * impact; objectCodeImpl.set(o, objectCodeImpl.get(o) + x); } - - protected long calculateModelCode(long lastSum) { - long result = 0; - for (var nullImpactValue : nullImpactValues) { - result = result * PRIME + Objects.hashCode(((Interpretation) nullImpactValue).get(Tuple0.INSTANCE)); - } - result += lastSum; - return result; - } } diff --git a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/LazyNeighbourhoodCalculator.java b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/LazyNeighbourhoodCalculator.java deleted file mode 100644 index 04335141..00000000 --- a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/LazyNeighbourhoodCalculator.java +++ /dev/null @@ -1,195 +0,0 @@ -/* - * SPDX-FileCopyrightText: 2023 The Refinery Authors - * - * SPDX-License-Identifier: EPL-2.0 - */ -package tools.refinery.store.statecoding.neighbourhood; - -import org.eclipse.collections.api.factory.primitive.LongIntMaps; -import org.eclipse.collections.api.map.primitive.LongIntMap; -import org.eclipse.collections.api.map.primitive.MutableLongIntMap; -import org.eclipse.collections.api.set.primitive.IntSet; -import tools.refinery.store.map.Cursor; -import tools.refinery.store.model.AnyInterpretation; -import tools.refinery.store.model.Interpretation; -import tools.refinery.store.model.Model; -import tools.refinery.store.statecoding.StateCodeCalculator; -import tools.refinery.store.statecoding.StateCoderResult; -import tools.refinery.store.tuple.Tuple; - -import java.util.List; - -public class LazyNeighbourhoodCalculator extends AbstractNeighbourhoodCalculator implements StateCodeCalculator { - public LazyNeighbourhoodCalculator(Model model, List interpretations, - IntSet individuals) { - super(model, interpretations, individuals); - } - - public StateCoderResult calculateCodes() { - ObjectCodeImpl previousObjectCode = new ObjectCodeImpl(); - MutableLongIntMap prevHash2Amount = LongIntMaps.mutable.empty(); - - long lastSum; - // All hash code is 0, except to the individuals. - int lastSize = 1; - boolean first = true; - - boolean grows; - int rounds = 0; - do { - final ObjectCodeImpl nextObjectCode; - if (first) { - nextObjectCode = new ObjectCodeImpl(); - initializeWithIndividuals(nextObjectCode); - } else { - nextObjectCode = new ObjectCodeImpl(previousObjectCode); - } - constructNextObjectCodes(previousObjectCode, nextObjectCode, prevHash2Amount); - - MutableLongIntMap nextHash2Amount = LongIntMaps.mutable.empty(); - lastSum = calculateLastSum(previousObjectCode, nextObjectCode, prevHash2Amount, nextHash2Amount); - - int nextSize = nextHash2Amount.size(); - grows = nextSize > lastSize; - lastSize = nextSize; - first = false; - - previousObjectCode = nextObjectCode; - prevHash2Amount = nextHash2Amount; - } while (grows && rounds++ < 4/*&& lastSize < previousObjectCode.getSize()*/); - - long result = calculateModelCode(lastSum); - - return new StateCoderResult((int) result, previousObjectCode); - } - - private long calculateLastSum(ObjectCodeImpl previous, ObjectCodeImpl next, LongIntMap hash2Amount, - MutableLongIntMap nextHash2Amount) { - long lastSum = 0; - for (int i = 0; i < next.getSize(); i++) { - final long hash; - if (isUnique(hash2Amount, previous, i)) { - hash = previous.get(i); - next.set(i, hash); - } else { - hash = next.get(i); - } - - final int amount = nextHash2Amount.get(hash); - nextHash2Amount.put(hash, amount + 1); - - final long shifted1 = hash >>> 8; - final long shifted2 = hash << 8; - final long shifted3 = hash >> 2; - lastSum += shifted1 * shifted3 + shifted2; - } - return lastSum; - } - - private void constructNextObjectCodes(ObjectCodeImpl previous, ObjectCodeImpl next, LongIntMap hash2Amount) { - for (var impactValueEntry : this.impactValues.entrySet()) { - Interpretation interpretation = (Interpretation) impactValueEntry.getKey(); - var cursor = interpretation.getAll(); - int arity = interpretation.getSymbol().arity(); - long[] impactValue = impactValueEntry.getValue(); - - if (arity == 1) { - while (cursor.move()) { - lazyImpactCalculation1(hash2Amount, previous, next, impactValue, cursor); - } - } else if (arity == 2) { - while (cursor.move()) { - lazyImpactCalculation2(hash2Amount, previous, next, impactValue, cursor); - } - } else { - while (cursor.move()) { - lazyImpactCalculationN(hash2Amount, previous, next, impactValue, cursor); - } - } - } - } - - private boolean isUnique(LongIntMap hash2Amount, ObjectCodeImpl objectCodeImpl, int object) { - final long hash = objectCodeImpl.get(object); - if (hash == 0) { - return false; - } - final int amount = hash2Amount.get(hash); - return amount == 1; - } - - private void lazyImpactCalculation1(LongIntMap hash2Amount, ObjectCodeImpl previous, ObjectCodeImpl next, - long[] impactValues, Cursor cursor) { - - Tuple tuple = cursor.getKey(); - int o = tuple.get(0); - - if (isUnique(hash2Amount, previous, o)) { - next.ensureSize(o); - } else { - Object value = cursor.getValue(); - long tupleHash = getTupleHash1(tuple, value, previous); - - addHash(next, o, impactValues[0], tupleHash); - } - } - - private void lazyImpactCalculation2(LongIntMap hash2Amount, ObjectCodeImpl previous, ObjectCodeImpl next, - long[] impactValues, Cursor cursor) { - final Tuple tuple = cursor.getKey(); - final int o1 = tuple.get(0); - final int o2 = tuple.get(1); - - final boolean u1 = isUnique(hash2Amount, previous, o1); - final boolean u2 = isUnique(hash2Amount, previous, o2); - - if (u1 && u2) { - next.ensureSize(o1); - next.ensureSize(o2); - } else { - Object value = cursor.getValue(); - long tupleHash = getTupleHash2(tuple, value, previous); - - if (!u1) { - addHash(next, o1, impactValues[0], tupleHash); - next.ensureSize(o2); - } - if (!u2) { - next.ensureSize(o1); - addHash(next, o2, impactValues[1], tupleHash); - } - } - } - - private void lazyImpactCalculationN(LongIntMap hash2Amount, ObjectCodeImpl previous, ObjectCodeImpl next, - long[] impactValues, Cursor cursor) { - final Tuple tuple = cursor.getKey(); - - final boolean[] uniques = new boolean[tuple.getSize()]; - boolean allUnique = true; - for (int i = 0; i < tuple.getSize(); i++) { - final boolean isUnique = isUnique(hash2Amount, previous, tuple.get(i)); - uniques[i] = isUnique; - allUnique &= isUnique; - } - - if (allUnique) { - for (int i = 0; i < tuple.getSize(); i++) { - next.ensureSize(tuple.get(i)); - } - } else { - Object value = cursor.getValue(); - long tupleHash = getTupleHashN(tuple, value, previous); - - for (int i = 0; i < tuple.getSize(); i++) { - int o = tuple.get(i); - if (!uniques[i]) { - addHash(next, o, impactValues[i], tupleHash); - } else { - next.ensureSize(o); - } - } - } - } - -} diff --git a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java index 5e6de53b..f6071c5b 100644 --- a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java +++ b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: 2023 The Refinery Authors + * SPDX-FileCopyrightText: 2023-2024 The Refinery Authors * * SPDX-License-Identifier: EPL-2.0 */ @@ -9,115 +9,36 @@ import org.eclipse.collections.api.set.primitive.IntSet; import tools.refinery.store.map.Cursor; import tools.refinery.store.model.Interpretation; import tools.refinery.store.model.Model; -import tools.refinery.store.statecoding.ObjectCode; -import tools.refinery.store.statecoding.StateCodeCalculator; -import tools.refinery.store.statecoding.StateCoderResult; import tools.refinery.store.tuple.Tuple; -import tools.refinery.store.tuple.Tuple0; import java.util.List; -import java.util.Objects; -public class NeighbourhoodCalculator extends AbstractNeighbourhoodCalculator implements StateCodeCalculator { - private ObjectCodeImpl previousObjectCode = new ObjectCodeImpl(); - private ObjectCodeImpl nextObjectCode = new ObjectCodeImpl(); +public class NeighbourhoodCalculator extends AbstractNeighbourhoodCalculator> { + private final List> interpretations; - public NeighbourhoodCalculator(Model model, List> interpretations, IntSet individuals) { - super(model, interpretations, individuals); + public NeighbourhoodCalculator(Model model, List> interpretations, + IntSet individuals) { + super(model, individuals); + this.interpretations = List.copyOf(interpretations); } - public StateCoderResult calculateCodes() { - model.checkCancelled(); - previousObjectCode.clear(); - nextObjectCode.clear(); - initializeWithIndividuals(previousObjectCode); - - int rounds = 0; - do { - model.checkCancelled(); - constructNextObjectCodes(previousObjectCode, nextObjectCode); - var tempObjectCode = previousObjectCode; - previousObjectCode = nextObjectCode; - nextObjectCode = tempObjectCode; - nextObjectCode.clear(); - rounds++; - } while (rounds <= 7 && rounds <= previousObjectCode.getEffectiveSize()); - - long result = calculateLastSum(previousObjectCode); - return new StateCoderResult((int) result, previousObjectCode); + @Override + public List> getInterpretations() { + return interpretations; } - private long calculateLastSum(ObjectCode codes) { - long result = 0; - for (var nullImpactValue : nullImpactValues) { - result = result * PRIME + Objects.hashCode(((Interpretation) nullImpactValue).get(Tuple0.INSTANCE)); - } - - for (int i = 0; i < codes.getSize(); i++) { - final long hash = codes.get(i); - result += hash*PRIME; - } - - return result; + @Override + protected int getArity(Interpretation interpretation) { + return interpretation.getSymbol().arity(); } - private void constructNextObjectCodes(ObjectCodeImpl previous, ObjectCodeImpl next) { - for (var impactValueEntry : this.impactValues.entrySet()) { - model.checkCancelled(); - Interpretation interpretation = (Interpretation) impactValueEntry.getKey(); - var cursor = interpretation.getAll(); - int arity = interpretation.getSymbol().arity(); - long[] impactValue = impactValueEntry.getValue(); - - if (arity == 1) { - while (cursor.move()) { - impactCalculation1(previous, next, impactValue, cursor); - } - } else if (arity == 2) { - while (cursor.move()) { - impactCalculation2(previous, next, impactValue, cursor); - } - } else { - while (cursor.move()) { - impactCalculationN(previous, next, impactValue, cursor); - } - } - } + @Override + protected Object getNullValue(Interpretation interpretation) { + return interpretation.get(Tuple.of()); } - - private void impactCalculation1(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValues, - Cursor cursor) { - - Tuple tuple = cursor.getKey(); - int o = tuple.get(0); - Object value = cursor.getValue(); - long tupleHash = getTupleHash1(tuple, value, previous); - addHash(next, o, impactValues[0], tupleHash); - } - - private void impactCalculation2(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValues, - Cursor cursor) { - final Tuple tuple = cursor.getKey(); - final int o1 = tuple.get(0); - final int o2 = tuple.get(1); - - Object value = cursor.getValue(); - long tupleHash = getTupleHash2(tuple, value, previous); - - addHash(next, o1, impactValues[0], tupleHash); - addHash(next, o2, impactValues[1], tupleHash); - } - - private void impactCalculationN(ObjectCodeImpl previous, ObjectCodeImpl next, long[] impactValues, - Cursor cursor) { - final Tuple tuple = cursor.getKey(); - - Object value = cursor.getValue(); - long tupleHash = getTupleHashN(tuple, value, previous); - - for (int i = 0; i < tuple.getSize(); i++) { - addHash(next, tuple.get(i), impactValues[i], tupleHash); - } + @Override + protected Cursor getCursor(Interpretation interpretation) { + return interpretation.getAll(); } } -- cgit v1.2.3-70-g09d2