From a30e258a60c6619830dff8d17aed4af4763af2c6 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Wed, 8 May 2019 14:28:54 -0400 Subject: Interval arithmetic WIP --- .../logic2viatra/interval/Interval.xtend | 318 +++++++++++++++++++++ 1 file changed, 318 insertions(+) create mode 100644 Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval') diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend new file mode 100644 index 00000000..5b839fbd --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend @@ -0,0 +1,318 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval + +import java.math.BigDecimal +import java.math.MathContext +import java.math.RoundingMode +import org.eclipse.xtend.lib.annotations.Data + +abstract class Interval { + static val PRECISION = 32 + static val ROUND_DOWN = new MathContext(PRECISION, RoundingMode.FLOOR) + static val ROUND_UP = new MathContext(PRECISION, RoundingMode.CEILING) + + private new() { + } + + abstract def boolean isZero() + + def operator_plus() { + this + } + + abstract def Interval operator_minus() + + abstract def Interval operator_plus(Interval other) + + abstract def Interval operator_minus(Interval other) + + abstract def Interval operator_multiply(Interval other) + + abstract def Interval operator_divide(Interval other) + + public static val EMPTY = new Interval { + override isZero() { + false + } + + override operator_minus() { + EMPTY + } + + override operator_plus(Interval other) { + EMPTY + } + + override operator_minus(Interval other) { + EMPTY + } + + override operator_multiply(Interval other) { + EMPTY + } + + override operator_divide(Interval other) { + EMPTY + } + } + + public static val Interval ZERO = new NonEmpty(BigDecimal.ZERO, BigDecimal.ZERO) + + public static val Interval UNBOUNDED = new NonEmpty(null, null) + + @Data + static class NonEmpty extends Interval { + val BigDecimal lower + val BigDecimal upper + + /** + * Construct a new non-empty interval. + * + * @param lower The lower bound of the interval. Use null for negative infinity. + * @param upper The upper bound of the interval. Use null for positive infinity. + */ + new(BigDecimal lower, BigDecimal upper) { + if (lower !== null && upper !== null && lower > upper) { + throw new IllegalArgumentException("Lower bound of interval must not be larger than upper bound") + } + this.lower = lower + this.upper = upper + } + + override isZero() { + upper == BigDecimal.ZERO && lower == BigDecimal.ZERO + } + + override operator_minus() { + new NonEmpty(upper?.negate(ROUND_DOWN), lower?.negate(ROUND_UP)) + } + + override operator_plus(Interval other) { + switch (other) { + case EMPTY: EMPTY + NonEmpty: operator_plus(other) + default: throw new IllegalArgumentException("") + } + } + + def operator_plus(NonEmpty other) { + new NonEmpty( + lower.tryAdd(other.lower, ROUND_DOWN), + upper.tryAdd(other.upper, ROUND_UP) + ) + } + + private static def tryAdd(BigDecimal a, BigDecimal b, MathContext mc) { + if (b === null) { + null + } else { + a?.add(b, mc) + } + } + + override operator_minus(Interval other) { + switch (other) { + case EMPTY: EMPTY + NonEmpty: operator_minus(other) + default: throw new IllegalArgumentException("") + } + } + + def operator_minus(NonEmpty other) { + new NonEmpty( + lower.trySubtract(other.upper, ROUND_DOWN), + upper.trySubtract(other.lower, ROUND_UP) + ) + } + + private static def trySubtract(BigDecimal a, BigDecimal b, MathContext mc) { + if (b === null) { + null + } else { + a?.subtract(b, mc) + } + } + + override operator_multiply(Interval other) { + switch (other) { + case EMPTY: EMPTY + NonEmpty: operator_multiply(other) + default: throw new IllegalArgumentException("") + } + } + + def operator_multiply(NonEmpty other) { + if (nonpositive) { + if (other.nonpositive) { + new NonEmpty( + upper.multiply(other.upper, ROUND_DOWN), + lower.tryMultiply(other.lower, ROUND_UP) + ) + } else if (other.nonnegative) { + new NonEmpty( + lower.tryMultiply(other.upper, ROUND_DOWN), + upper.multiply(other.lower, ROUND_UP) + ) + } else { + new NonEmpty( + lower.tryMultiply(other.upper, ROUND_DOWN), + upper.tryMultiply(other.lower, ROUND_UP) + ) + } + } else if (nonnegative) { + if (other.nonpositive) { + new NonEmpty( + upper.tryMultiply(other.lower, ROUND_DOWN), + lower.multiply(other.upper, ROUND_UP) + ) + } else if (other.nonnegative) { + new NonEmpty( + lower.multiply(other.lower, ROUND_DOWN), + upper.tryMultiply(other.upper, ROUND_UP) + ) + } else { + new NonEmpty( + upper.tryMultiply(other.lower, ROUND_DOWN), + upper.tryMultiply(other.upper, ROUND_UP) + ) + } + } else { + if (other.nonpositive) { + new NonEmpty( + upper.tryMultiply(other.lower, ROUND_DOWN), + lower.tryMultiply(other.lower, ROUND_UP) + ) + } else if (other.nonnegative) { + new NonEmpty( + lower.tryMultiply(other.upper, ROUND_DOWN), + upper.tryMultiply(other.upper, ROUND_UP) + ) + } else { + new NonEmpty( + lower.tryMultiply(other.upper, ROUND_DOWN).tryMin(upper.tryMultiply(other.lower, ROUND_DOWN)), + lower.tryMultiply(other.lower, ROUND_UP).tryMax(upper.tryMultiply(other.upper, ROUND_UP)) + ) + } + } + } + + private def isNonpositive() { + upper !== null && upper <= BigDecimal.ZERO + } + + private def isNonnegative() { + lower !== null && lower >= BigDecimal.ZERO + } + + private static def tryMultiply(BigDecimal a, BigDecimal b, MathContext mc) { + if (b === null) { + null + } else { + a?.multiply(b, mc) + } + } + + private static def tryMin(BigDecimal a, BigDecimal b) { + if (b === null) { + null + } else { + a?.min(b) + } + } + + private static def tryMax(BigDecimal a, BigDecimal b) { + if (b === null) { + null + } else { + a?.max(b) + } + } + + override operator_divide(Interval other) { + switch (other) { + case EMPTY: EMPTY + NonEmpty: operator_divide(other) + default: throw new IllegalArgumentException("") + } + } + + def operator_divide(NonEmpty other) { + if (other.strictlyNegative) { + if (nonpositive) { + new NonEmpty( + upper.tryDivide(other.lower, ROUND_DOWN), + lower.tryDivide(other.upper, ROUND_UP) + ) + } else if (nonnegative) { + new NonEmpty( + upper.tryDivide(other.upper, ROUND_DOWN), + lower.tryDivide(other.lower, ROUND_UP) + ) + } else { // lower < 0 < upper + new NonEmpty( + upper.tryDivide(other.upper, ROUND_DOWN), + lower.tryDivide(other.upper, ROUND_UP) + ) + } + } else if (other.strictlyPositive) { + if (nonpositive) { + new NonEmpty( + lower.tryDivide(other.lower, ROUND_DOWN), + upper.tryDivide(other.upper, ROUND_UP) + ) + } else if (nonnegative) { + new NonEmpty( + lower.tryDivide(other.upper, ROUND_DOWN), + upper.tryDivide(other.lower, ROUND_UP) + ) + } else { // lower < 0 < upper + new NonEmpty( + lower.tryDivide(other.lower, ROUND_DOWN), + upper.tryDivide(other.lower, ROUND_UP) + ) + } + } else { // other contains 0 + if (other.lower == BigDecimal.ZERO) { + if (other.upper == BigDecimal.ZERO) { // [0, 0] + EMPTY + } else { // 0 == other.lower < other.upper + if (nonpositive) { + new NonEmpty(null, upper.tryDivide(other.upper, ROUND_UP)) + } else if (nonnegative) { + new NonEmpty(lower.tryDivide(other.upper, ROUND_DOWN), null) + } else { // lower < 0 < upper + UNBOUNDED + } + } + } else { + if (other.upper == BigDecimal.ZERO) { // other.lower < other.upper == 0 + if (nonpositive) { + new NonEmpty(upper.tryDivide(other.lower, ROUND_DOWN), null) + } else if (nonnegative) { + new NonEmpty(null, lower.tryDivide(other.lower, ROUND_UP)) + } else { // lower < 0 < upper + UNBOUNDED + } + } else { // other.lower < 0 < other.upper + UNBOUNDED + } + } + } + } + + private def isStrictlyNegative() { + upper !== null && upper < BigDecimal.ZERO + } + + private def isStrictlyPositive() { + lower !== null && lower > BigDecimal.ZERO + } + + private static def tryDivide(BigDecimal a, BigDecimal b, MathContext mc) { + if (b === null) { + BigDecimal.ZERO + } else { + a?.divide(b, mc) + } + } + } +} -- cgit v1.2.3-70-g09d2 From c925edcadbabcdc6de5e0442105dc30a387d3088 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Wed, 8 May 2019 17:50:28 -0400 Subject: Implement interval arithmetic without exponentiation --- .../logic2viatra/interval/Interval.xtend | 74 +++++--- .../.classpath | 16 ++ .../.gitignore | 4 + .../.project | 34 ++++ .../.settings/org.eclipse.jdt.core.prefs | 7 + .../META-INF/MANIFEST.MF | 13 ++ .../build.properties | 5 + .../logic2viatra/tests/interval/AdditionTest.xtend | 49 +++++ .../logic2viatra/tests/interval/DivisionTest.xtend | 202 ++++++++++++++++++++ .../tests/interval/MultiplicationTest.xtend | 205 +++++++++++++++++++++ .../logic2viatra/tests/interval/NegationTest.xtend | 34 ++++ .../tests/interval/SubtractionTest.xtend | 49 +++++ 12 files changed, 667 insertions(+), 25 deletions(-) create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.classpath create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.gitignore create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.project create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.settings/org.eclipse.jdt.core.prefs create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/META-INF/MANIFEST.MF create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/build.properties create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/AdditionTest.xtend create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/DivisionTest.xtend create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MultiplicationTest.xtend create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/NegationTest.xtend create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SubtractionTest.xtend (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval') diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend index 5b839fbd..cf22315b 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend @@ -53,14 +53,34 @@ abstract class Interval { override operator_divide(Interval other) { EMPTY } + + override toString() { + "∅" + } } public static val Interval ZERO = new NonEmpty(BigDecimal.ZERO, BigDecimal.ZERO) public static val Interval UNBOUNDED = new NonEmpty(null, null) + static def Interval of(BigDecimal lower, BigDecimal upper) { + new NonEmpty(lower, upper) + } + + static def between(double lower, double upper) { + of(new BigDecimal(lower, ROUND_DOWN), new BigDecimal(upper, ROUND_UP)) + } + + static def upTo(double upper) { + of(null, new BigDecimal(upper, ROUND_UP)) + } + + static def above(double lower) { + of(new BigDecimal(lower, ROUND_DOWN), null) + } + @Data - static class NonEmpty extends Interval { + private static class NonEmpty extends Interval { val BigDecimal lower val BigDecimal upper @@ -141,7 +161,9 @@ abstract class Interval { } def operator_multiply(NonEmpty other) { - if (nonpositive) { + if (this == ZERO || other == ZERO) { + ZERO + } else if (nonpositive) { if (other.nonpositive) { new NonEmpty( upper.multiply(other.upper, ROUND_DOWN), @@ -155,7 +177,7 @@ abstract class Interval { } else { new NonEmpty( lower.tryMultiply(other.upper, ROUND_DOWN), - upper.tryMultiply(other.lower, ROUND_UP) + lower.tryMultiply(other.lower, ROUND_UP) ) } } else if (nonnegative) { @@ -236,7 +258,11 @@ abstract class Interval { } def operator_divide(NonEmpty other) { - if (other.strictlyNegative) { + if (other == ZERO) { + EMPTY + } else if (this == ZERO) { + ZERO + } else if (other.strictlyNegative) { if (nonpositive) { new NonEmpty( upper.tryDivide(other.lower, ROUND_DOWN), @@ -271,30 +297,24 @@ abstract class Interval { ) } } else { // other contains 0 - if (other.lower == BigDecimal.ZERO) { - if (other.upper == BigDecimal.ZERO) { // [0, 0] - EMPTY - } else { // 0 == other.lower < other.upper - if (nonpositive) { - new NonEmpty(null, upper.tryDivide(other.upper, ROUND_UP)) - } else if (nonnegative) { - new NonEmpty(lower.tryDivide(other.upper, ROUND_DOWN), null) - } else { // lower < 0 < upper - UNBOUNDED - } + if (other.lower == BigDecimal.ZERO) { // 0 == other.lower < other.upper, because [0, 0] was exluded earlier + if (nonpositive) { + new NonEmpty(null, upper.tryDivide(other.upper, ROUND_UP)) + } else if (nonnegative) { + new NonEmpty(lower.tryDivide(other.upper, ROUND_DOWN), null) + } else { // lower < 0 < upper + UNBOUNDED } - } else { - if (other.upper == BigDecimal.ZERO) { // other.lower < other.upper == 0 - if (nonpositive) { - new NonEmpty(upper.tryDivide(other.lower, ROUND_DOWN), null) - } else if (nonnegative) { - new NonEmpty(null, lower.tryDivide(other.lower, ROUND_UP)) - } else { // lower < 0 < upper - UNBOUNDED - } - } else { // other.lower < 0 < other.upper + } else if (other.upper == BigDecimal.ZERO) { // other.lower < other.upper == 0 + if (nonpositive) { + new NonEmpty(upper.tryDivide(other.lower, ROUND_DOWN), null) + } else if (nonnegative) { + new NonEmpty(null, lower.tryDivide(other.lower, ROUND_UP)) + } else { // lower < 0 < upper UNBOUNDED } + } else { // other.lower < 0 < other.upper + UNBOUNDED } } } @@ -314,5 +334,9 @@ abstract class Interval { a?.divide(b, mc) } } + + override toString() { + '''«IF lower === null»(-∞«ELSE»[«lower»«ENDIF», «IF upper === null»∞)«ELSE»«upper»]«ENDIF»''' + } } } diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.classpath b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.classpath new file mode 100644 index 00000000..ef58158d --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.classpath @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + + diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.gitignore b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.gitignore new file mode 100644 index 00000000..8ae4e44d --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.gitignore @@ -0,0 +1,4 @@ +/bin/ +/src-gen/ +/vql-gen/ +/xtend-gen/ diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.project b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.project new file mode 100644 index 00000000..5bc946ea --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.project @@ -0,0 +1,34 @@ + + + hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests + + + + + + org.eclipse.xtext.ui.shared.xtextBuilder + + + + + org.eclipse.jdt.core.javabuilder + + + + + org.eclipse.pde.ManifestBuilder + + + + + org.eclipse.pde.SchemaBuilder + + + + + + org.eclipse.pde.PluginNature + org.eclipse.jdt.core.javanature + org.eclipse.xtext.ui.shared.xtextNature + + diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.settings/org.eclipse.jdt.core.prefs b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.settings/org.eclipse.jdt.core.prefs new file mode 100644 index 00000000..0c68a61d --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/.settings/org.eclipse.jdt.core.prefs @@ -0,0 +1,7 @@ +eclipse.preferences.version=1 +org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled +org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 +org.eclipse.jdt.core.compiler.compliance=1.8 +org.eclipse.jdt.core.compiler.problem.assertIdentifier=error +org.eclipse.jdt.core.compiler.problem.enumIdentifier=error +org.eclipse.jdt.core.compiler.source=1.8 diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/META-INF/MANIFEST.MF b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/META-INF/MANIFEST.MF new file mode 100644 index 00000000..76c113c1 --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/META-INF/MANIFEST.MF @@ -0,0 +1,13 @@ +Manifest-Version: 1.0 +Bundle-ManifestVersion: 2 +Bundle-Name: Logic2Viatra Tests +Bundle-SymbolicName: hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests +Bundle-Version: 1.0.0.qualifier +Automatic-Module-Name: hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests +Bundle-RequiredExecutionEnvironment: JavaSE-1.8 +Import-Package: org.junit;version="4.12.0" +Require-Bundle: com.google.guava, + org.eclipse.xtext.xbase.lib, + org.eclipse.xtend.lib, + org.eclipse.xtend.lib.macro, + hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatraquery diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/build.properties b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/build.properties new file mode 100644 index 00000000..5b9d2918 --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/build.properties @@ -0,0 +1,5 @@ +source.. = src/ +output.. = bin/ +bin.includes = META-INF/,\ + . +additional.bundles = org.junit diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/AdditionTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/AdditionTest.xtend new file mode 100644 index 00000000..de5f40e1 --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/AdditionTest.xtend @@ -0,0 +1,49 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval + +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval +import java.util.Collection +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameter +import org.junit.runners.Parameterized.Parameters + +import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* + +@RunWith(Parameterized) +class AdditionTest { + @Parameters(name = "{index}: {0} + {1} = {2}") + static def Collection data() { + #[ + #[EMPTY, EMPTY, EMPTY], + #[EMPTY, between(-1, 1), EMPTY], + #[between(-1, 1), EMPTY, EMPTY], + #[UNBOUNDED, UNBOUNDED, UNBOUNDED], + #[UNBOUNDED, upTo(2), UNBOUNDED], + #[UNBOUNDED, above(-2), UNBOUNDED], + #[UNBOUNDED, between(-1, 1), UNBOUNDED], + #[upTo(2), UNBOUNDED, UNBOUNDED], + #[upTo(2), upTo(1), upTo(3)], + #[upTo(2), above(-1), UNBOUNDED], + #[upTo(2), between(-1, 2), upTo(4)], + #[above(-2), UNBOUNDED, UNBOUNDED], + #[above(-2), upTo(1), UNBOUNDED], + #[above(-2), above(-1), above(-3)], + #[above(-2), between(-1, 2), above(-3)], + #[between(-2, 3), UNBOUNDED, UNBOUNDED], + #[between(-2, 3), upTo(1), upTo(4)], + #[between(-2, 3), above(-1), above(-3)], + #[between(-2, 3), between(-1, 2.5), between(-3, 5.5)] + ] + } + + @Parameter(0) public var Interval a + @Parameter(1) public var Interval b + @Parameter(2) public var Interval result + + @Test + def void additionTest() { + Assert.assertEquals(result, a + b) + } +} diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/DivisionTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/DivisionTest.xtend new file mode 100644 index 00000000..3a8c0c5d --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/DivisionTest.xtend @@ -0,0 +1,202 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval + +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval +import java.util.Collection +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameter +import org.junit.runners.Parameterized.Parameters + +import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* + +@RunWith(Parameterized) +class DivisionTest { + @Parameters(name="{index}: {0} / {1} = {2}") + static def Collection data() { + #[ + #[EMPTY, EMPTY, EMPTY], + #[EMPTY, between(-1, 1), EMPTY], + #[between(-1, 1), EMPTY, EMPTY], + #[UNBOUNDED, UNBOUNDED, UNBOUNDED], + #[UNBOUNDED, upTo(-2), UNBOUNDED], + #[UNBOUNDED, upTo(0), UNBOUNDED], + #[UNBOUNDED, upTo(3), UNBOUNDED], + #[UNBOUNDED, above(-2), UNBOUNDED], + #[UNBOUNDED, above(0), UNBOUNDED], + #[UNBOUNDED, above(3), UNBOUNDED], + #[UNBOUNDED, between(-4, -3), UNBOUNDED], + #[UNBOUNDED, between(-4, 0), UNBOUNDED], + #[UNBOUNDED, between(-3, 4), UNBOUNDED], + #[UNBOUNDED, between(0, 4), UNBOUNDED], + #[UNBOUNDED, between(3, 4), UNBOUNDED], + #[UNBOUNDED, ZERO, EMPTY], + #[upTo(-12), UNBOUNDED, UNBOUNDED], + #[upTo(-12), upTo(-2), above(0)], + #[upTo(-12), upTo(0), above(0)], + #[upTo(-12), upTo(3), UNBOUNDED], + #[upTo(-12), above(-2), UNBOUNDED], + #[upTo(-12), above(0), upTo(0)], + #[upTo(-12), above(3), upTo(0)], + #[upTo(-12), between(-4, -3), above(3)], + #[upTo(-12), between(-4, 0), above(3)], + #[upTo(-12), between(-3, 4), UNBOUNDED], + #[upTo(-12), between(0, 4), upTo(-3)], + #[upTo(-12), between(3, 4), upTo(-3)], + #[upTo(-12), ZERO, EMPTY], + #[upTo(0), UNBOUNDED, UNBOUNDED], + #[upTo(0), upTo(-2), above(0)], + #[upTo(0), upTo(0), above(0)], + #[upTo(0), upTo(3), UNBOUNDED], + #[upTo(0), above(-2), UNBOUNDED], + #[upTo(0), above(0), upTo(0)], + #[upTo(0), above(3), upTo(0)], + #[upTo(0), between(-4, -3), above(0)], + #[upTo(0), between(-4, 0), above(0)], + #[upTo(0), between(-3, 4), UNBOUNDED], + #[upTo(0), between(0, 4), upTo(0)], + #[upTo(0), between(3, 4), upTo(0)], + #[upTo(0), ZERO, EMPTY], + #[upTo(12), UNBOUNDED, UNBOUNDED], + #[upTo(12), upTo(-2), above(-6)], + #[upTo(12), upTo(0), UNBOUNDED], + #[upTo(12), upTo(3), UNBOUNDED], + #[upTo(12), above(-2), UNBOUNDED], + #[upTo(12), above(0), UNBOUNDED], + #[upTo(12), above(3), upTo(4)], + #[upTo(12), between(-4, -3), above(-4)], + #[upTo(12), between(-4, 0), UNBOUNDED], + #[upTo(12), between(-3, 4), UNBOUNDED], + #[upTo(12), between(0, 4), UNBOUNDED], + #[upTo(12), between(3, 4), upTo(4)], + #[upTo(12), ZERO, EMPTY], + #[above(-12), UNBOUNDED, UNBOUNDED], + #[above(-12), upTo(-2), upTo(6)], + #[above(-12), upTo(0), UNBOUNDED], + #[above(-12), upTo(3), UNBOUNDED], + #[above(-12), above(-2), UNBOUNDED], + #[above(-12), above(0), UNBOUNDED], + #[above(-12), above(3), above(-4)], + #[above(-12), between(-4, -3), upTo(4)], + #[above(-12), between(-4, 0), UNBOUNDED], + #[above(-12), between(-3, 4), UNBOUNDED], + #[above(-12), between(0, 4), UNBOUNDED], + #[above(-12), between(3, 4), above(-4)], + #[above(-12), ZERO, EMPTY], + #[above(0), UNBOUNDED, UNBOUNDED], + #[above(0), upTo(-2), upTo(0)], + #[above(0), upTo(0), upTo(0)], + #[above(0), upTo(3), UNBOUNDED], + #[above(0), above(-2), UNBOUNDED], + #[above(0), above(0), above(0)], + #[above(0), above(3), above(0)], + #[above(0), between(-4, -3), upTo(0)], + #[above(0), between(-4, 0), upTo(0)], + #[above(0), between(-3, 4), UNBOUNDED], + #[above(0), between(0, 4), above(0)], + #[above(0), between(3, 4), above(0)], + #[above(0), ZERO, EMPTY], + #[above(12), UNBOUNDED, UNBOUNDED], + #[above(12), upTo(-2), upTo(0)], + #[above(12), upTo(0), upTo(0)], + #[above(12), upTo(3), UNBOUNDED], + #[above(12), above(-2), UNBOUNDED], + #[above(12), above(0), above(0)], + #[above(12), above(3), above(0)], + #[above(12), between(-4, -3), upTo(-3)], + #[above(12), between(-4, 0), upTo(-3)], + #[above(12), between(-3, 4), UNBOUNDED], + #[above(12), between(0, 4), above(3)], + #[above(12), between(3, 4), above(3)], + #[above(12), ZERO, EMPTY], + #[between(-36, -12), UNBOUNDED, UNBOUNDED], + #[between(-36, -12), upTo(-2), between(0, 18)], + #[between(-36, -12), upTo(0), above(0)], + #[between(-36, -12), upTo(3), UNBOUNDED], + #[between(-36, -12), above(-2), UNBOUNDED], + #[between(-36, -12), above(0), upTo(0)], + #[between(-36, -12), above(3), between(-12, 0)], + #[between(-36, -12), between(-4, -3), between(3, 12)], + #[between(-36, -12), between(-4, 0), above(3)], + #[between(-36, -12), between(-3, 4), UNBOUNDED], + #[between(-36, -12), between(0, 4), upTo(-3)], + #[between(-36, -12), between(3, 4), between(-12, -3)], + #[between(-36, -12), ZERO, EMPTY], + #[between(-36, 0), UNBOUNDED, UNBOUNDED], + #[between(-36, 0), upTo(-2), between(0, 18)], + #[between(-36, 0), upTo(0), above(0)], + #[between(-36, 0), upTo(3), UNBOUNDED], + #[between(-36, 0), above(-2), UNBOUNDED], + #[between(-36, 0), above(0), upTo(0)], + #[between(-36, 0), above(3), between(-12, 0)], + #[between(-36, 0), between(-4, -3), between(0, 12)], + #[between(-36, 0), between(-4, 0), above(0)], + #[between(-36, 0), between(-3, 4), UNBOUNDED], + #[between(-36, 0), between(0, 4), upTo(0)], + #[between(-36, 0), between(3, 4), between(-12, 0)], + #[between(-36, 0), ZERO, EMPTY], + #[between(-12, 36), UNBOUNDED, UNBOUNDED], + #[between(-12, 36), upTo(-2), between(-18, 6)], + #[between(-12, 36), upTo(0), UNBOUNDED], + #[between(-12, 36), upTo(3), UNBOUNDED], + #[between(-12, 36), above(-2), UNBOUNDED], + #[between(-12, 36), above(0), UNBOUNDED], + #[between(-12, 36), above(3), between(-4, 12)], + #[between(-12, 36), between(-4, -3), between(-12, 4)], + #[between(-12, 36), between(-4, 0), UNBOUNDED], + #[between(-12, 36), between(-3, 4), UNBOUNDED], + #[between(-12, 36), between(0, 4), UNBOUNDED], + #[between(-12, 36), between(3, 4), between(-4, 12)], + #[between(-12, 36), ZERO, EMPTY], + #[between(0, 36), UNBOUNDED, UNBOUNDED], + #[between(0, 36), upTo(-2), between(-18, 0)], + #[between(0, 36), upTo(0), upTo(0)], + #[between(0, 36), upTo(3), UNBOUNDED], + #[between(0, 36), above(-2), UNBOUNDED], + #[between(0, 36), above(0), above(0)], + #[between(0, 36), above(3), between(0, 12)], + #[between(0, 36), between(-4, -3), between(-12, 0)], + #[between(0, 36), between(-4, 0), upTo(0)], + #[between(0, 36), between(-3, 4), UNBOUNDED], + #[between(0, 36), between(0, 4), above(0)], + #[between(0, 36), between(3, 4), between(0, 12)], + #[between(0, 36), ZERO, EMPTY], + #[between(12, 36), UNBOUNDED, UNBOUNDED], + #[between(12, 36), upTo(-2), between(-18, 0)], + #[between(12, 36), upTo(0), upTo(0)], + #[between(12, 36), upTo(3), UNBOUNDED], + #[between(12, 36), above(-2), UNBOUNDED], + #[between(12, 36), above(0), above(0)], + #[between(12, 36), above(3), between(0, 12)], + #[between(12, 36), between(-4, -3), between(-12, -3)], + #[between(12, 36), between(-4, 0), upTo(-3)], + #[between(12, 36), between(-3, 4), UNBOUNDED], + #[between(12, 36), between(0, 4), above(3)], + #[between(12, 36), between(3, 4), between(3, 12)], + #[between(12, 36), ZERO, EMPTY], + #[ZERO, UNBOUNDED, ZERO], + #[ZERO, upTo(-2), ZERO], + #[ZERO, upTo(0), ZERO], + #[ZERO, upTo(3), ZERO], + #[ZERO, above(-2), ZERO], + #[ZERO, above(0), ZERO], + #[ZERO, above(3), ZERO], + #[ZERO, between(-4, -3), ZERO], + #[ZERO, between(-4, 0), ZERO], + #[ZERO, between(-3, 4), ZERO], + #[ZERO, between(0, 4), ZERO], + #[ZERO, between(3, 4), ZERO], + #[ZERO, ZERO, EMPTY] + ] + } + + @Parameter(0) public var Interval a + @Parameter(1) public var Interval b + @Parameter(2) public var Interval result + + @Test + def void divisionTest() { + Assert.assertEquals(result, a / b) + } +} diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MultiplicationTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MultiplicationTest.xtend new file mode 100644 index 00000000..5f997094 --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MultiplicationTest.xtend @@ -0,0 +1,205 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval + +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval +import java.util.Collection +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameter +import org.junit.runners.Parameterized.Parameters + +import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* + +@RunWith(Parameterized) +class MultiplicationTest { + @Parameters(name="{index}: {0} * {1} = {2}") + static def Collection data() { + #[ + #[EMPTY, EMPTY, EMPTY], + #[EMPTY, between(-1, 1), EMPTY], + #[between(-1, 1), EMPTY, EMPTY], + #[UNBOUNDED, UNBOUNDED, UNBOUNDED], + #[UNBOUNDED, upTo(-2), UNBOUNDED], + #[UNBOUNDED, upTo(0), UNBOUNDED], + #[UNBOUNDED, upTo(3), UNBOUNDED], + #[UNBOUNDED, above(-2), UNBOUNDED], + #[UNBOUNDED, above(0), UNBOUNDED], + #[UNBOUNDED, above(3), UNBOUNDED], + #[UNBOUNDED, between(-4, -3), UNBOUNDED], + #[UNBOUNDED, between(-4, 0), UNBOUNDED], + #[UNBOUNDED, between(-3, 4), UNBOUNDED], + #[UNBOUNDED, between(0, 4), UNBOUNDED], + #[UNBOUNDED, between(3, 4), UNBOUNDED], + #[UNBOUNDED, ZERO, ZERO], + #[upTo(-5), UNBOUNDED, UNBOUNDED], + #[upTo(-5), upTo(-2), above(10)], + #[upTo(-5), upTo(0), above(0)], + #[upTo(-5), upTo(3), UNBOUNDED], + #[upTo(-5), above(-2), UNBOUNDED], + #[upTo(-5), above(0), upTo(0)], + #[upTo(-5), above(3), upTo(-15)], + #[upTo(-5), between(-4, -3), above(15)], + #[upTo(-5), between(-4, 0), above(0)], + #[upTo(-5), between(-3, 4), UNBOUNDED], + #[upTo(-5), between(0, 4), upTo(0)], + #[upTo(-5), between(3, 4), upTo(-15)], + #[upTo(-5), ZERO, ZERO], + #[upTo(0), UNBOUNDED, UNBOUNDED], + #[upTo(0), upTo(-2), above(0)], + #[upTo(0), upTo(0), above(0)], + #[upTo(0), upTo(3), UNBOUNDED], + #[upTo(0), above(-2), UNBOUNDED], + #[upTo(0), above(0), upTo(0)], + #[upTo(0), above(3), upTo(0)], + #[upTo(0), between(-4, -3), above(0)], + #[upTo(0), between(-4, 0), above(0)], + #[upTo(0), between(-3, 4), UNBOUNDED], + #[upTo(0), between(0, 4), upTo(0)], + #[upTo(0), between(3, 4), upTo(0)], + #[upTo(0), ZERO, ZERO], + #[upTo(5), UNBOUNDED, UNBOUNDED], + #[upTo(5), upTo(-2), UNBOUNDED], + #[upTo(5), upTo(0), UNBOUNDED], + #[upTo(5), upTo(3), UNBOUNDED], + #[upTo(5), above(-2), UNBOUNDED], + #[upTo(5), above(0), UNBOUNDED], + #[upTo(5), above(3), UNBOUNDED], + #[upTo(5), between(-4, -3), above(-20)], + #[upTo(5), between(-4, 0), above(-20)], + #[upTo(5), between(-3, 4), UNBOUNDED], + #[upTo(5), between(0, 4), upTo(20)], + #[upTo(5), between(3, 4), upTo(20)], + #[upTo(5), ZERO, ZERO], + #[above(-5), UNBOUNDED, UNBOUNDED], + #[above(-5), upTo(-2), UNBOUNDED], + #[above(-5), upTo(0), UNBOUNDED], + #[above(-5), upTo(3), UNBOUNDED], + #[above(-5), above(-2), UNBOUNDED], + #[above(-5), above(0), UNBOUNDED], + #[above(-5), above(3), UNBOUNDED], + #[above(-5), between(-4, -3), upTo(20)], + #[above(-5), between(-4, 0), upTo(20)], + #[above(-5), between(-3, 4), UNBOUNDED], + #[above(-5), between(0, 4), above(-20)], + #[above(-5), between(3, 4), above(-20)], + #[above(-5), ZERO, ZERO], + #[above(0), UNBOUNDED, UNBOUNDED], + #[above(0), upTo(-2), upTo(0)], + #[above(0), upTo(0), upTo(0)], + #[above(0), upTo(3), UNBOUNDED], + #[above(0), above(-2), UNBOUNDED], + #[above(0), above(0), above(0)], + #[above(0), above(3), above(0)], + #[above(0), between(-4, -3), upTo(0)], + #[above(0), between(-4, 0), upTo(0)], + #[above(0), between(-3, 4), UNBOUNDED], + #[above(0), between(0, 4), above(0)], + #[above(0), between(3, 4), above(0)], + #[above(0), ZERO, ZERO], + #[above(5), UNBOUNDED, UNBOUNDED], + #[above(5), upTo(-2), upTo(-10)], + #[above(5), upTo(0), upTo(0)], + #[above(5), upTo(3), UNBOUNDED], + #[above(5), above(-2), UNBOUNDED], + #[above(5), above(0), above(0)], + #[above(5), above(3), above(15)], + #[above(5), between(-4, -3), upTo(-15)], + #[above(5), between(-4, 0), upTo(0)], + #[above(5), between(-3, 4), UNBOUNDED], + #[above(5), between(0, 4), above(0)], + #[above(5), between(3, 4), above(15)], + #[above(5), ZERO, ZERO], + #[between(-6, -5), UNBOUNDED, UNBOUNDED], + #[between(-6, -5), upTo(-2), above(10)], + #[between(-6, -5), upTo(0), above(0)], + #[between(-6, -5), upTo(3), above(-18)], + #[between(-6, -5), above(-2), upTo(12)], + #[between(-6, -5), above(0), upTo(0)], + #[between(-6, -5), above(3), upTo(-15)], + #[between(-6, -5), between(-4, -3), between(15, 24)], + #[between(-6, -5), between(-4, 0), between(0, 24)], + #[between(-6, -5), between(-3, 4), between(-24, 18)], + #[between(-6, -5), between(0, 4), between(-24, 0)], + #[between(-6, -5), between(3, 4), between(-24, -15)], + #[between(-6, -5), ZERO, ZERO], + #[between(-6, 0), UNBOUNDED, UNBOUNDED], + #[between(-6, 0), upTo(-2), above(0)], + #[between(-6, 0), upTo(0), above(0)], + #[between(-6, 0), upTo(3), above(-18)], + #[between(-6, 0), above(-2), upTo(12)], + #[between(-6, 0), above(0), upTo(0)], + #[between(-6, 0), above(3), upTo(0)], + #[between(-6, 0), between(-4, -3), between(0, 24)], + #[between(-6, 0), between(-4, 0), between(0, 24)], + #[between(-6, 0), between(-3, 4), between(-24, 18)], + #[between(-6, 0), between(0, 4), between(-24, 0)], + #[between(-6, 0), between(3, 4), between(-24, 0)], + #[between(-6, 0), ZERO, ZERO], + #[between(-5, 6), UNBOUNDED, UNBOUNDED], + #[between(-5, 6), upTo(-2), UNBOUNDED], + #[between(-5, 6), upTo(0), UNBOUNDED], + #[between(-5, 6), upTo(3), UNBOUNDED], + #[between(-5, 6), above(-2), UNBOUNDED], + #[between(-5, 6), above(0), UNBOUNDED], + #[between(-5, 6), above(3), UNBOUNDED], + #[between(-5, 6), between(-4, -3), between(-24, 20)], + #[between(-5, 6), between(-4, 0), between(-24, 20)], + #[between(-5, 6), between(-3, 4), between(-20, 24)], + #[between(-5, 6), between(-3, 2), between(-18, 15)], + #[between(-5, 1), between(-3, 4), between(-20, 15)], + #[between(-5, 1), between(-3, 2), between(-10, 15)], + #[between(-5, 6), between(0, 4), between(-20, 24)], + #[between(-5, 6), between(3, 4), between(-20, 24)], + #[between(-5, 6), ZERO, ZERO], + #[between(0, 6), UNBOUNDED, UNBOUNDED], + #[between(0, 6), upTo(-2), upTo(0)], + #[between(0, 6), upTo(0), upTo(0)], + #[between(0, 6), upTo(3), upTo(18)], + #[between(0, 6), above(-2), above(-12)], + #[between(0, 6), above(0), above(0)], + #[between(0, 6), above(3), above(0)], + #[between(0, 6), between(-4, -3), between(-24, 0)], + #[between(0, 6), between(-4, 0), between(-24, 0)], + #[between(0, 6), between(-3, 4), between(-18, 24)], + #[between(0, 6), between(0, 4), between(0, 24)], + #[between(0, 6), between(3, 4), between(0, 24)], + #[between(0, 6), ZERO, ZERO], + #[between(5, 6), UNBOUNDED, UNBOUNDED], + #[between(5, 6), upTo(-2), upTo(-10)], + #[between(5, 6), upTo(0), upTo(0)], + #[between(5, 6), upTo(3), upTo(18)], + #[between(5, 6), above(-2), above(-12)], + #[between(5, 6), above(0), above(0)], + #[between(5, 6), above(3), above(15)], + #[between(5, 6), between(-4, -3), between(-24, -15)], + #[between(5, 6), between(-4, 0), between(-24, 0)], + #[between(5, 6), between(-3, 4), between(-18, 24)], + #[between(5, 6), between(0, 4), between(0, 24)], + #[between(5, 6), between(3, 4), between(15, 24)], + #[between(5, 6), ZERO, ZERO], + #[ZERO, UNBOUNDED, ZERO], + #[ZERO, upTo(-2), ZERO], + #[ZERO, upTo(0), ZERO], + #[ZERO, upTo(3), ZERO], + #[ZERO, above(-2), ZERO], + #[ZERO, above(0), ZERO], + #[ZERO, above(3), ZERO], + #[ZERO, between(-4, -3), ZERO], + #[ZERO, between(-4, 0), ZERO], + #[ZERO, between(-3, 4), ZERO], + #[ZERO, between(0, 4), ZERO], + #[ZERO, between(3, 4), ZERO], + #[ZERO, ZERO, ZERO] + ] + } + + @Parameter(0) public var Interval a + @Parameter(1) public var Interval b + @Parameter(2) public var Interval result + + @Test + def void multiplicatonTest() { + Assert.assertEquals(result, a * b) + } +} diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/NegationTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/NegationTest.xtend new file mode 100644 index 00000000..477e925e --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/NegationTest.xtend @@ -0,0 +1,34 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval + +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval +import java.util.Collection +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameter +import org.junit.runners.Parameterized.Parameters + +import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* + +@RunWith(Parameterized) +class NegationTest { + @Parameters(name = "{index}: -{0} = {1}") + static def Collection data() { + #[ + #[EMPTY, EMPTY], + #[UNBOUNDED, UNBOUNDED], + #[upTo(1), above(-1)], + #[above(1), upTo(-1)], + #[between(2, 3), between(-3, -2)] + ] + } + + @Parameter(0) public var Interval a + @Parameter(1) public var Interval result + + @Test + def void negationTest() { + Assert.assertEquals(result, -a) + } +} diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SubtractionTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SubtractionTest.xtend new file mode 100644 index 00000000..30709a9e --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SubtractionTest.xtend @@ -0,0 +1,49 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval + +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval +import java.util.Collection +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameter +import org.junit.runners.Parameterized.Parameters + +import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* + +@RunWith(Parameterized) +class SubtractionTest { + @Parameters(name = "{index}: {0} - {1} = {2}") + static def Collection data() { + #[ + #[EMPTY, EMPTY, EMPTY], + #[EMPTY, between(-1, 1), EMPTY], + #[between(-1, 1), EMPTY, EMPTY], + #[UNBOUNDED, UNBOUNDED, UNBOUNDED], + #[UNBOUNDED, upTo(2), UNBOUNDED], + #[UNBOUNDED, above(-2), UNBOUNDED], + #[UNBOUNDED, between(-1, 1), UNBOUNDED], + #[upTo(2), UNBOUNDED, UNBOUNDED], + #[upTo(2), upTo(1), UNBOUNDED], + #[upTo(2), above(-1), upTo(3)], + #[upTo(2), between(-1, 2), upTo(3)], + #[above(-2), UNBOUNDED, UNBOUNDED], + #[above(-2), upTo(1), above(-3)], + #[above(-2), above(-1), UNBOUNDED], + #[above(-2), between(-1, 2), above(-4)], + #[between(-2, 3), UNBOUNDED, UNBOUNDED], + #[between(-2, 3), upTo(1), above(-3)], + #[between(-2, 3), above(-1), upTo(4)], + #[between(-2, 3), between(-1, 2.5), between(-4.5, 4)] + ] + } + + @Parameter(0) public var Interval a + @Parameter(1) public var Interval b + @Parameter(2) public var Interval result + + @Test + def void subtractionTest() { + Assert.assertEquals(result, a - b) + } +} -- cgit v1.2.3-70-g09d2 From 1999ab4733071c6a4c9989c137eb44ec62b09847 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Thu, 9 May 2019 00:49:19 -0400 Subject: Interval comparison --- .../logic2viatra/interval/Interval.xtend | 101 ++++++++++++++++- .../logic2viatra/tests/interval/RelationTest.xtend | 120 +++++++++++++++++++++ 2 files changed, 217 insertions(+), 4 deletions(-) create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/RelationTest.xtend (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval') diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend index cf22315b..93749767 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend @@ -13,7 +13,47 @@ abstract class Interval { private new() { } - abstract def boolean isZero() + abstract def boolean mustEqual(Interval other) + + abstract def boolean mayEqual(Interval other) + + def mustNotEqual(Interval other) { + !mayEqual(other) + } + + def mayNotEqual(Interval other) { + !mustEqual(other) + } + + abstract def boolean mustBeLessThan(Interval other) + + abstract def boolean mayBeLessThan(Interval other) + + def mustBeLessThanOrEqual(Interval other) { + !mayBeGreaterThan(other) + } + + def mayBeLessThanOrEqual(Interval other) { + !mustBeGreaterThan(other) + } + + def mustBeGreaterThan(Interval other) { + other.mustBeLessThan(this) + } + + def mayBeGreaterThan(Interval other) { + other.mayBeLessThan(this) + } + + def mustBeGreaterThanOrEqual(Interval other) { + other.mustBeLessThanOrEqual(this) + } + + def mayBeGreaterThanOrEqual(Interval other) { + other.mayBeLessThanOrEqual(this) + } + + abstract def Interval join(Interval other) def operator_plus() { this @@ -30,9 +70,25 @@ abstract class Interval { abstract def Interval operator_divide(Interval other) public static val EMPTY = new Interval { - override isZero() { + override mustEqual(Interval other) { + true + } + + override mayEqual(Interval other) { false } + + override mustBeLessThan(Interval other) { + true + } + + override mayBeLessThan(Interval other) { + false + } + + override join(Interval other) { + other + } override operator_minus() { EMPTY @@ -98,8 +154,45 @@ abstract class Interval { this.upper = upper } - override isZero() { - upper == BigDecimal.ZERO && lower == BigDecimal.ZERO + override mustEqual(Interval other) { + switch (other) { + case EMPTY: true + NonEmpty: lower == upper && lower == other.lower && lower == other.upper + default: throw new IllegalArgumentException("") + } + } + + override mayEqual(Interval other) { + if (other instanceof NonEmpty) { + (lower === null || other.upper === null || lower <= other.upper) && + (other.lower === null || upper === null || other.lower <= upper) + } else { + false + } + } + + override mustBeLessThan(Interval other) { + switch (other) { + case EMPTY: true + NonEmpty: upper !== null && other.lower !== null && upper < other.lower + default: throw new IllegalArgumentException("") + } + } + + override mayBeLessThan(Interval other) { + if (other instanceof NonEmpty) { + lower === null || other.upper === null || lower < other.upper + } else { + false + } + } + + override join(Interval other) { + switch (other) { + case EMPTY: this + NonEmpty: new NonEmpty(lower.tryMin(other.lower), upper.tryMin(other.upper)) + default: throw new IllegalArgumentException("") + } } override operator_minus() { diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/RelationTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/RelationTest.xtend new file mode 100644 index 00000000..23fc69ea --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/RelationTest.xtend @@ -0,0 +1,120 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval + +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval +import java.util.Collection +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameter +import org.junit.runners.Parameterized.Parameters + +import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* + +@RunWith(Parameterized) +class RelationTest { + @Parameters(name = "{index}: {0} <> {1}") + static def Collection data() { + #[ + #[EMPTY, EMPTY, true, false, true, false], + #[EMPTY, between(1, 2), true, false, true, false], + #[between(1, 2), EMPTY, true, false, true, false], + #[upTo(1), upTo(0), false, true, false, true], + #[upTo(1), upTo(1), false, true, false, true], + #[upTo(1), upTo(2), false, true, false, true], + #[upTo(1), above(0), false, true, false, true], + #[upTo(1), above(1), false, true, false, true], + #[upTo(1), above(2), false, false, true, true], + #[upTo(1), between(-1, -1), false, true, false, true], + #[upTo(1), between(-1, 0), false, true, false, true], + #[upTo(1), between(-1, 1), false, true, false, true], + #[upTo(1), between(-1, 2), false, true, false, true], + #[upTo(1), between(1, 1), false, true, false, true], + #[upTo(1), between(1, 2), false, true, false, true], + #[upTo(1), between(2, 2), false, false, true, true], + #[upTo(1), between(2, 3), false, false, true, true], + #[above(1), upTo(0), false, false, false, false], + #[above(1), upTo(1), false, true, false, false], + #[above(1), upTo(2), false, true, false, true], + #[above(1), above(0), false, true, false, true], + #[above(1), above(1), false, true, false, true], + #[above(1), above(2), false, true, false, true], + #[above(1), between(-1, -1), false, false, false, false], + #[above(1), between(-1, 0), false, false, false, false], + #[above(1), between(-1, 1), false, true, false, false], + #[above(1), between(-1, 2), false, true, false, true], + #[above(1), between(1, 1), false, true, false, false], + #[above(1), between(1, 2), false, true, false, true], + #[above(1), between(2, 2), false, true, false, true], + #[above(1), between(2, 3), false, true, false, true], + #[between(1, 1), upTo(0), false, false, false, false], + #[between(1, 1), upTo(1), false, true, false, false], + #[between(1, 1), upTo(2), false, true, false, true], + #[between(1, 1), above(0), false, true, false, true], + #[between(1, 1), above(1), false, true, false, true], + #[between(1, 1), above(2), false, false, true, true], + #[between(1, 1), between(-1, -1), false, false, false, false], + #[between(1, 1), between(-1, 0), false, false, false, false], + #[between(1, 1), between(-1, 1), false, true, false, false], + #[between(1, 1), between(-1, 2), false, true, false, true], + #[between(1, 1), between(1, 1), true, true, false, false], + #[between(1, 1), between(1, 2), false, true, false, true], + #[between(1, 1), between(2, 2), false, false, true, true], + #[between(1, 1), between(2, 3), false, false, true, true], + #[between(-1, 1), upTo(-2), false, false, false, false], + #[between(-1, 1), upTo(-1), false, true, false, false], + #[between(-1, 1), upTo(0), false, true, false, true], + #[between(-1, 1), upTo(1), false, true, false, true], + #[between(-1, 1), upTo(2), false, true, false, true], + #[between(-1, 1), above(-2), false, true, false, true], + #[between(-1, 1), above(-1), false, true, false, true], + #[between(-1, 1), above(0), false, true, false, true], + #[between(-1, 1), above(1), false, true, false, true], + #[between(-1, 1), above(2), false, false, true, true], + #[between(-1, 1), between(-3, -2), false, false, false, false], + #[between(-1, 1), between(-2, -2), false, false, false, false], + #[between(-1, 1), between(-2, -1), false, true, false, false], + #[between(-1, 1), between(-2, 0), false, true, false, true], + #[between(-1, 1), between(-2, 1), false, true, false, true], + #[between(-1, 1), between(-2, 2), false, true, false, true], + #[between(-1, 1), between(-1, -1), false, true, false, false], + #[between(-1, 1), between(-1, 0), false, true, false, true], + #[between(-1, 1), between(-1, 1), false, true, false, true], + #[between(-1, 1), between(-1, 2), false, true, false, true], + #[between(-1, 1), between(0, 0), false, true, false, true], + #[between(-1, 1), between(0, 1), false, true, false, true], + #[between(-1, 1), between(0, 2), false, true, false, true], + #[between(-1, 1), between(1, 1), false, true, false, true], + #[between(-1, 1), between(1, 2), false, true, false, true], + #[between(-1, 1), between(2, 2), false, false, true, true], + #[between(-1, 1), between(2, 3), false, false, true, true] + ] + } + + @Parameter(0) public var Interval a + @Parameter(1) public var Interval b + @Parameter(2) public var boolean mustEqual + @Parameter(3) public var boolean mayEqual + @Parameter(4) public var boolean mustBeLessThan + @Parameter(5) public var boolean mayBeLessThan + + @Test + def void mustEqualTest() { + Assert.assertEquals(mustEqual, a.mustEqual(b)) + } + + @Test + def void mayEqualTest() { + Assert.assertEquals(mayEqual, a.mayEqual(b)) + } + + @Test + def void mustBeLessThanTest() { + Assert.assertEquals(mustBeLessThan, a.mustBeLessThan(b)) + } + + @Test + def void mayBeLessThanTest() { + Assert.assertEquals(mayBeLessThan, a.mayBeLessThan(b)) + } +} -- cgit v1.2.3-70-g09d2 From 94a7e721fba3c3bf6bcda75cde474e21c5afdf39 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Thu, 9 May 2019 09:28:40 -0400 Subject: Fix interval join --- .../logic2viatra/interval/Interval.xtend | 2 +- .../logic2viatra/tests/interval/RelationTest.xtend | 150 +++++++++++---------- 2 files changed, 79 insertions(+), 73 deletions(-) (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval') diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend index 93749767..6ea96866 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend @@ -190,7 +190,7 @@ abstract class Interval { override join(Interval other) { switch (other) { case EMPTY: this - NonEmpty: new NonEmpty(lower.tryMin(other.lower), upper.tryMin(other.upper)) + NonEmpty: new NonEmpty(lower.tryMin(other.lower), upper.tryMax(other.upper)) default: throw new IllegalArgumentException("") } } diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/RelationTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/RelationTest.xtend index 23fc69ea..5527fbaa 100644 --- a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/RelationTest.xtend +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/RelationTest.xtend @@ -16,78 +16,78 @@ class RelationTest { @Parameters(name = "{index}: {0} <> {1}") static def Collection data() { #[ - #[EMPTY, EMPTY, true, false, true, false], - #[EMPTY, between(1, 2), true, false, true, false], - #[between(1, 2), EMPTY, true, false, true, false], - #[upTo(1), upTo(0), false, true, false, true], - #[upTo(1), upTo(1), false, true, false, true], - #[upTo(1), upTo(2), false, true, false, true], - #[upTo(1), above(0), false, true, false, true], - #[upTo(1), above(1), false, true, false, true], - #[upTo(1), above(2), false, false, true, true], - #[upTo(1), between(-1, -1), false, true, false, true], - #[upTo(1), between(-1, 0), false, true, false, true], - #[upTo(1), between(-1, 1), false, true, false, true], - #[upTo(1), between(-1, 2), false, true, false, true], - #[upTo(1), between(1, 1), false, true, false, true], - #[upTo(1), between(1, 2), false, true, false, true], - #[upTo(1), between(2, 2), false, false, true, true], - #[upTo(1), between(2, 3), false, false, true, true], - #[above(1), upTo(0), false, false, false, false], - #[above(1), upTo(1), false, true, false, false], - #[above(1), upTo(2), false, true, false, true], - #[above(1), above(0), false, true, false, true], - #[above(1), above(1), false, true, false, true], - #[above(1), above(2), false, true, false, true], - #[above(1), between(-1, -1), false, false, false, false], - #[above(1), between(-1, 0), false, false, false, false], - #[above(1), between(-1, 1), false, true, false, false], - #[above(1), between(-1, 2), false, true, false, true], - #[above(1), between(1, 1), false, true, false, false], - #[above(1), between(1, 2), false, true, false, true], - #[above(1), between(2, 2), false, true, false, true], - #[above(1), between(2, 3), false, true, false, true], - #[between(1, 1), upTo(0), false, false, false, false], - #[between(1, 1), upTo(1), false, true, false, false], - #[between(1, 1), upTo(2), false, true, false, true], - #[between(1, 1), above(0), false, true, false, true], - #[between(1, 1), above(1), false, true, false, true], - #[between(1, 1), above(2), false, false, true, true], - #[between(1, 1), between(-1, -1), false, false, false, false], - #[between(1, 1), between(-1, 0), false, false, false, false], - #[between(1, 1), between(-1, 1), false, true, false, false], - #[between(1, 1), between(-1, 2), false, true, false, true], - #[between(1, 1), between(1, 1), true, true, false, false], - #[between(1, 1), between(1, 2), false, true, false, true], - #[between(1, 1), between(2, 2), false, false, true, true], - #[between(1, 1), between(2, 3), false, false, true, true], - #[between(-1, 1), upTo(-2), false, false, false, false], - #[between(-1, 1), upTo(-1), false, true, false, false], - #[between(-1, 1), upTo(0), false, true, false, true], - #[between(-1, 1), upTo(1), false, true, false, true], - #[between(-1, 1), upTo(2), false, true, false, true], - #[between(-1, 1), above(-2), false, true, false, true], - #[between(-1, 1), above(-1), false, true, false, true], - #[between(-1, 1), above(0), false, true, false, true], - #[between(-1, 1), above(1), false, true, false, true], - #[between(-1, 1), above(2), false, false, true, true], - #[between(-1, 1), between(-3, -2), false, false, false, false], - #[between(-1, 1), between(-2, -2), false, false, false, false], - #[between(-1, 1), between(-2, -1), false, true, false, false], - #[between(-1, 1), between(-2, 0), false, true, false, true], - #[between(-1, 1), between(-2, 1), false, true, false, true], - #[between(-1, 1), between(-2, 2), false, true, false, true], - #[between(-1, 1), between(-1, -1), false, true, false, false], - #[between(-1, 1), between(-1, 0), false, true, false, true], - #[between(-1, 1), between(-1, 1), false, true, false, true], - #[between(-1, 1), between(-1, 2), false, true, false, true], - #[between(-1, 1), between(0, 0), false, true, false, true], - #[between(-1, 1), between(0, 1), false, true, false, true], - #[between(-1, 1), between(0, 2), false, true, false, true], - #[between(-1, 1), between(1, 1), false, true, false, true], - #[between(-1, 1), between(1, 2), false, true, false, true], - #[between(-1, 1), between(2, 2), false, false, true, true], - #[between(-1, 1), between(2, 3), false, false, true, true] + #[EMPTY, EMPTY, true, false, true, false, EMPTY], + #[EMPTY, between(1, 2), true, false, true, false, between(1, 2)], + #[between(1, 2), EMPTY, true, false, true, false, between(1, 2)], + #[upTo(1), upTo(0), false, true, false, true, upTo(1)], + #[upTo(1), upTo(1), false, true, false, true, upTo(1)], + #[upTo(1), upTo(2), false, true, false, true, upTo(2)], + #[upTo(1), above(0), false, true, false, true, UNBOUNDED], + #[upTo(1), above(1), false, true, false, true, UNBOUNDED], + #[upTo(1), above(2), false, false, true, true, UNBOUNDED], + #[upTo(1), between(-1, -1), false, true, false, true, upTo(1)], + #[upTo(1), between(-1, 0), false, true, false, true, upTo(1)], + #[upTo(1), between(-1, 1), false, true, false, true, upTo(1)], + #[upTo(1), between(-1, 2), false, true, false, true, upTo(2)], + #[upTo(1), between(1, 1), false, true, false, true, upTo(1)], + #[upTo(1), between(1, 2), false, true, false, true, upTo(2)], + #[upTo(1), between(2, 2), false, false, true, true, upTo(2)], + #[upTo(1), between(2, 3), false, false, true, true, upTo(3)], + #[above(1), upTo(0), false, false, false, false, UNBOUNDED], + #[above(1), upTo(1), false, true, false, false, UNBOUNDED], + #[above(1), upTo(2), false, true, false, true, UNBOUNDED], + #[above(1), above(0), false, true, false, true, above(0)], + #[above(1), above(1), false, true, false, true, above(1)], + #[above(1), above(2), false, true, false, true, above(1)], + #[above(1), between(-1, -1), false, false, false, false, above(-1)], + #[above(1), between(-1, 0), false, false, false, false, above(-1)], + #[above(1), between(-1, 1), false, true, false, false, above(-1)], + #[above(1), between(-1, 2), false, true, false, true, above(-1)], + #[above(1), between(1, 1), false, true, false, false, above(1)], + #[above(1), between(1, 2), false, true, false, true, above(1)], + #[above(1), between(2, 2), false, true, false, true, above(1)], + #[above(1), between(2, 3), false, true, false, true, above(1)], + #[between(1, 1), upTo(0), false, false, false, false, upTo(1)], + #[between(1, 1), upTo(1), false, true, false, false, upTo(1)], + #[between(1, 1), upTo(2), false, true, false, true, upTo(2)], + #[between(1, 1), above(0), false, true, false, true, above(0)], + #[between(1, 1), above(1), false, true, false, true, above(1)], + #[between(1, 1), above(2), false, false, true, true, above(1)], + #[between(1, 1), between(-1, -1), false, false, false, false, between(-1, 1)], + #[between(1, 1), between(-1, 0), false, false, false, false, between(-1, 1)], + #[between(1, 1), between(-1, 1), false, true, false, false, between(-1, 1)], + #[between(1, 1), between(-1, 2), false, true, false, true, between(-1, 2)], + #[between(1, 1), between(1, 1), true, true, false, false, between(1, 1)], + #[between(1, 1), between(1, 2), false, true, false, true, between(1, 2)], + #[between(1, 1), between(2, 2), false, false, true, true, between(1, 2)], + #[between(1, 1), between(2, 3), false, false, true, true, between(1, 3)], + #[between(-1, 1), upTo(-2), false, false, false, false, upTo(1)], + #[between(-1, 1), upTo(-1), false, true, false, false, upTo(1)], + #[between(-1, 1), upTo(0), false, true, false, true, upTo(1)], + #[between(-1, 1), upTo(1), false, true, false, true, upTo(1)], + #[between(-1, 1), upTo(2), false, true, false, true, upTo(2)], + #[between(-1, 1), above(-2), false, true, false, true, above(-2)], + #[between(-1, 1), above(-1), false, true, false, true, above(-1)], + #[between(-1, 1), above(0), false, true, false, true, above(-1)], + #[between(-1, 1), above(1), false, true, false, true, above(-1)], + #[between(-1, 1), above(2), false, false, true, true, above(-1)], + #[between(-1, 1), between(-3, -2), false, false, false, false, between(-3, 1)], + #[between(-1, 1), between(-2, -2), false, false, false, false, between(-2, 1)], + #[between(-1, 1), between(-2, -1), false, true, false, false, between(-2, 1)], + #[between(-1, 1), between(-2, 0), false, true, false, true, between(-2, 1)], + #[between(-1, 1), between(-2, 1), false, true, false, true, between(-2, 1)], + #[between(-1, 1), between(-2, 2), false, true, false, true, between(-2, 2)], + #[between(-1, 1), between(-1, -1), false, true, false, false, between(-1, 1)], + #[between(-1, 1), between(-1, 0), false, true, false, true, between(-1, 1)], + #[between(-1, 1), between(-1, 1), false, true, false, true, between(-1, 1)], + #[between(-1, 1), between(-1, 2), false, true, false, true, between(-1, 2)], + #[between(-1, 1), between(0, 0), false, true, false, true, between(-1, 1)], + #[between(-1, 1), between(0, 1), false, true, false, true, between(-1, 1)], + #[between(-1, 1), between(0, 2), false, true, false, true, between(-1, 2)], + #[between(-1, 1), between(1, 1), false, true, false, true, between(-1, 1)], + #[between(-1, 1), between(1, 2), false, true, false, true, between(-1, 2)], + #[between(-1, 1), between(2, 2), false, false, true, true, between(-1, 2)], + #[between(-1, 1), between(2, 3), false, false, true, true, between(-1, 3)] ] } @@ -97,6 +97,7 @@ class RelationTest { @Parameter(3) public var boolean mayEqual @Parameter(4) public var boolean mustBeLessThan @Parameter(5) public var boolean mayBeLessThan + @Parameter(6) public var Interval join @Test def void mustEqualTest() { @@ -117,4 +118,9 @@ class RelationTest { def void mayBeLessThanTest() { Assert.assertEquals(mayBeLessThan, a.mayBeLessThan(b)) } + + @Test + def void joinTest() { + Assert.assertEquals(join, a.join(b)) + } } -- cgit v1.2.3-70-g09d2 From ba167247757d76df603a6527d9ad51c3d9f150b9 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Thu, 9 May 2019 20:24:56 -0400 Subject: Interval aggregation operators --- .../build.properties | 3 +- .../logic2viatra/interval/Interval.xtend | 88 +- .../interval/IntervalAggregationMode.java | 66 + .../interval/IntervalAggregationOperator.xtend | 48 + .../interval/IntervalRedBlackNode.xtend | 177 +++ .../logic2viatra/interval/RedBlackNode.java | 1392 ++++++++++++++++++++ .../logic2viatra/interval/Reference.java | 51 + .../logic2viatra/tests/interval/SumTest.xtend | 140 ++ 8 files changed, 1947 insertions(+), 18 deletions(-) create mode 100644 Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationMode.java create mode 100644 Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationOperator.xtend create mode 100644 Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalRedBlackNode.xtend create mode 100644 Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/RedBlackNode.java create mode 100644 Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Reference.java create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval') diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/build.properties b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/build.properties index 585df5ce..9ffc994a 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/build.properties +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/build.properties @@ -4,6 +4,5 @@ bin.includes = META-INF/,\ source.. = src/,\ patterns/,\ vql-gen/,\ - xtend-gen/,\ - src-gen/ + xtend-gen/ output.. = bin/ diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend index 6ea96866..229656c0 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend @@ -5,7 +5,7 @@ import java.math.MathContext import java.math.RoundingMode import org.eclipse.xtend.lib.annotations.Data -abstract class Interval { +abstract class Interval implements Comparable { static val PRECISION = 32 static val ROUND_DOWN = new MathContext(PRECISION, RoundingMode.FLOOR) static val ROUND_UP = new MathContext(PRECISION, RoundingMode.CEILING) @@ -24,35 +24,35 @@ abstract class Interval { def mayNotEqual(Interval other) { !mustEqual(other) } - + abstract def boolean mustBeLessThan(Interval other) - + abstract def boolean mayBeLessThan(Interval other) - + def mustBeLessThanOrEqual(Interval other) { !mayBeGreaterThan(other) } - + def mayBeLessThanOrEqual(Interval other) { !mustBeGreaterThan(other) } - + def mustBeGreaterThan(Interval other) { other.mustBeLessThan(this) } - + def mayBeGreaterThan(Interval other) { other.mayBeLessThan(this) } - + def mustBeGreaterThanOrEqual(Interval other) { other.mustBeLessThanOrEqual(this) } - + def mayBeGreaterThanOrEqual(Interval other) { other.mayBeLessThanOrEqual(this) } - + abstract def Interval join(Interval other) def operator_plus() { @@ -65,6 +65,8 @@ abstract class Interval { abstract def Interval operator_minus(Interval other) + abstract def Interval operator_multiply(int count) + abstract def Interval operator_multiply(Interval other) abstract def Interval operator_divide(Interval other) @@ -77,15 +79,15 @@ abstract class Interval { override mayEqual(Interval other) { false } - + override mustBeLessThan(Interval other) { true } - + override mayBeLessThan(Interval other) { false } - + override join(Interval other) { other } @@ -102,6 +104,10 @@ abstract class Interval { EMPTY } + override operator_multiply(int count) { + EMPTY + } + override operator_multiply(Interval other) { EMPTY } @@ -113,6 +119,15 @@ abstract class Interval { override toString() { "∅" } + + override compareTo(Interval o) { + if (o == EMPTY) { + 0 + } else { + -1 + } + } + } public static val Interval ZERO = new NonEmpty(BigDecimal.ZERO, BigDecimal.ZERO) @@ -170,7 +185,7 @@ abstract class Interval { false } } - + override mustBeLessThan(Interval other) { switch (other) { case EMPTY: true @@ -178,7 +193,7 @@ abstract class Interval { default: throw new IllegalArgumentException("") } } - + override mayBeLessThan(Interval other) { if (other instanceof NonEmpty) { lower === null || other.upper === null || lower < other.upper @@ -186,7 +201,7 @@ abstract class Interval { false } } - + override join(Interval other) { switch (other) { case EMPTY: this @@ -245,6 +260,14 @@ abstract class Interval { } } + override operator_multiply(int count) { + val bigCount = new BigDecimal(count) + new NonEmpty( + lower.tryMultiply(bigCount, ROUND_DOWN), + upper.tryMultiply(bigCount, ROUND_UP) + ) + } + override operator_multiply(Interval other) { switch (other) { case EMPTY: EMPTY @@ -431,5 +454,38 @@ abstract class Interval { override toString() { '''«IF lower === null»(-∞«ELSE»[«lower»«ENDIF», «IF upper === null»∞)«ELSE»«upper»]«ENDIF»''' } + + override compareTo(Interval o) { + switch (o) { + case EMPTY: 1 + NonEmpty: compareTo(o) + default: throw new IllegalArgumentException("") + } + } + + def compareTo(NonEmpty o) { + if (lower === null) { + if (o.lower !== null) { + return -1 + } + } else if (o.lower === null) { // lower !== null + return 1 + } else { // both lower and o.lower are finite + val lowerDifference = lower.compareTo(o.lower) + if (lowerDifference != 0) { + return lowerDifference + } + } + if (upper === null) { + if (o.upper === null) { + return 0 + } else { + return 1 + } + } else if (o.upper === null) { // upper !== null + return -1 + } + upper.compareTo(o.upper) + } } } diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationMode.java b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationMode.java new file mode 100644 index 00000000..f5bd2efc --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationMode.java @@ -0,0 +1,66 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval; + +import java.util.function.BinaryOperator; + +public enum IntervalAggregationMode implements BinaryOperator { + SUM("intervalSum", "Sum a set of intervals") { + @Override + public IntervalRedBlackNode createNode(Interval interval) { + return new IntervalRedBlackNode(interval) { + public boolean isMultiplicitySensitive() { + return true; + } + + public Interval multiply(Interval interval, int count) { + return interval.operator_multiply(count); + }; + + @Override + public Interval op(Interval left, Interval right) { + return left.operator_plus(right); + } + }; + } + }, + + JOIN("intervalJoin", "Calculate the smallest interval containing all the intervals in a set") { + @Override + public IntervalRedBlackNode createNode(Interval interval) { + return new IntervalRedBlackNode(interval) { + @Override + public Interval op(Interval left, Interval right) { + return left.join(right); + } + }; + } + }; + + private final String modeName; + private final String description; + private final IntervalRedBlackNode empty; + + IntervalAggregationMode(String modeName, String description) { + this.modeName = modeName; + this.description = description; + empty = createNode(null); + } + + public String getModeName() { + return modeName; + } + + public String getDescription() { + return description; + } + + public IntervalRedBlackNode getEmpty() { + return empty; + } + + @Override + public Interval apply(Interval left, Interval right) { + return empty.op(left, right); + } + + public abstract IntervalRedBlackNode createNode(Interval interval); +} diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationOperator.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationOperator.xtend new file mode 100644 index 00000000..940c71bb --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationOperator.xtend @@ -0,0 +1,48 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval + +import java.util.stream.Stream +import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.IMultisetAggregationOperator +import org.eclipse.xtend.lib.annotations.Accessors +import org.eclipse.xtend.lib.annotations.FinalFieldsConstructor + +@FinalFieldsConstructor +class IntervalAggregationOperator implements IMultisetAggregationOperator { + @Accessors val IntervalAggregationMode mode + + override getName() { + mode.modeName + } + + override getShortDescription() { + mode.description + } + + override createNeutral() { + mode.empty + } + + override isNeutral(IntervalRedBlackNode result) { + result.leaf + } + + override update(IntervalRedBlackNode oldResult, Interval updateValue, boolean isInsertion) { + if (isInsertion) { + val newNode = mode.createNode(updateValue) + oldResult.add(newNode) + } else { + oldResult.remove(updateValue) + } + } + + override getAggregate(IntervalRedBlackNode result) { + if (result.leaf) { + null + } else { + result.result + } + } + + override aggregateStream(Stream stream) { + stream.reduce(mode).orElse(null) + } +} diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalRedBlackNode.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalRedBlackNode.xtend new file mode 100644 index 00000000..3aa575bc --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalRedBlackNode.xtend @@ -0,0 +1,177 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval + +abstract class IntervalRedBlackNode extends RedBlackNode { + public val Interval interval + public var int count = 1 + public var Interval result + + new(Interval interval) { + this.interval = interval + } + + def boolean isMultiplicitySensitive() { + false + } + + def Interval multiply(Interval interval, int count) { + interval + } + + abstract def Interval op(Interval left, Interval right) + + override augment() { + val value = calcualteAugmentation() + if (result == value) { + false + } else { + result = value + true + } + } + + private def calcualteAugmentation() { + var value = multiply(interval, count) + if (!left.leaf) { + value = op(value, left.result) + } + if (!right.leaf) { + value = op(value, right.result) + } + value + } + + override assertNodeIsValid() { + super.assertNodeIsValid() + if (leaf) { + return + } + if (count <= 0) { + throw new IllegalStateException("Node with nonpositive count") + } + val value = calcualteAugmentation() + if (result != value) { + throw new IllegalStateException("Node with invalid augmentation: " + result + " != " + value) + } + } + + override assertSubtreeIsValid() { + super.assertSubtreeIsValid() + assertNodeIsValid() + } + + override compareTo(IntervalRedBlackNode other) { + if (leaf || other.leaf) { + throw new IllegalArgumentException("One of the nodes is a leaf node") + } + interval.compareTo(other.interval) + } + + def add(IntervalRedBlackNode newNode) { + if (parent !== null) { + throw new IllegalArgumentException("This is not the root of a tree") + } + if (leaf) { + newNode.isRed = false + newNode.left = this + newNode.right = this + newNode.parent = null + newNode.augment + return newNode + } + val modifiedNode = addWithoutFixup(newNode) + if (modifiedNode === newNode) { + // Must augment here, because fixInsertion() might call augment() + // on a node repeatedly, which might lose the change notification the + // second time it is called, and the augmentation will fail to + // reach the root. + modifiedNode.augmentRecursively + modifiedNode.isRed = true + return modifiedNode.fixInsertion + } + if (multiplicitySensitive) { + modifiedNode.augmentRecursively + } + this + } + + private def addWithoutFixup(IntervalRedBlackNode newNode) { + var node = this + while (!node.leaf) { + val comparison = node.interval.compareTo(newNode.interval) + if (comparison < 0) { + if (node.left.leaf) { + newNode.left = node.left + newNode.right = node.left + node.left = newNode + newNode.parent = node + return newNode + } else { + node = node.left + } + } else if (comparison > 0) { + if (node.right.leaf) { + newNode.left = node.right + newNode.right = node.right + node.right = newNode + newNode.parent = node + return newNode + } else { + node = node.right + } + } else { // comparison == 0 + newNode.parent = null + node.count++ + return node + } + } + throw new IllegalStateException("Reached leaf node while searching for insertion point") + } + + private def augmentRecursively() { + for (var node = this; node !== null; node = node.parent) { + if (!node.augment) { + return + } + } + } + + def remove(Interval interval) { + val node = find(interval) + node.count-- + if (node.count == 0) { + return node.remove + } + if (multiplicitySensitive) { + node.augmentRecursively + } + this + } + + private def find(Interval interval) { + var node = this + while (!node.leaf) { + val comparison = node.interval.compareTo(interval) + if (comparison < 0) { + node = node.left + } else if (comparison > 0) { + node = node.right + } else { // comparison == 0 + return node + } + } + throw new IllegalStateException("Reached leaf node while searching for interval to remove") + } + + override toString() { + if (leaf) { + "L" + } else { + ''' + «IF isRed»R«ELSE»B«ENDIF» «count»«interval» : «result» + «left» + «right» + ''' + } + } + +} diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/RedBlackNode.java b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/RedBlackNode.java new file mode 100644 index 00000000..8c40816b --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/RedBlackNode.java @@ -0,0 +1,1392 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2016 btrekkie + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval; + +import java.lang.reflect.Array; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; + +/** + * A node in a red-black tree ( https://en.wikipedia.org/wiki/Red%E2%80%93black_tree ). Compared to a class like Java's + * TreeMap, RedBlackNode is a low-level data structure. The internals of a node are exposed as public fields, allowing + * clients to directly observe and manipulate the structure of the tree. This gives clients flexibility, although it + * also enables them to violate the red-black or BST properties. The RedBlackNode class provides methods for performing + * various standard operations, such as insertion and removal. + * + * Unlike most implementations of binary search trees, RedBlackNode supports arbitrary augmentation. By subclassing + * RedBlackNode, clients can add arbitrary data and augmentation information to each node. For example, if we were to + * use a RedBlackNode subclass to implement a sorted set, the subclass would have a field storing an element in the set. + * If we wanted to keep track of the number of non-leaf nodes in each subtree, we would store this as a "size" field and + * override augment() to update this field. All RedBlackNode methods (such as "insert" and remove()) call augment() as + * necessary to correctly maintain the augmentation information, unless otherwise indicated. + * + * The values of the tree are stored in the non-leaf nodes. RedBlackNode does not support use cases where values must be + * stored in the leaf nodes. It is recommended that all of the leaf nodes in a given tree be the same (black) + * RedBlackNode instance, to save space. The root of an empty tree is a leaf node, as opposed to null. + * + * For reference, a red-black tree is a binary search tree satisfying the following properties: + * + * - Every node is colored red or black. + * - The leaf nodes, which are dummy nodes that do not store any values, are colored black. + * - The root is black. + * - Both children of each red node are black. + * - Every path from the root to a leaf contains the same number of black nodes. + * + * @param The type of node in the tree. For example, we might have + * "class FooNode extends RedBlackNode>". + * @author Bill Jacobs + */ +public abstract class RedBlackNode> implements Comparable { + /** A Comparator that compares Comparable elements using their natural order. */ + private static final Comparator> NATURAL_ORDER = new Comparator>() { + @Override + public int compare(Comparable value1, Comparable value2) { + return value1.compareTo(value2); + } + }; + + /** The parent of this node, if any. "parent" is null if this is a leaf node. */ + public N parent; + + /** The left child of this node. "left" is null if this is a leaf node. */ + public N left; + + /** The right child of this node. "right" is null if this is a leaf node. */ + public N right; + + /** Whether the node is colored red, as opposed to black. */ + public boolean isRed; + + /** + * Sets any augmentation information about the subtree rooted at this node that is stored in this node. For + * example, if we augment each node by subtree size (the number of non-leaf nodes in the subtree), this method would + * set the size field of this node to be equal to the size field of the left child plus the size field of the right + * child plus one. + * + * "Augmentation information" is information that we can compute about a subtree rooted at some node, preferably + * based only on the augmentation information in the node's two children and the information in the node. Examples + * of augmentation information are the sum of the values in a subtree and the number of non-leaf nodes in a subtree. + * Augmentation information may not depend on the colors of the nodes. + * + * This method returns whether the augmentation information in any of the ancestors of this node might have been + * affected by changes in this subtree since the last call to augment(). In the usual case, where the augmentation + * information depends only on the information in this node and the augmentation information in its immediate + * children, this is equivalent to whether the augmentation information changed as a result of this call to + * augment(). For example, in the case of subtree size, this returns whether the value of the size field prior to + * calling augment() differed from the size field of the left child plus the size field of the right child plus one. + * False positives are permitted. The return value is unspecified if we have not called augment() on this node + * before. + * + * This method may assume that this is not a leaf node. It may not assume that the augmentation information stored + * in any of the tree's nodes is correct. However, if the augmentation information stored in all of the node's + * descendants is correct, then the augmentation information stored in this node must be correct after calling + * augment(). + */ + public boolean augment() { + return false; + } + + /** + * Throws a RuntimeException if we detect that this node locally violates any invariants specific to this subclass + * of RedBlackNode. For example, if this stores the size of the subtree rooted at this node, this should throw a + * RuntimeException if the size field of this is not equal to the size field of the left child plus the size field + * of the right child plus one. Note that we may call this on a leaf node. + * + * assertSubtreeIsValid() calls assertNodeIsValid() on each node, or at least starts to do so until it detects a + * problem. assertNodeIsValid() should assume the node is in a tree that satisfies all properties common to all + * red-black trees, as assertSubtreeIsValid() is responsible for such checks. assertNodeIsValid() should be + * "downward-looking", i.e. it should ignore any information in "parent", and it should be "local", i.e. it should + * only check a constant number of descendants. To include "global" checks, such as verifying the BST property + * concerning ordering, override assertSubtreeIsValid(). assertOrderIsValid is useful for checking the BST + * property. + */ + public void assertNodeIsValid() { + + } + + /** Returns whether this is a leaf node. */ + public boolean isLeaf() { + return left == null; + } + + /** Returns the root of the tree that contains this node. */ + public N root() { + @SuppressWarnings("unchecked") + N node = (N)this; + while (node.parent != null) { + node = node.parent; + } + return node; + } + + /** Returns the first node in the subtree rooted at this node, if any. */ + public N min() { + if (isLeaf()) { + return null; + } + @SuppressWarnings("unchecked") + N node = (N)this; + while (!node.left.isLeaf()) { + node = node.left; + } + return node; + } + + /** Returns the last node in the subtree rooted at this node, if any. */ + public N max() { + if (isLeaf()) { + return null; + } + @SuppressWarnings("unchecked") + N node = (N)this; + while (!node.right.isLeaf()) { + node = node.right; + } + return node; + } + + /** Returns the node immediately before this in the tree that contains this node, if any. */ + public N predecessor() { + if (!left.isLeaf()) { + N node; + for (node = left; !node.right.isLeaf(); node = node.right); + return node; + } else if (parent == null) { + return null; + } else { + @SuppressWarnings("unchecked") + N node = (N)this; + while (node.parent != null && node.parent.left == node) { + node = node.parent; + } + return node.parent; + } + } + + /** Returns the node immediately after this in the tree that contains this node, if any. */ + public N successor() { + if (!right.isLeaf()) { + N node; + for (node = right; !node.left.isLeaf(); node = node.left); + return node; + } else if (parent == null) { + return null; + } else { + @SuppressWarnings("unchecked") + N node = (N)this; + while (node.parent != null && node.parent.right == node) { + node = node.parent; + } + return node.parent; + } + } + + /** + * Performs a left rotation about this node. This method assumes that !isLeaf() && !right.isLeaf(). It calls + * augment() on this node and on its resulting parent. However, it does not call augment() on any of the resulting + * parent's ancestors, because that is normally the responsibility of the caller. + * @return The return value from calling augment() on the resulting parent. + */ + public boolean rotateLeft() { + if (isLeaf() || right.isLeaf()) { + throw new IllegalArgumentException("The node or its right child is a leaf"); + } + N newParent = right; + right = newParent.left; + @SuppressWarnings("unchecked") + N nThis = (N)this; + if (!right.isLeaf()) { + right.parent = nThis; + } + newParent.parent = parent; + parent = newParent; + newParent.left = nThis; + if (newParent.parent != null) { + if (newParent.parent.left == this) { + newParent.parent.left = newParent; + } else { + newParent.parent.right = newParent; + } + } + augment(); + return newParent.augment(); + } + + /** + * Performs a right rotation about this node. This method assumes that !isLeaf() && !left.isLeaf(). It calls + * augment() on this node and on its resulting parent. However, it does not call augment() on any of the resulting + * parent's ancestors, because that is normally the responsibility of the caller. + * @return The return value from calling augment() on the resulting parent. + */ + public boolean rotateRight() { + if (isLeaf() || left.isLeaf()) { + throw new IllegalArgumentException("The node or its left child is a leaf"); + } + N newParent = left; + left = newParent.right; + @SuppressWarnings("unchecked") + N nThis = (N)this; + if (!left.isLeaf()) { + left.parent = nThis; + } + newParent.parent = parent; + parent = newParent; + newParent.right = nThis; + if (newParent.parent != null) { + if (newParent.parent.left == this) { + newParent.parent.left = newParent; + } else { + newParent.parent.right = newParent; + } + } + augment(); + return newParent.augment(); + } + + /** + * Performs red-black insertion fixup. To be more precise, this fixes a tree that satisfies all of the requirements + * of red-black trees, except that this may be a red child of a red node, and if this is the root, the root may be + * red. node.isRed must initially be true. This method assumes that this is not a leaf node. The method performs + * any rotations by calling rotateLeft() and rotateRight(). This method is more efficient than fixInsertion if + * "augment" is false or augment() might return false. + * @param augment Whether to set the augmentation information for "node" and its ancestors, by calling augment(). + */ + public void fixInsertionWithoutGettingRoot(boolean augment) { + if (!isRed) { + throw new IllegalArgumentException("The node must be red"); + } + boolean changed = augment; + if (augment) { + augment(); + } + + RedBlackNode node = this; + while (node.parent != null && node.parent.isRed) { + N parent = node.parent; + N grandparent = parent.parent; + if (grandparent.left.isRed && grandparent.right.isRed) { + grandparent.left.isRed = false; + grandparent.right.isRed = false; + grandparent.isRed = true; + + if (changed) { + changed = parent.augment(); + if (changed) { + changed = grandparent.augment(); + } + } + node = grandparent; + } else { + if (parent.left == node) { + if (grandparent.right == parent) { + parent.rotateRight(); + node = parent; + parent = node.parent; + } + } else if (grandparent.left == parent) { + parent.rotateLeft(); + node = parent; + parent = node.parent; + } + + if (parent.left == node) { + boolean grandparentChanged = grandparent.rotateRight(); + if (augment) { + changed = grandparentChanged; + } + } else { + boolean grandparentChanged = grandparent.rotateLeft(); + if (augment) { + changed = grandparentChanged; + } + } + + parent.isRed = false; + grandparent.isRed = true; + node = parent; + break; + } + } + + if (node.parent == null) { + node.isRed = false; + } + if (changed) { + for (node = node.parent; node != null; node = node.parent) { + if (!node.augment()) { + break; + } + } + } + } + + /** + * Performs red-black insertion fixup. To be more precise, this fixes a tree that satisfies all of the requirements + * of red-black trees, except that this may be a red child of a red node, and if this is the root, the root may be + * red. node.isRed must initially be true. This method assumes that this is not a leaf node. The method performs + * any rotations by calling rotateLeft() and rotateRight(). This method is more efficient than fixInsertion() if + * augment() might return false. + */ + public void fixInsertionWithoutGettingRoot() { + fixInsertionWithoutGettingRoot(true); + } + + /** + * Performs red-black insertion fixup. To be more precise, this fixes a tree that satisfies all of the requirements + * of red-black trees, except that this may be a red child of a red node, and if this is the root, the root may be + * red. node.isRed must initially be true. This method assumes that this is not a leaf node. The method performs + * any rotations by calling rotateLeft() and rotateRight(). + * @param augment Whether to set the augmentation information for "node" and its ancestors, by calling augment(). + * @return The root of the resulting tree. + */ + public N fixInsertion(boolean augment) { + fixInsertionWithoutGettingRoot(augment); + return root(); + } + + /** + * Performs red-black insertion fixup. To be more precise, this fixes a tree that satisfies all of the requirements + * of red-black trees, except that this may be a red child of a red node, and if this is the root, the root may be + * red. node.isRed must initially be true. This method assumes that this is not a leaf node. The method performs + * any rotations by calling rotateLeft() and rotateRight(). + * @return The root of the resulting tree. + */ + public N fixInsertion() { + fixInsertionWithoutGettingRoot(true); + return root(); + } + + /** Returns a Comparator that compares instances of N using their natural order, as in N.compareTo. */ + @SuppressWarnings({"rawtypes", "unchecked"}) + private Comparator naturalOrder() { + Comparator comparator = (Comparator)NATURAL_ORDER; + return (Comparator)comparator; + } + + /** + * Inserts the specified node into the tree rooted at this node. Assumes this is the root. We treat newNode as a + * solitary node that does not belong to any tree, and we ignore its initial "parent", "left", "right", and isRed + * fields. + * + * If it is not efficient or convenient to find the location for a node using a Comparator, then you should manually + * add the node to the appropriate location, color it red, and call fixInsertion(). + * + * @param newNode The node to insert. + * @param allowDuplicates Whether to insert newNode if there is an equal node in the tree. To check whether we + * inserted newNode, check whether newNode.parent is null and the return value differs from newNode. + * @param comparator A comparator indicating where to put the node. If this is null, we use the nodes' natural + * order, as in N.compareTo. If you are passing null, then you must override the compareTo method, because the + * default implementation requires the nodes to already be in the same tree. + * @return The root of the resulting tree. + */ + public N insert(N newNode, boolean allowDuplicates, Comparator comparator) { + if (parent != null) { + throw new IllegalArgumentException("This is not the root of a tree"); + } + @SuppressWarnings("unchecked") + N nThis = (N)this; + if (isLeaf()) { + newNode.isRed = false; + newNode.left = nThis; + newNode.right = nThis; + newNode.parent = null; + newNode.augment(); + return newNode; + } + if (comparator == null) { + comparator = naturalOrder(); + } + + N node = nThis; + int comparison; + while (true) { + comparison = comparator.compare(newNode, node); + if (comparison < 0) { + if (!node.left.isLeaf()) { + node = node.left; + } else { + newNode.left = node.left; + newNode.right = node.left; + node.left = newNode; + newNode.parent = node; + break; + } + } else if (comparison > 0 || allowDuplicates) { + if (!node.right.isLeaf()) { + node = node.right; + } else { + newNode.left = node.right; + newNode.right = node.right; + node.right = newNode; + newNode.parent = node; + break; + } + } else { + newNode.parent = null; + return nThis; + } + } + newNode.isRed = true; + return newNode.fixInsertion(); + } + + /** + * Moves this node to its successor's former position in the tree and vice versa, i.e. sets the "left", "right", + * "parent", and isRed fields of each. This method assumes that this is not a leaf node. + * @return The node with which we swapped. + */ + private N swapWithSuccessor() { + N replacement = successor(); + boolean oldReplacementIsRed = replacement.isRed; + N oldReplacementLeft = replacement.left; + N oldReplacementRight = replacement.right; + N oldReplacementParent = replacement.parent; + + replacement.isRed = isRed; + replacement.left = left; + replacement.right = right; + replacement.parent = parent; + if (parent != null) { + if (parent.left == this) { + parent.left = replacement; + } else { + parent.right = replacement; + } + } + + @SuppressWarnings("unchecked") + N nThis = (N)this; + isRed = oldReplacementIsRed; + left = oldReplacementLeft; + right = oldReplacementRight; + if (oldReplacementParent == this) { + parent = replacement; + parent.right = nThis; + } else { + parent = oldReplacementParent; + parent.left = nThis; + } + + replacement.right.parent = replacement; + if (!replacement.left.isLeaf()) { + replacement.left.parent = replacement; + } + if (!right.isLeaf()) { + right.parent = nThis; + } + return replacement; + } + + /** + * Performs red-black deletion fixup. To be more precise, this fixes a tree that satisfies all of the requirements + * of red-black trees, except that all paths from the root to a leaf that pass through the sibling of this node have + * one fewer black node than all other root-to-leaf paths. This method assumes that this is not a leaf node. + */ + private void fixSiblingDeletion() { + RedBlackNode sibling = this; + boolean changed = true; + boolean haveAugmentedParent = false; + boolean haveAugmentedGrandparent = false; + while (true) { + N parent = sibling.parent; + if (sibling.isRed) { + parent.isRed = true; + sibling.isRed = false; + if (parent.left == sibling) { + changed = parent.rotateRight(); + sibling = parent.left; + } else { + changed = parent.rotateLeft(); + sibling = parent.right; + } + haveAugmentedParent = true; + haveAugmentedGrandparent = true; + } else if (!sibling.left.isRed && !sibling.right.isRed) { + sibling.isRed = true; + if (parent.isRed) { + parent.isRed = false; + break; + } else { + if (changed && !haveAugmentedParent) { + changed = parent.augment(); + } + N grandparent = parent.parent; + if (grandparent == null) { + break; + } else if (grandparent.left == parent) { + sibling = grandparent.right; + } else { + sibling = grandparent.left; + } + haveAugmentedParent = haveAugmentedGrandparent; + haveAugmentedGrandparent = false; + } + } else { + if (sibling == parent.left) { + if (!sibling.left.isRed) { + sibling.rotateLeft(); + sibling = sibling.parent; + } + } else if (!sibling.right.isRed) { + sibling.rotateRight(); + sibling = sibling.parent; + } + sibling.isRed = parent.isRed; + parent.isRed = false; + if (sibling == parent.left) { + sibling.left.isRed = false; + changed = parent.rotateRight(); + } else { + sibling.right.isRed = false; + changed = parent.rotateLeft(); + } + haveAugmentedParent = haveAugmentedGrandparent; + haveAugmentedGrandparent = false; + break; + } + } + + // Update augmentation info + N parent = sibling.parent; + if (changed && parent != null) { + if (!haveAugmentedParent) { + changed = parent.augment(); + } + if (changed && parent.parent != null) { + parent = parent.parent; + if (!haveAugmentedGrandparent) { + changed = parent.augment(); + } + if (changed) { + for (parent = parent.parent; parent != null; parent = parent.parent) { + if (!parent.augment()) { + break; + } + } + } + } + } + } + + /** + * Removes this node from the tree that contains it. The effect of this method on the fields of this node is + * unspecified. This method assumes that this is not a leaf node. This method is more efficient than remove() if + * augment() might return false. + * + * If the node has two children, we begin by moving the node's successor to its former position, by changing the + * successor's "left", "right", "parent", and isRed fields. + */ + public void removeWithoutGettingRoot() { + if (isLeaf()) { + throw new IllegalArgumentException("Attempted to remove a leaf node"); + } + N replacement; + if (left.isLeaf() || right.isLeaf()) { + replacement = null; + } else { + replacement = swapWithSuccessor(); + } + + N child; + if (!left.isLeaf()) { + child = left; + } else if (!right.isLeaf()) { + child = right; + } else { + child = null; + } + + if (child != null) { + // Replace this node with its child + child.parent = parent; + if (parent != null) { + if (parent.left == this) { + parent.left = child; + } else { + parent.right = child; + } + } + child.isRed = false; + + if (child.parent != null) { + N parent; + for (parent = child.parent; parent != null; parent = parent.parent) { + if (!parent.augment()) { + break; + } + } + } + } else if (parent != null) { + // Replace this node with a leaf node + N leaf = left; + N parent = this.parent; + N sibling; + if (parent.left == this) { + parent.left = leaf; + sibling = parent.right; + } else { + parent.right = leaf; + sibling = parent.left; + } + + if (!isRed) { + RedBlackNode siblingNode = sibling; + siblingNode.fixSiblingDeletion(); + } else { + while (parent != null) { + if (!parent.augment()) { + break; + } + parent = parent.parent; + } + } + } + + if (replacement != null) { + replacement.augment(); + for (N parent = replacement.parent; parent != null; parent = parent.parent) { + if (!parent.augment()) { + break; + } + } + } + + // Clear any previously existing links, so that we're more likely to encounter an exception if we attempt to + // access the removed node + parent = null; + left = null; + right = null; + isRed = true; + } + + /** + * Removes this node from the tree that contains it. The effect of this method on the fields of this node is + * unspecified. This method assumes that this is not a leaf node. + * + * If the node has two children, we begin by moving the node's successor to its former position, by changing the + * successor's "left", "right", "parent", and isRed fields. + * + * @return The root of the resulting tree. + */ + public N remove() { + if (isLeaf()) { + throw new IllegalArgumentException("Attempted to remove a leaf node"); + } + + // Find an arbitrary non-leaf node in the tree other than this node + N node; + if (parent != null) { + node = parent; + } else if (!left.isLeaf()) { + node = left; + } else if (!right.isLeaf()) { + node = right; + } else { + return left; + } + + removeWithoutGettingRoot(); + return node.root(); + } + + /** + * Returns the root of a perfectly height-balanced subtree containing the next "size" (non-leaf) nodes from + * "iterator", in iteration order. This method is responsible for setting the "left", "right", "parent", and isRed + * fields of the nodes, and calling augment() as appropriate. It ignores the initial values of the "left", "right", + * "parent", and isRed fields. + * @param iterator The nodes. + * @param size The number of nodes. + * @param height The "height" of the subtree's root node above the deepest leaf in the tree that contains it. Since + * insertion fixup is slow if there are too many red nodes and deleteion fixup is slow if there are too few red + * nodes, we compromise and have red nodes at every fourth level. We color a node red iff its "height" is equal + * to 1 mod 4. + * @param leaf The leaf node. + * @return The root of the subtree. + */ + private static > N createTree( + Iterator iterator, int size, int height, N leaf) { + if (size == 0) { + return leaf; + } else { + N left = createTree(iterator, (size - 1) / 2, height - 1, leaf); + N node = iterator.next(); + N right = createTree(iterator, size / 2, height - 1, leaf); + + node.isRed = height % 4 == 1; + node.left = left; + node.right = right; + if (!left.isLeaf()) { + left.parent = node; + } + if (!right.isLeaf()) { + right.parent = node; + } + + node.augment(); + return node; + } + } + + /** + * Returns the root of a perfectly height-balanced tree containing the specified nodes, in iteration order. This + * method is responsible for setting the "left", "right", "parent", and isRed fields of the nodes (excluding + * "leaf"), and calling augment() as appropriate. It ignores the initial values of the "left", "right", "parent", + * and isRed fields. + * @param nodes The nodes. + * @param leaf The leaf node. + * @return The root of the tree. + */ + public static > N createTree(Collection nodes, N leaf) { + int size = nodes.size(); + if (size == 0) { + return leaf; + } + + int height = 0; + for (int subtreeSize = size; subtreeSize > 0; subtreeSize /= 2) { + height++; + } + + N node = createTree(nodes.iterator(), size, height, leaf); + node.parent = null; + node.isRed = false; + return node; + } + + /** + * Concatenates to the end of the tree rooted at this node. To be precise, given that all of the nodes in this + * precede the node "pivot", which precedes all of the nodes in "last", this returns the root of a tree containing + * all of these nodes. This method destroys the trees rooted at "this" and "last". We treat "pivot" as a solitary + * node that does not belong to any tree, and we ignore its initial "parent", "left", "right", and isRed fields. + * This method assumes that this node and "last" are the roots of their respective trees. + * + * This method takes O(log N) time. It is more efficient than inserting "pivot" and then calling concatenate(last). + * It is considerably more efficient than inserting "pivot" and all of the nodes in "last". + */ + public N concatenate(N last, N pivot) { + // If the black height of "first", where first = this, is less than or equal to that of "last", starting at the + // root of "last", we keep going left until we reach a black node whose black height is equal to that of + // "first". Then, we make "pivot" the parent of that node and of "first", coloring it red, and perform + // insertion fixup on the pivot. If the black height of "first" is greater than that of "last", we do the + // mirror image of the above. + + if (parent != null) { + throw new IllegalArgumentException("This is not the root of a tree"); + } + if (last.parent != null) { + throw new IllegalArgumentException("\"last\" is not the root of a tree"); + } + + // Compute the black height of the trees + int firstBlackHeight = 0; + @SuppressWarnings("unchecked") + N first = (N)this; + for (N node = first; node != null; node = node.right) { + if (!node.isRed) { + firstBlackHeight++; + } + } + int lastBlackHeight = 0; + for (N node = last; node != null; node = node.right) { + if (!node.isRed) { + lastBlackHeight++; + } + } + + // Identify the children and parent of pivot + N firstChild = first; + N lastChild = last; + N parent; + if (firstBlackHeight <= lastBlackHeight) { + parent = null; + int blackHeight = lastBlackHeight; + while (blackHeight > firstBlackHeight) { + if (!lastChild.isRed) { + blackHeight--; + } + parent = lastChild; + lastChild = lastChild.left; + } + if (lastChild.isRed) { + parent = lastChild; + lastChild = lastChild.left; + } + } else { + parent = null; + int blackHeight = firstBlackHeight; + while (blackHeight > lastBlackHeight) { + if (!firstChild.isRed) { + blackHeight--; + } + parent = firstChild; + firstChild = firstChild.right; + } + if (firstChild.isRed) { + parent = firstChild; + firstChild = firstChild.right; + } + } + + // Add "pivot" to the tree + pivot.isRed = true; + pivot.parent = parent; + if (parent != null) { + if (firstBlackHeight < lastBlackHeight) { + parent.left = pivot; + } else { + parent.right = pivot; + } + } + pivot.left = firstChild; + if (!firstChild.isLeaf()) { + firstChild.parent = pivot; + } + pivot.right = lastChild; + if (!lastChild.isLeaf()) { + lastChild.parent = pivot; + } + + // Perform insertion fixup + return pivot.fixInsertion(); + } + + /** + * Concatenates the tree rooted at "last" to the end of the tree rooted at this node. To be precise, given that all + * of the nodes in this precede all of the nodes in "last", this returns the root of a tree containing all of these + * nodes. This method destroys the trees rooted at "this" and "last". It assumes that this node and "last" are the + * roots of their respective trees. This method takes O(log N) time. It is considerably more efficient than + * inserting all of the nodes in "last". + */ + public N concatenate(N last) { + if (parent != null || last.parent != null) { + throw new IllegalArgumentException("The node is not the root of a tree"); + } + if (isLeaf()) { + return last; + } else if (last.isLeaf()) { + @SuppressWarnings("unchecked") + N nThis = (N)this; + return nThis; + } else { + N node = last.min(); + last = node.remove(); + return concatenate(last, node); + } + } + + /** + * Splits the tree rooted at this node into two trees, so that the first element of the return value is the root of + * a tree consisting of the nodes that were before the specified node, and the second element of the return value is + * the root of a tree consisting of the nodes that were equal to or after the specified node. This method is + * destructive, meaning it does not preserve the original tree. It assumes that this node is the root and is in the + * same tree as splitNode. It takes O(log N) time. It is considerably more efficient than removing all of the + * nodes at or after splitNode and then creating a new tree from those nodes. + * @param The node at which to split the tree. + * @return An array consisting of the resulting trees. + */ + public N[] split(N splitNode) { + // To split the tree, we accumulate a pre-split tree and a post-split tree. We walk down the tree toward the + // position where we are splitting. Whenever we go left, we concatenate the right subtree with the post-split + // tree, and whenever we go right, we concatenate the pre-split tree with the left subtree. We use the + // concatenation algorithm described in concatenate(Object, Object). For the pivot, we use the last node where + // we went left in the case of a left move, and the last node where we went right in the case of a right move. + // + // The method uses the following variables: + // + // node: The current node in our walk down the tree. + // first: A node on the right spine of the pre-split tree. At the beginning of each iteration, it is the black + // node with the same black height as "node". If the pre-split tree is empty, this is null instead. + // firstParent: The parent of "first". If the pre-split tree is empty, this is null. Otherwise, this is the + // same as first.parent, unless first.isLeaf(). + // firstPivot: The node where we last went right, i.e. the next node to use as a pivot when concatenating with + // the pre-split tree. + // advanceFirst: Whether to set "first" to be its next black descendant at the end of the loop. + // last, lastParent, lastPivot, advanceLast: Analogous to "first", firstParent, firstPivot, and advanceFirst, + // but for the post-split tree. + if (parent != null) { + throw new IllegalArgumentException("This is not the root of a tree"); + } + if (isLeaf() || splitNode.isLeaf()) { + throw new IllegalArgumentException("The root or the split node is a leaf"); + } + + // Create an array containing the path from the root to splitNode + int depth = 1; + N parent; + for (parent = splitNode; parent.parent != null; parent = parent.parent) { + depth++; + } + if (parent != this) { + throw new IllegalArgumentException("The split node does not belong to this tree"); + } + RedBlackNode[] path = new RedBlackNode[depth]; + for (parent = splitNode; parent != null; parent = parent.parent) { + depth--; + path[depth] = parent; + } + + @SuppressWarnings("unchecked") + N node = (N)this; + N first = null; + N firstParent = null; + N last = null; + N lastParent = null; + N firstPivot = null; + N lastPivot = null; + while (!node.isLeaf()) { + boolean advanceFirst = !node.isRed && firstPivot != null; + boolean advanceLast = !node.isRed && lastPivot != null; + if ((depth + 1 < path.length && path[depth + 1] == node.left) || depth + 1 == path.length) { + // Left move + if (lastPivot == null) { + // The post-split tree is empty + last = node.right; + last.parent = null; + if (last.isRed) { + last.isRed = false; + lastParent = last; + last = last.left; + } + } else { + // Concatenate node.right and the post-split tree + if (node.right.isRed) { + node.right.isRed = false; + } else if (!node.isRed) { + lastParent = last; + last = last.left; + if (last.isRed) { + lastParent = last; + last = last.left; + } + advanceLast = false; + } + lastPivot.isRed = true; + lastPivot.parent = lastParent; + if (lastParent != null) { + lastParent.left = lastPivot; + } + lastPivot.left = node.right; + if (!lastPivot.left.isLeaf()) { + lastPivot.left.parent = lastPivot; + } + lastPivot.right = last; + if (!last.isLeaf()) { + last.parent = lastPivot; + } + last = lastPivot.left; + lastParent = lastPivot; + lastPivot.fixInsertionWithoutGettingRoot(false); + } + lastPivot = node; + node = node.left; + } else { + // Right move + if (firstPivot == null) { + // The pre-split tree is empty + first = node.left; + first.parent = null; + if (first.isRed) { + first.isRed = false; + firstParent = first; + first = first.right; + } + } else { + // Concatenate the post-split tree and node.left + if (node.left.isRed) { + node.left.isRed = false; + } else if (!node.isRed) { + firstParent = first; + first = first.right; + if (first.isRed) { + firstParent = first; + first = first.right; + } + advanceFirst = false; + } + firstPivot.isRed = true; + firstPivot.parent = firstParent; + if (firstParent != null) { + firstParent.right = firstPivot; + } + firstPivot.right = node.left; + if (!firstPivot.right.isLeaf()) { + firstPivot.right.parent = firstPivot; + } + firstPivot.left = first; + if (!first.isLeaf()) { + first.parent = firstPivot; + } + first = firstPivot.right; + firstParent = firstPivot; + firstPivot.fixInsertionWithoutGettingRoot(false); + } + firstPivot = node; + node = node.right; + } + + depth++; + + // Update "first" and "last" to be the nodes at the proper black height + if (advanceFirst) { + firstParent = first; + first = first.right; + if (first.isRed) { + firstParent = first; + first = first.right; + } + } + if (advanceLast) { + lastParent = last; + last = last.left; + if (last.isRed) { + lastParent = last; + last = last.left; + } + } + } + + // Add firstPivot to the pre-split tree + N leaf = node; + if (first == null) { + first = leaf; + } else { + firstPivot.isRed = true; + firstPivot.parent = firstParent; + if (firstParent != null) { + firstParent.right = firstPivot; + } + firstPivot.left = leaf; + firstPivot.right = leaf; + firstPivot.fixInsertionWithoutGettingRoot(false); + for (first = firstPivot; first.parent != null; first = first.parent) { + first.augment(); + } + first.augment(); + } + + // Add lastPivot to the post-split tree + lastPivot.isRed = true; + lastPivot.parent = lastParent; + if (lastParent != null) { + lastParent.left = lastPivot; + } + lastPivot.left = leaf; + lastPivot.right = leaf; + lastPivot.fixInsertionWithoutGettingRoot(false); + for (last = lastPivot; last.parent != null; last = last.parent) { + last.augment(); + } + last.augment(); + + @SuppressWarnings("unchecked") + N[] result = (N[])Array.newInstance(getClass(), 2); + result[0] = first; + result[1] = last; + return result; + } + + /** + * Returns the lowest common ancestor of this node and "other" - the node that is an ancestor of both and is not the + * parent of a node that is an ancestor of both. Assumes that this is in the same tree as "other". Assumes that + * neither "this" nor "other" is a leaf node. This method may return "this" or "other". + * + * Note that while it is possible to compute the lowest common ancestor in O(P) time, where P is the length of the + * path from this node to "other", the "lca" method is not guaranteed to take O(P) time. If your application + * requires this, then you should write your own lowest common ancestor method. + */ + public N lca(N other) { + if (isLeaf() || other.isLeaf()) { + throw new IllegalArgumentException("One of the nodes is a leaf node"); + } + + // Compute the depth of each node + int depth = 0; + for (N parent = this.parent; parent != null; parent = parent.parent) { + depth++; + } + int otherDepth = 0; + for (N parent = other.parent; parent != null; parent = parent.parent) { + otherDepth++; + } + + // Go up to nodes of the same depth + @SuppressWarnings("unchecked") + N parent = (N)this; + N otherParent = other; + if (depth <= otherDepth) { + for (int i = otherDepth; i > depth; i--) { + otherParent = otherParent.parent; + } + } else { + for (int i = depth; i > otherDepth; i--) { + parent = parent.parent; + } + } + + // Find the LCA + while (parent != otherParent) { + parent = parent.parent; + otherParent = otherParent.parent; + } + if (parent != null) { + return parent; + } else { + throw new IllegalArgumentException("The nodes do not belong to the same tree"); + } + } + + /** + * Returns an integer comparing the position of this node in the tree that contains it with that of "other". Returns + * a negative number if this is earlier, a positive number if this is later, and 0 if this is at the same position. + * Assumes that this is in the same tree as "other". Assumes that neither "this" nor "other" is a leaf node. + * + * The base class's implementation takes O(log N) time. If a RedBlackNode subclass stores a value used to order the + * nodes, then it could override compareTo to compare the nodes' values, which would take O(1) time. + * + * Note that while it is possible to compare the positions of two nodes in O(P) time, where P is the length of the + * path from this node to "other", the default implementation of compareTo is not guaranteed to take O(P) time. If + * your application requires this, then you should write your own comparison method. + */ + @Override + public int compareTo(N other) { + if (isLeaf() || other.isLeaf()) { + throw new IllegalArgumentException("One of the nodes is a leaf node"); + } + + // The algorithm operates as follows: compare the depth of this node to that of "other". If the depth of + // "other" is greater, keep moving up from "other" until we find the ancestor at the same depth. Then, keep + // moving up from "this" and from that node until we reach the lowest common ancestor. The node that arrived + // from the left child of the common ancestor is earlier. The algorithm is analogous if the depth of "other" is + // not greater. + if (this == other) { + return 0; + } + + // Compute the depth of each node + int depth = 0; + RedBlackNode parent; + for (parent = this; parent.parent != null; parent = parent.parent) { + depth++; + } + int otherDepth = 0; + N otherParent; + for (otherParent = other; otherParent.parent != null; otherParent = otherParent.parent) { + otherDepth++; + } + + // Go up to nodes of the same depth + if (depth < otherDepth) { + otherParent = other; + for (int i = otherDepth - 1; i > depth; i--) { + otherParent = otherParent.parent; + } + if (otherParent.parent != this) { + otherParent = otherParent.parent; + } else if (left == otherParent) { + return 1; + } else { + return -1; + } + parent = this; + } else if (depth > otherDepth) { + parent = this; + for (int i = depth - 1; i > otherDepth; i--) { + parent = parent.parent; + } + if (parent.parent != other) { + parent = parent.parent; + } else if (other.left == parent) { + return -1; + } else { + return 1; + } + otherParent = other; + } else { + parent = this; + otherParent = other; + } + + // Keep going up until we reach the lowest common ancestor + while (parent.parent != otherParent.parent) { + parent = parent.parent; + otherParent = otherParent.parent; + } + if (parent.parent == null) { + throw new IllegalArgumentException("The nodes do not belong to the same tree"); + } + if (parent.parent.left == parent) { + return -1; + } else { + return 1; + } + } + + /** Throws a RuntimeException if the RedBlackNode fields of this are not correct for a leaf node. */ + private void assertIsValidLeaf() { + if (left != null || right != null || parent != null || isRed) { + throw new RuntimeException("A leaf node's \"left\", \"right\", \"parent\", or isRed field is incorrect"); + } + } + + /** + * Throws a RuntimeException if the subtree rooted at this node does not satisfy the red-black properties, excluding + * the requirement that the root be black, or it contains a repeated node other than a leaf node. + * @param blackHeight The required number of black nodes in each path from this to a leaf node, including this and + * the leaf node. + * @param visited The nodes we have reached thus far, other than leaf nodes. This method adds the non-leaf nodes in + * the subtree rooted at this node to "visited". + */ + private void assertSubtreeIsValidRedBlack(int blackHeight, Set> visited) { + @SuppressWarnings("unchecked") + N nThis = (N)this; + if (left == null || right == null) { + assertIsValidLeaf(); + if (blackHeight != 1) { + throw new RuntimeException("Not all root-to-leaf paths have the same number of black nodes"); + } + return; + } else if (!visited.add(new Reference(nThis))) { + throw new RuntimeException("The tree contains a repeated non-leaf node"); + } else { + int childBlackHeight; + if (isRed) { + if ((!left.isLeaf() && left.isRed) || (!right.isLeaf() && right.isRed)) { + throw new RuntimeException("A red node has a red child"); + } + childBlackHeight = blackHeight; + } else if (blackHeight == 0) { + throw new RuntimeException("Not all root-to-leaf paths have the same number of black nodes"); + } else { + childBlackHeight = blackHeight - 1; + } + + if (!left.isLeaf() && left.parent != this) { + throw new RuntimeException("left.parent != this"); + } + if (!right.isLeaf() && right.parent != this) { + throw new RuntimeException("right.parent != this"); + } + RedBlackNode leftNode = left; + RedBlackNode rightNode = right; + leftNode.assertSubtreeIsValidRedBlack(childBlackHeight, visited); + rightNode.assertSubtreeIsValidRedBlack(childBlackHeight, visited); + } + } + + /** Calls assertNodeIsValid() on every node in the subtree rooted at this node. */ + private void assertNodesAreValid() { + assertNodeIsValid(); + if (left != null) { + RedBlackNode leftNode = left; + RedBlackNode rightNode = right; + leftNode.assertNodesAreValid(); + rightNode.assertNodesAreValid(); + } + } + + /** + * Throws a RuntimeException if the subtree rooted at this node is not a valid red-black tree, e.g. if a red node + * has a red child or it contains a non-leaf node "node" for which node.left.parent != node. (If parent != null, + * it's okay if isRed is true.) This method is useful for debugging. See also assertSubtreeIsValid(). + */ + public void assertSubtreeIsValidRedBlack() { + if (isLeaf()) { + assertIsValidLeaf(); + } else { + if (parent == null && isRed) { + throw new RuntimeException("The root is red"); + } + + // Compute the black height of the tree + Set> nodes = new HashSet>(); + int blackHeight = 0; + @SuppressWarnings("unchecked") + N node = (N)this; + while (node != null) { + if (!nodes.add(new Reference(node))) { + throw new RuntimeException("The tree contains a repeated non-leaf node"); + } + if (!node.isRed) { + blackHeight++; + } + node = node.left; + } + + assertSubtreeIsValidRedBlack(blackHeight, new HashSet>()); + } + } + + /** + * Throws a RuntimeException if we detect a problem with the subtree rooted at this node, such as a red child of a + * red node or a non-leaf descendant "node" for which node.left.parent != node. This method is useful for + * debugging. RedBlackNode subclasses may want to override assertSubtreeIsValid() to call assertOrderIsValid. + */ + public void assertSubtreeIsValid() { + assertSubtreeIsValidRedBlack(); + assertNodesAreValid(); + } + + /** + * Throws a RuntimeException if the nodes in the subtree rooted at this node are not in the specified order or they + * do not lie in the specified range. Assumes that the subtree rooted at this node is a valid binary tree, i.e. it + * has no repeated nodes other than leaf nodes. + * @param comparator A comparator indicating how the nodes should be ordered. + * @param start The lower limit for nodes in the subtree, if any. + * @param end The upper limit for nodes in the subtree, if any. + */ + private void assertOrderIsValid(Comparator comparator, N start, N end) { + if (!isLeaf()) { + @SuppressWarnings("unchecked") + N nThis = (N)this; + if (start != null && comparator.compare(nThis, start) < 0) { + throw new RuntimeException("The nodes are not ordered correctly"); + } + if (end != null && comparator.compare(nThis, end) > 0) { + throw new RuntimeException("The nodes are not ordered correctly"); + } + RedBlackNode leftNode = left; + RedBlackNode rightNode = right; + leftNode.assertOrderIsValid(comparator, start, nThis); + rightNode.assertOrderIsValid(comparator, nThis, end); + } + } + + /** + * Throws a RuntimeException if the nodes in the subtree rooted at this node are not in the specified order. + * Assumes that this is a valid binary tree, i.e. there are no repeated nodes other than leaf nodes. This method is + * useful for debugging. RedBlackNode subclasses may want to override assertSubtreeIsValid() to call + * assertOrderIsValid. + * @param comparator A comparator indicating how the nodes should be ordered. If this is null, we use the nodes' + * natural order, as in N.compareTo. + */ + public void assertOrderIsValid(Comparator comparator) { + if (comparator == null) { + comparator = naturalOrder(); + } + assertOrderIsValid(comparator, null, null); + } +} diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Reference.java b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Reference.java new file mode 100644 index 00000000..a25c167d --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Reference.java @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2016 btrekkie + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval; + +/** + * Wraps a value using reference equality. In other words, two references are equal only if their values are the same + * object instance, as in ==. + * @param The type of value. + */ +class Reference { + /** The value this wraps. */ + private final T value; + + public Reference(T value) { + this.value = value; + } + + public boolean equals(Object obj) { + if (!(obj instanceof Reference)) { + return false; + } + Reference reference = (Reference)obj; + return value == reference.value; + } + + @Override + public int hashCode() { + return System.identityHashCode(value); + } +} diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend new file mode 100644 index 00000000..cbd7e71f --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend @@ -0,0 +1,140 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval + +import com.google.common.collect.HashMultiset +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationMode +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationOperator +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalRedBlackNode +import java.math.BigDecimal +import java.util.Random +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* + +class SumTest { + val aggregator = new IntervalAggregationOperator(IntervalAggregationMode.SUM) + var IntervalRedBlackNode value = null + + @Before + def void reset() { + value = aggregator.createNeutral + } + + @Test + def void emptyTest() { + assertEquals(null) + } + + @Test + def void addSingleTest() { + add(between(-1, 1)) + assertEquals(between(-1, 1)) + } + + @Test + def void addRemoveTest() { + add(between(-1, 1)) + remove(between(-1, 1)) + assertEquals(null) + } + + @Test + def void addTwoTest() { + add(between(-1, 1)) + add(above(2)) + assertEquals(above(1)) + } + + @Test + def void addTwoRemoveFirstTest() { + add(between(-1, 1)) + add(above(2)) + remove(between(-1, 1)) + assertEquals(above(2)) + } + + @Test + def void addTwoRemoveSecondTest() { + add(between(-1, 1)) + add(above(2)) + remove(above(2)) + assertEquals(between(-1, 1)) + } + + @Test + def void addMultiplicityTest() { + add(between(-1, 1)) + add(between(-1, 1)) + add(between(-1, 1)) + assertEquals(between(-3, 3)) + } + + @Test + def void removeAllMultiplicityTest() { + add(between(-1, 1)) + add(between(-1, 1)) + add(between(-1, 1)) + remove(between(-1, 1)) + remove(between(-1, 1)) + remove(between(-1, 1)) + assertEquals(null) + } + + @Test + def void removeSomeMultiplicityTest() { + add(between(-1, 1)) + add(between(-1, 1)) + add(between(-1, 1)) + remove(between(-1, 1)) + remove(between(-1, 1)) + assertEquals(between(-1, 1)) + } + + @Test + def void largeTest() { + val starts = #[null, new BigDecimal(-3), new BigDecimal(-2), new BigDecimal(-1)] + val ends = #[new BigDecimal(1), new BigDecimal(2), new BigDecimal(3), null] + val current = HashMultiset.create + val random = new Random(1) + for (var int i = 0; i < 1000; i++) { + val start = starts.get(random.nextInt(starts.size)) + val end = ends.get(random.nextInt(ends.size)) + val interval = Interval.of(start, end) + val isInsert = !current.contains(interval) || random.nextInt(3) == 0 + if (isInsert) { + current.add(interval) + } else { + current.remove(interval) + } + val expected = current.stream.reduce(aggregator.mode).orElse(null) + update(interval, isInsert) + assertEquals(expected) + } + } + + private def update(Interval interval, boolean isInsert) { + value = aggregator.update(value, interval, isInsert) + val nodes = newArrayList + var node = value.min + while (node !== null) { + nodes += node + node = node.successor + } + value.assertSubtreeIsValid + } + + private def add(Interval interval) { + update(interval, true) + } + + private def remove(Interval interval) { + update(interval, false) + } + + private def assertEquals(Interval interval) { + val actual = aggregator.getAggregate(value) + Assert.assertEquals(interval, actual) + } +} -- cgit v1.2.3-70-g09d2 From 5a55d0d306e85a697aa86bdf3f9caf243d384faa Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Fri, 10 May 2019 00:01:57 -0400 Subject: Neutral element for sum is [0, 0] --- .../logic2viatra/interval/IntervalAggregationMode.java | 9 +++++++++ .../logic2viatra/interval/IntervalAggregationOperator.xtend | 4 ++-- .../viatrasolver/logic2viatra/tests/interval/SumTest.xtend | 8 ++++---- 3 files changed, 15 insertions(+), 6 deletions(-) (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval') diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationMode.java b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationMode.java index f5bd2efc..66dcb00f 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationMode.java +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationMode.java @@ -21,6 +21,11 @@ public enum IntervalAggregationMode implements BinaryOperator { } }; } + + @Override + public Interval getNeutral() { + return Interval.ZERO; + } }, JOIN("intervalJoin", "Calculate the smallest interval containing all the intervals in a set") { @@ -63,4 +68,8 @@ public enum IntervalAggregationMode implements BinaryOperator { } public abstract IntervalRedBlackNode createNode(Interval interval); + + public Interval getNeutral() { + return Interval.EMPTY; + } } diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationOperator.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationOperator.xtend index 940c71bb..21d3d73b 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationOperator.xtend +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationOperator.xtend @@ -36,13 +36,13 @@ class IntervalAggregationOperator implements IMultisetAggregationOperator stream) { - stream.reduce(mode).orElse(null) + stream.reduce(mode).orElse(mode.neutral) } } diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend index cbd7e71f..530c081c 100644 --- a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend @@ -24,7 +24,7 @@ class SumTest { @Test def void emptyTest() { - assertEquals(null) + assertEquals(ZERO) } @Test @@ -37,7 +37,7 @@ class SumTest { def void addRemoveTest() { add(between(-1, 1)) remove(between(-1, 1)) - assertEquals(null) + assertEquals(ZERO) } @Test @@ -79,7 +79,7 @@ class SumTest { remove(between(-1, 1)) remove(between(-1, 1)) remove(between(-1, 1)) - assertEquals(null) + assertEquals(ZERO) } @Test @@ -108,7 +108,7 @@ class SumTest { } else { current.remove(interval) } - val expected = current.stream.reduce(aggregator.mode).orElse(null) + val expected = current.stream.reduce(aggregator.mode).orElse(ZERO) update(interval, isInsert) assertEquals(expected) } -- cgit v1.2.3-70-g09d2 From 3ab6f907e44d993830bad4d25a8b03811731c481 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Fri, 10 May 2019 12:05:10 -0400 Subject: More aggregation operators --- .../META-INF/MANIFEST.MF | 1 + .../logic2viatra/interval/Interval.xtend | 54 +++++++++++++++++++--- .../interval/IntervalAggregationMode.java | 24 ++++++++++ .../aggregators/IntervalAggregatorFactory.xtend | 50 ++++++++++++++++++++ 4 files changed, 123 insertions(+), 6 deletions(-) create mode 100644 Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/aggregators/IntervalAggregatorFactory.xtend (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval') diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/META-INF/MANIFEST.MF b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/META-INF/MANIFEST.MF index 2bc35ae6..b2ee3981 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/META-INF/MANIFEST.MF +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/META-INF/MANIFEST.MF @@ -5,6 +5,7 @@ Bundle-SymbolicName: hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatraquery;s Bundle-Version: 1.0.0.qualifier Export-Package: hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra, hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval, + hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.aggregators, hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.patterns, hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.queries Require-Bundle: hu.bme.mit.inf.dslreasoner.logic.model;bundle-version="1.0.0", diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend index 229656c0..173be0be 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend @@ -53,6 +53,10 @@ abstract class Interval implements Comparable { other.mayBeLessThanOrEqual(this) } + abstract def Interval min(Interval other) + + abstract def Interval max(Interval other) + abstract def Interval join(Interval other) def operator_plus() { @@ -88,6 +92,14 @@ abstract class Interval implements Comparable { false } + override min(Interval other) { + EMPTY + } + + override max(Interval other) { + EMPTY + } + override join(Interval other) { other } @@ -173,7 +185,7 @@ abstract class Interval implements Comparable { switch (other) { case EMPTY: true NonEmpty: lower == upper && lower == other.lower && lower == other.upper - default: throw new IllegalArgumentException("") + default: throw new IllegalArgumentException("Unknown interval: " + other) } } @@ -190,7 +202,7 @@ abstract class Interval implements Comparable { switch (other) { case EMPTY: true NonEmpty: upper !== null && other.lower !== null && upper < other.lower - default: throw new IllegalArgumentException("") + default: throw new IllegalArgumentException("Unknown interval: " + other) } } @@ -202,11 +214,41 @@ abstract class Interval implements Comparable { } } + override min(Interval other) { + switch (other) { + case EMPTY: this + NonEmpty: min(other) + default: throw new IllegalArgumentException("Unknown interval: " + other) + } + } + + def min(NonEmpty other) { + new NonEmpty( + lower.tryMin(other.lower), + if (other.upper === null) upper else upper?.min(other.upper) + ) + } + + override max(Interval other) { + switch (other) { + case EMPTY: this + NonEmpty: max(other) + default: throw new IllegalArgumentException("Unknown interval: " + other) + } + } + + def max(NonEmpty other) { + new NonEmpty( + if (other.lower === null) lower else lower?.min(other.lower), + upper.tryMax(other.upper) + ) + } + override join(Interval other) { switch (other) { case EMPTY: this NonEmpty: new NonEmpty(lower.tryMin(other.lower), upper.tryMax(other.upper)) - default: throw new IllegalArgumentException("") + default: throw new IllegalArgumentException("Unknown interval: " + other) } } @@ -218,7 +260,7 @@ abstract class Interval implements Comparable { switch (other) { case EMPTY: EMPTY NonEmpty: operator_plus(other) - default: throw new IllegalArgumentException("") + default: throw new IllegalArgumentException("Unknown interval: " + other) } } @@ -241,7 +283,7 @@ abstract class Interval implements Comparable { switch (other) { case EMPTY: EMPTY NonEmpty: operator_minus(other) - default: throw new IllegalArgumentException("") + default: throw new IllegalArgumentException("Unknown interval: " + other) } } @@ -369,7 +411,7 @@ abstract class Interval implements Comparable { switch (other) { case EMPTY: EMPTY NonEmpty: operator_divide(other) - default: throw new IllegalArgumentException("") + default: throw new IllegalArgumentException("Unknown interval: " + other) } } diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationMode.java b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationMode.java index 66dcb00f..f106e305 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationMode.java +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalAggregationMode.java @@ -27,6 +27,30 @@ public enum IntervalAggregationMode implements BinaryOperator { return Interval.ZERO; } }, + + MIN("intervalMin", "Find the minimum a set of intervals") { + @Override + public IntervalRedBlackNode createNode(Interval interval) { + return new IntervalRedBlackNode(interval) { + @Override + public Interval op(Interval left, Interval right) { + return left.min(right); + } + }; + } + }, + + MAX("intervalMax", "Find the maximum a set of intervals") { + @Override + public IntervalRedBlackNode createNode(Interval interval) { + return new IntervalRedBlackNode(interval) { + @Override + public Interval op(Interval left, Interval right) { + return left.max(right); + } + }; + } + }, JOIN("intervalJoin", "Calculate the smallest interval containing all the intervals in a set") { @Override diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/aggregators/IntervalAggregatorFactory.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/aggregators/IntervalAggregatorFactory.xtend new file mode 100644 index 00000000..2b6059da --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/aggregators/IntervalAggregatorFactory.xtend @@ -0,0 +1,50 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.aggregators + +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationMode +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationOperator +import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.AggregatorType +import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.BoundAggregator +import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.IAggregatorFactory +import org.eclipse.xtend.lib.annotations.FinalFieldsConstructor + +@AggregatorType(parameterTypes = #[Interval], returnTypes = #[Interval]) +abstract class IntervalAggregatorFactory implements IAggregatorFactory { + val IntervalAggregationMode mode + + @FinalFieldsConstructor + protected new() { + } + + override getAggregatorLogic(Class domainClass) { + if (domainClass == Interval) { + new BoundAggregator(new IntervalAggregationOperator(mode), Interval, Interval) + } else { + throw new IllegalArgumentException("Unknown domain class: " + domainClass) + } + } +} + +class intervalSum extends IntervalAggregatorFactory { + new() { + super(IntervalAggregationMode.SUM) + } +} + +class intervalMin extends IntervalAggregatorFactory { + new() { + super(IntervalAggregationMode.MIN) + } +} + +class intervalMax extends IntervalAggregatorFactory { + new() { + super(IntervalAggregationMode.MAX) + } +} + +class intervalJoin extends IntervalAggregatorFactory { + new() { + super(IntervalAggregationMode.JOIN) + } +} -- cgit v1.2.3-70-g09d2 From 9670538a0e5630edecab8aaf4ba38ae6c81e8606 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Fri, 10 May 2019 17:27:13 -0400 Subject: Interval power and aggregator fix --- .../logic2viatra/interval/Interval.xtend | 123 ++++++++++++------ .../tests/interval/MinAggregatorTest.xtend | 67 ++++++++++ .../logic2viatra/tests/interval/PowerTest.xtend | 43 +++++++ .../tests/interval/SumAggregatorTest.xtend | 140 +++++++++++++++++++++ .../logic2viatra/tests/interval/SumTest.xtend | 140 --------------------- 5 files changed, 337 insertions(+), 176 deletions(-) create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MinAggregatorTest.xtend create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/PowerTest.xtend create mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumAggregatorTest.xtend delete mode 100644 Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval') diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend index 173be0be..4f0f594f 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend @@ -54,26 +54,28 @@ abstract class Interval implements Comparable { } abstract def Interval min(Interval other) - + abstract def Interval max(Interval other) abstract def Interval join(Interval other) - def operator_plus() { + def +() { this } - abstract def Interval operator_minus() + abstract def Interval -() + + abstract def Interval +(Interval other) - abstract def Interval operator_plus(Interval other) + abstract def Interval -(Interval other) - abstract def Interval operator_minus(Interval other) + abstract def Interval *(int count) - abstract def Interval operator_multiply(int count) + abstract def Interval *(Interval other) - abstract def Interval operator_multiply(Interval other) + abstract def Interval /(Interval other) - abstract def Interval operator_divide(Interval other) + abstract def Interval **(Interval other) public static val EMPTY = new Interval { override mustEqual(Interval other) { @@ -95,7 +97,7 @@ abstract class Interval implements Comparable { override min(Interval other) { EMPTY } - + override max(Interval other) { EMPTY } @@ -104,27 +106,31 @@ abstract class Interval implements Comparable { other } - override operator_minus() { + override -() { EMPTY } - override operator_plus(Interval other) { + override +(Interval other) { EMPTY } - override operator_minus(Interval other) { + override -(Interval other) { EMPTY } - override operator_multiply(int count) { + override *(int count) { EMPTY } - override operator_multiply(Interval other) { + override *(Interval other) { EMPTY } - override operator_divide(Interval other) { + override /(Interval other) { + EMPTY + } + + override **(Interval other) { EMPTY } @@ -221,14 +227,14 @@ abstract class Interval implements Comparable { default: throw new IllegalArgumentException("Unknown interval: " + other) } } - + def min(NonEmpty other) { new NonEmpty( lower.tryMin(other.lower), - if (other.upper === null) upper else upper?.min(other.upper) + if(other.upper === null) upper else if(upper === null) other.upper else upper.min(other.upper) ) } - + override max(Interval other) { switch (other) { case EMPTY: this @@ -236,10 +242,10 @@ abstract class Interval implements Comparable { default: throw new IllegalArgumentException("Unknown interval: " + other) } } - + def max(NonEmpty other) { new NonEmpty( - if (other.lower === null) lower else lower?.min(other.lower), + if(other.lower === null) lower else if(lower === null) other.lower else lower.max(other.lower), upper.tryMax(other.upper) ) } @@ -252,19 +258,19 @@ abstract class Interval implements Comparable { } } - override operator_minus() { + override -() { new NonEmpty(upper?.negate(ROUND_DOWN), lower?.negate(ROUND_UP)) } - override operator_plus(Interval other) { + override +(Interval other) { switch (other) { case EMPTY: EMPTY - NonEmpty: operator_plus(other) + NonEmpty: this + other default: throw new IllegalArgumentException("Unknown interval: " + other) } } - def operator_plus(NonEmpty other) { + def +(NonEmpty other) { new NonEmpty( lower.tryAdd(other.lower, ROUND_DOWN), upper.tryAdd(other.upper, ROUND_UP) @@ -279,15 +285,15 @@ abstract class Interval implements Comparable { } } - override operator_minus(Interval other) { + override -(Interval other) { switch (other) { case EMPTY: EMPTY - NonEmpty: operator_minus(other) + NonEmpty: this - other default: throw new IllegalArgumentException("Unknown interval: " + other) } } - def operator_minus(NonEmpty other) { + def -(NonEmpty other) { new NonEmpty( lower.trySubtract(other.upper, ROUND_DOWN), upper.trySubtract(other.lower, ROUND_UP) @@ -302,7 +308,7 @@ abstract class Interval implements Comparable { } } - override operator_multiply(int count) { + override *(int count) { val bigCount = new BigDecimal(count) new NonEmpty( lower.tryMultiply(bigCount, ROUND_DOWN), @@ -310,15 +316,15 @@ abstract class Interval implements Comparable { ) } - override operator_multiply(Interval other) { + override *(Interval other) { switch (other) { case EMPTY: EMPTY - NonEmpty: operator_multiply(other) - default: throw new IllegalArgumentException("") + NonEmpty: this * other + default: throw new IllegalArgumentException("Unknown interval: " + other) } } - def operator_multiply(NonEmpty other) { + def *(NonEmpty other) { if (this == ZERO || other == ZERO) { ZERO } else if (nonpositive) { @@ -407,15 +413,15 @@ abstract class Interval implements Comparable { } } - override operator_divide(Interval other) { + override /(Interval other) { switch (other) { case EMPTY: EMPTY - NonEmpty: operator_divide(other) + NonEmpty: this / other default: throw new IllegalArgumentException("Unknown interval: " + other) } } - def operator_divide(NonEmpty other) { + def /(NonEmpty other) { if (other == ZERO) { EMPTY } else if (this == ZERO) { @@ -493,6 +499,51 @@ abstract class Interval implements Comparable { } } + override **(Interval other) { + switch (other) { + case EMPTY: EMPTY + NonEmpty: this ** other + default: throw new IllegalArgumentException("Unknown interval: " + other) + } + } + + def **(NonEmpty other) { + // XXX This should use proper rounding for log and exp instead of + // converting to double. + // XXX We should not ignore (integer) powers of negative numbers. + val lowerLog = if (lower === null || lower <= BigDecimal.ZERO) { + null + } else { + new BigDecimal(Math.log(lower.doubleValue), ROUND_DOWN) + } + val upperLog = if (upper === null) { + null + } else if (upper == BigDecimal.ZERO) { + return ZERO + } else if (upper < BigDecimal.ZERO) { + return EMPTY + } else { + new BigDecimal(Math.log(upper.doubleValue), ROUND_UP) + } + val log = new NonEmpty(lowerLog, upperLog) + val product = log * other + if (product instanceof NonEmpty) { + val lowerResult = if (product.lower === null) { + BigDecimal.ZERO + } else { + new BigDecimal(Math.exp(product.lower.doubleValue), ROUND_DOWN) + } + val upperResult = if (product.upper === null) { + null + } else { + new BigDecimal(Math.exp(product.upper.doubleValue), ROUND_UP) + } + new NonEmpty(lowerResult, upperResult) + } else { + throw new IllegalArgumentException("Unknown interval: " + product) + } + } + override toString() { '''«IF lower === null»(-∞«ELSE»[«lower»«ENDIF», «IF upper === null»∞)«ELSE»«upper»]«ENDIF»''' } @@ -501,7 +552,7 @@ abstract class Interval implements Comparable { switch (o) { case EMPTY: 1 NonEmpty: compareTo(o) - default: throw new IllegalArgumentException("") + default: throw new IllegalArgumentException("Unknown interval: " + o) } } diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MinAggregatorTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MinAggregatorTest.xtend new file mode 100644 index 00000000..7d46e16c --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/MinAggregatorTest.xtend @@ -0,0 +1,67 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval + +import com.google.common.collect.HashMultiset +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationMode +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationOperator +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalRedBlackNode +import java.math.BigDecimal +import java.util.Random +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* + +class MinAggregatorTest { + val aggregator = new IntervalAggregationOperator(IntervalAggregationMode.MIN) + var IntervalRedBlackNode value = null + + @Before + def void reset() { + value = aggregator.createNeutral + } + + @Test + def void emptyTest() { + assertEquals(EMPTY) + } + + @Test + def void largeTest() { + val starts = #[null, new BigDecimal(-3), new BigDecimal(-2), new BigDecimal(-1)] + val ends = #[new BigDecimal(1), new BigDecimal(2), new BigDecimal(3), null] + val current = HashMultiset.create + val random = new Random(1) + for (var int i = 0; i < 1000; i++) { + val start = starts.get(random.nextInt(starts.size)) + val end = ends.get(random.nextInt(ends.size)) + val interval = Interval.of(start, end) + val isInsert = !current.contains(interval) || random.nextInt(3) == 0 + if (isInsert) { + current.add(interval) + } else { + current.remove(interval) + } + val expected = current.stream.reduce(aggregator.mode).orElse(EMPTY) + update(interval, isInsert) + assertEquals(expected) + } + } + + private def update(Interval interval, boolean isInsert) { + value = aggregator.update(value, interval, isInsert) + val nodes = newArrayList + var node = value.min + while (node !== null) { + nodes += node + node = node.successor + } + value.assertSubtreeIsValid + } + + private def assertEquals(Interval interval) { + val actual = aggregator.getAggregate(value) + Assert.assertEquals(interval, actual) + } +} diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/PowerTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/PowerTest.xtend new file mode 100644 index 00000000..c842d90d --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/PowerTest.xtend @@ -0,0 +1,43 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval + +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval +import java.util.Collection +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameter +import org.junit.runners.Parameterized.Parameters + +import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* + +@RunWith(Parameterized) +class PowerTest { + @Parameters(name="{index}: {0} ** {1} = {2}") + static def Collection data() { + #[ + #[EMPTY, EMPTY, EMPTY], + #[EMPTY, between(-1, 1), EMPTY], + #[between(-1, 1), EMPTY, EMPTY], + #[upTo(-1), between(-1, 2), EMPTY], + #[upTo(0), between(-1, 2), between(0, 0)], + #[upTo(2), between(-1, 2), above(0)], + #[upTo(2), between(1, 2), between(0, 4)], + #[above(1), between(1, 2), above(1)], + #[between(2, 4), upTo(1), between(0, 4)], + #[between(0.25, 0.5), upTo(1), above(0.25)], + #[between(2, 3), above(1), above(2)], + #[between(0.25, 0.5), above(1), between(0, 0.5)], + #[between(1, 2), between(-1, 2), between(0.5, 4)] + ] + } + + @Parameter(0) public var Interval a + @Parameter(1) public var Interval b + @Parameter(2) public var Interval result + + @Test + def void powerTest() { + Assert.assertEquals(result, a ** b) + } +} diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumAggregatorTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumAggregatorTest.xtend new file mode 100644 index 00000000..56172b6c --- /dev/null +++ b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumAggregatorTest.xtend @@ -0,0 +1,140 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval + +import com.google.common.collect.HashMultiset +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationMode +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationOperator +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalRedBlackNode +import java.math.BigDecimal +import java.util.Random +import org.junit.Assert +import org.junit.Before +import org.junit.Test + +import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* + +class SumAggregatorTest { + val aggregator = new IntervalAggregationOperator(IntervalAggregationMode.SUM) + var IntervalRedBlackNode value = null + + @Before + def void reset() { + value = aggregator.createNeutral + } + + @Test + def void emptyTest() { + assertEquals(ZERO) + } + + @Test + def void addSingleTest() { + add(between(-1, 1)) + assertEquals(between(-1, 1)) + } + + @Test + def void addRemoveTest() { + add(between(-1, 1)) + remove(between(-1, 1)) + assertEquals(ZERO) + } + + @Test + def void addTwoTest() { + add(between(-1, 1)) + add(above(2)) + assertEquals(above(1)) + } + + @Test + def void addTwoRemoveFirstTest() { + add(between(-1, 1)) + add(above(2)) + remove(between(-1, 1)) + assertEquals(above(2)) + } + + @Test + def void addTwoRemoveSecondTest() { + add(between(-1, 1)) + add(above(2)) + remove(above(2)) + assertEquals(between(-1, 1)) + } + + @Test + def void addMultiplicityTest() { + add(between(-1, 1)) + add(between(-1, 1)) + add(between(-1, 1)) + assertEquals(between(-3, 3)) + } + + @Test + def void removeAllMultiplicityTest() { + add(between(-1, 1)) + add(between(-1, 1)) + add(between(-1, 1)) + remove(between(-1, 1)) + remove(between(-1, 1)) + remove(between(-1, 1)) + assertEquals(ZERO) + } + + @Test + def void removeSomeMultiplicityTest() { + add(between(-1, 1)) + add(between(-1, 1)) + add(between(-1, 1)) + remove(between(-1, 1)) + remove(between(-1, 1)) + assertEquals(between(-1, 1)) + } + + @Test + def void largeTest() { + val starts = #[null, new BigDecimal(-3), new BigDecimal(-2), new BigDecimal(-1)] + val ends = #[new BigDecimal(1), new BigDecimal(2), new BigDecimal(3), null] + val current = HashMultiset.create + val random = new Random(1) + for (var int i = 0; i < 1000; i++) { + val start = starts.get(random.nextInt(starts.size)) + val end = ends.get(random.nextInt(ends.size)) + val interval = Interval.of(start, end) + val isInsert = !current.contains(interval) || random.nextInt(3) == 0 + if (isInsert) { + current.add(interval) + } else { + current.remove(interval) + } + val expected = current.stream.reduce(aggregator.mode).orElse(ZERO) + update(interval, isInsert) + assertEquals(expected) + } + } + + private def update(Interval interval, boolean isInsert) { + value = aggregator.update(value, interval, isInsert) + val nodes = newArrayList + var node = value.min + while (node !== null) { + nodes += node + node = node.successor + } + value.assertSubtreeIsValid + } + + private def add(Interval interval) { + update(interval, true) + } + + private def remove(Interval interval) { + update(interval, false) + } + + private def assertEquals(Interval interval) { + val actual = aggregator.getAggregate(value) + Assert.assertEquals(interval, actual) + } +} diff --git a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend b/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend deleted file mode 100644 index 530c081c..00000000 --- a/Tests/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/tests/interval/SumTest.xtend +++ /dev/null @@ -1,140 +0,0 @@ -package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.tests.interval - -import com.google.common.collect.HashMultiset -import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval -import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationMode -import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalAggregationOperator -import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalRedBlackNode -import java.math.BigDecimal -import java.util.Random -import org.junit.Assert -import org.junit.Before -import org.junit.Test - -import static hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval.* - -class SumTest { - val aggregator = new IntervalAggregationOperator(IntervalAggregationMode.SUM) - var IntervalRedBlackNode value = null - - @Before - def void reset() { - value = aggregator.createNeutral - } - - @Test - def void emptyTest() { - assertEquals(ZERO) - } - - @Test - def void addSingleTest() { - add(between(-1, 1)) - assertEquals(between(-1, 1)) - } - - @Test - def void addRemoveTest() { - add(between(-1, 1)) - remove(between(-1, 1)) - assertEquals(ZERO) - } - - @Test - def void addTwoTest() { - add(between(-1, 1)) - add(above(2)) - assertEquals(above(1)) - } - - @Test - def void addTwoRemoveFirstTest() { - add(between(-1, 1)) - add(above(2)) - remove(between(-1, 1)) - assertEquals(above(2)) - } - - @Test - def void addTwoRemoveSecondTest() { - add(between(-1, 1)) - add(above(2)) - remove(above(2)) - assertEquals(between(-1, 1)) - } - - @Test - def void addMultiplicityTest() { - add(between(-1, 1)) - add(between(-1, 1)) - add(between(-1, 1)) - assertEquals(between(-3, 3)) - } - - @Test - def void removeAllMultiplicityTest() { - add(between(-1, 1)) - add(between(-1, 1)) - add(between(-1, 1)) - remove(between(-1, 1)) - remove(between(-1, 1)) - remove(between(-1, 1)) - assertEquals(ZERO) - } - - @Test - def void removeSomeMultiplicityTest() { - add(between(-1, 1)) - add(between(-1, 1)) - add(between(-1, 1)) - remove(between(-1, 1)) - remove(between(-1, 1)) - assertEquals(between(-1, 1)) - } - - @Test - def void largeTest() { - val starts = #[null, new BigDecimal(-3), new BigDecimal(-2), new BigDecimal(-1)] - val ends = #[new BigDecimal(1), new BigDecimal(2), new BigDecimal(3), null] - val current = HashMultiset.create - val random = new Random(1) - for (var int i = 0; i < 1000; i++) { - val start = starts.get(random.nextInt(starts.size)) - val end = ends.get(random.nextInt(ends.size)) - val interval = Interval.of(start, end) - val isInsert = !current.contains(interval) || random.nextInt(3) == 0 - if (isInsert) { - current.add(interval) - } else { - current.remove(interval) - } - val expected = current.stream.reduce(aggregator.mode).orElse(ZERO) - update(interval, isInsert) - assertEquals(expected) - } - } - - private def update(Interval interval, boolean isInsert) { - value = aggregator.update(value, interval, isInsert) - val nodes = newArrayList - var node = value.min - while (node !== null) { - nodes += node - node = node.successor - } - value.assertSubtreeIsValid - } - - private def add(Interval interval) { - update(interval, true) - } - - private def remove(Interval interval) { - update(interval, false) - } - - private def assertEquals(Interval interval) { - val actual = aggregator.getAggregate(value) - Assert.assertEquals(interval, actual) - } -} -- cgit v1.2.3-70-g09d2 From 1f60bda44172f1dedaf30785c88163ba6c36a0b9 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Sat, 18 May 2019 14:25:00 -0400 Subject: Interval hull aggregation operator --- .../logic2viatra/interval/Interval.xtend | 4 +- .../interval/IntervalHullAggregatorOperator.xtend | 87 ++++++++++++++++++++++ .../aggregators/IntervalAggregatorFactory.xtend | 6 +- .../interval/aggregators/intervalHull.xtend | 74 ++++++++++++++++++ 4 files changed, 166 insertions(+), 5 deletions(-) create mode 100644 Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalHullAggregatorOperator.xtend create mode 100644 Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/aggregators/intervalHull.xtend (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval') diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend index 4f0f594f..691c8783 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/Interval.xtend @@ -7,8 +7,8 @@ import org.eclipse.xtend.lib.annotations.Data abstract class Interval implements Comparable { static val PRECISION = 32 - static val ROUND_DOWN = new MathContext(PRECISION, RoundingMode.FLOOR) - static val ROUND_UP = new MathContext(PRECISION, RoundingMode.CEILING) + package static val ROUND_DOWN = new MathContext(PRECISION, RoundingMode.FLOOR) + package static val ROUND_UP = new MathContext(PRECISION, RoundingMode.CEILING) private new() { } diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalHullAggregatorOperator.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalHullAggregatorOperator.xtend new file mode 100644 index 00000000..ce48eca1 --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/IntervalHullAggregatorOperator.xtend @@ -0,0 +1,87 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval + +import java.math.BigDecimal +import java.math.MathContext +import java.util.SortedMap +import java.util.TreeMap +import java.util.stream.Stream +import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.IMultisetAggregationOperator + +abstract class IntervalHullAggregatorOperator> implements IMultisetAggregationOperator, Interval> { + protected new() { + } + + override getName() { + "intervalHull" + } + + override getShortDescription() { + "Calculates the interval hull of a set of numbers" + } + + override createNeutral() { + new TreeMap + } + + override getAggregate(SortedMap result) { + if (result.neutral) { + Interval.EMPTY + } else { + toInterval(result.firstKey, result.lastKey) + } + } + + protected abstract def BigDecimal toBigDecimal(T value, MathContext mc) + + private def toInterval(T min, T max) { + Interval.of(min.toBigDecimal(Interval.ROUND_DOWN), max.toBigDecimal(Interval.ROUND_UP)) + } + + override isNeutral(SortedMap result) { + result.empty + } + + override update(SortedMap oldResult, T updateValue, boolean isInsertion) { + if (isInsertion) { + oldResult.compute(updateValue) [ key, value | + if (value === null) { + 1 + } else if (value > 0) { + value + 1 + } else { + throw new IllegalStateException("Invalid count: " + value) + } + ] + } else { + oldResult.compute(updateValue) [ key, value | + if (value === 1) { + null + } else if (value > 1) { + value - 1 + } else { + throw new IllegalStateException("Invalid count: " + value) + } + ] + } + oldResult + } + + override aggregateStream(Stream stream) { + val iterator = stream.iterator + if (!iterator.hasNext) { + return Interval.EMPTY + } + var min = iterator.next + var max = min + while (iterator.hasNext) { + val element = iterator.next + if (element.compareTo(min) < 0) { + min = element + } + if (element.compareTo(max) > 0) { + max = element + } + } + toInterval(min, max) + } +} diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/aggregators/IntervalAggregatorFactory.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/aggregators/IntervalAggregatorFactory.xtend index 2b6059da..dee31f67 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/aggregators/IntervalAggregatorFactory.xtend +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/aggregators/IntervalAggregatorFactory.xtend @@ -8,14 +8,14 @@ import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.BoundAggre import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.IAggregatorFactory import org.eclipse.xtend.lib.annotations.FinalFieldsConstructor -@AggregatorType(parameterTypes = #[Interval], returnTypes = #[Interval]) +@AggregatorType(parameterTypes=#[Interval], returnTypes=#[Interval]) abstract class IntervalAggregatorFactory implements IAggregatorFactory { val IntervalAggregationMode mode - + @FinalFieldsConstructor protected new() { } - + override getAggregatorLogic(Class domainClass) { if (domainClass == Interval) { new BoundAggregator(new IntervalAggregationOperator(mode), Interval, Interval) diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/aggregators/intervalHull.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/aggregators/intervalHull.xtend new file mode 100644 index 00000000..72605f57 --- /dev/null +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/interval/aggregators/intervalHull.xtend @@ -0,0 +1,74 @@ +package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.aggregators + +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.Interval +import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval.IntervalHullAggregatorOperator +import java.math.BigDecimal +import java.math.BigInteger +import java.math.MathContext +import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.AggregatorType +import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.BoundAggregator +import org.eclipse.viatra.query.runtime.matchers.psystem.aggregations.IAggregatorFactory + +@AggregatorType(parameterTypes=#[BigDecimal, BigInteger, Byte, Double, Float, Integer, Long, Short], returnTypes=#[ + Interval, Interval, Interval, Interval, Interval, Interval, Interval, Interval]) +class intervalHull implements IAggregatorFactory { + + override getAggregatorLogic(Class domainClass) { + new BoundAggregator(getAggregationOperator(domainClass), domainClass, Interval) + } + + private def getAggregationOperator(Class domainClass) { + switch (domainClass) { + case BigDecimal: + new IntervalHullAggregatorOperator() { + override protected toBigDecimal(BigDecimal value, MathContext mc) { + value.round(mc) + } + } + case BigInteger: + new IntervalHullAggregatorOperator() { + override protected toBigDecimal(BigInteger value, MathContext mc) { + new BigDecimal(value, mc) + } + } + case Byte: + new IntervalHullAggregatorOperator() { + override protected toBigDecimal(Byte value, MathContext mc) { + new BigDecimal(value, mc) + } + } + case Double: + new IntervalHullAggregatorOperator() { + override protected toBigDecimal(Double value, MathContext mc) { + new BigDecimal(value, mc) + } + } + case Float: + new IntervalHullAggregatorOperator() { + override protected toBigDecimal(Float value, MathContext mc) { + new BigDecimal(value, mc) + } + } + case Integer: + new IntervalHullAggregatorOperator() { + override protected toBigDecimal(Integer value, MathContext mc) { + new BigDecimal(value, mc) + } + } + case Long: + new IntervalHullAggregatorOperator() { + override protected toBigDecimal(Long value, MathContext mc) { + new BigDecimal(value, mc) + } + } + case Short: + new IntervalHullAggregatorOperator() { + override protected toBigDecimal(Short value, MathContext mc) { + new BigDecimal(value, mc) + } + } + default: + throw new IllegalArgumentException("Unknown domain class: " + domainClass) + } + } +} -- cgit v1.2.3-70-g09d2