package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.interval
import java.math.BigDecimal
import java.math.MathContext
import java.math.RoundingMode
import org.eclipse.xtend.lib.annotations.Data
abstract class Interval {
static val PRECISION = 32
static val ROUND_DOWN = new MathContext(PRECISION, RoundingMode.FLOOR)
static val ROUND_UP = new MathContext(PRECISION, RoundingMode.CEILING)
private new() {
}
abstract def boolean mustEqual(Interval other)
abstract def boolean mayEqual(Interval other)
def mustNotEqual(Interval other) {
!mayEqual(other)
}
def mayNotEqual(Interval other) {
!mustEqual(other)
}
abstract def boolean mustBeLessThan(Interval other)
abstract def boolean mayBeLessThan(Interval other)
def mustBeLessThanOrEqual(Interval other) {
!mayBeGreaterThan(other)
}
def mayBeLessThanOrEqual(Interval other) {
!mustBeGreaterThan(other)
}
def mustBeGreaterThan(Interval other) {
other.mustBeLessThan(this)
}
def mayBeGreaterThan(Interval other) {
other.mayBeLessThan(this)
}
def mustBeGreaterThanOrEqual(Interval other) {
other.mustBeLessThanOrEqual(this)
}
def mayBeGreaterThanOrEqual(Interval other) {
other.mayBeLessThanOrEqual(this)
}
abstract def Interval join(Interval other)
def operator_plus() {
this
}
abstract def Interval operator_minus()
abstract def Interval operator_plus(Interval other)
abstract def Interval operator_minus(Interval other)
abstract def Interval operator_multiply(Interval other)
abstract def Interval operator_divide(Interval other)
public static val EMPTY = new Interval {
override mustEqual(Interval other) {
true
}
override mayEqual(Interval other) {
false
}
override mustBeLessThan(Interval other) {
true
}
override mayBeLessThan(Interval other) {
false
}
override join(Interval other) {
other
}
override operator_minus() {
EMPTY
}
override operator_plus(Interval other) {
EMPTY
}
override operator_minus(Interval other) {
EMPTY
}
override operator_multiply(Interval other) {
EMPTY
}
override operator_divide(Interval other) {
EMPTY
}
override toString() {
"∅"
}
}
public static val Interval ZERO = new NonEmpty(BigDecimal.ZERO, BigDecimal.ZERO)
public static val Interval UNBOUNDED = new NonEmpty(null, null)
static def Interval of(BigDecimal lower, BigDecimal upper) {
new NonEmpty(lower, upper)
}
static def between(double lower, double upper) {
of(new BigDecimal(lower, ROUND_DOWN), new BigDecimal(upper, ROUND_UP))
}
static def upTo(double upper) {
of(null, new BigDecimal(upper, ROUND_UP))
}
static def above(double lower) {
of(new BigDecimal(lower, ROUND_DOWN), null)
}
@Data
private static class NonEmpty extends Interval {
val BigDecimal lower
val BigDecimal upper
/**
* Construct a new non-empty interval.
*
* @param lower The lower bound of the interval. Use null
for negative infinity.
* @param upper The upper bound of the interval. Use null
for positive infinity.
*/
new(BigDecimal lower, BigDecimal upper) {
if (lower !== null && upper !== null && lower > upper) {
throw new IllegalArgumentException("Lower bound of interval must not be larger than upper bound")
}
this.lower = lower
this.upper = upper
}
override mustEqual(Interval other) {
switch (other) {
case EMPTY: true
NonEmpty: lower == upper && lower == other.lower && lower == other.upper
default: throw new IllegalArgumentException("")
}
}
override mayEqual(Interval other) {
if (other instanceof NonEmpty) {
(lower === null || other.upper === null || lower <= other.upper) &&
(other.lower === null || upper === null || other.lower <= upper)
} else {
false
}
}
override mustBeLessThan(Interval other) {
switch (other) {
case EMPTY: true
NonEmpty: upper !== null && other.lower !== null && upper < other.lower
default: throw new IllegalArgumentException("")
}
}
override mayBeLessThan(Interval other) {
if (other instanceof NonEmpty) {
lower === null || other.upper === null || lower < other.upper
} else {
false
}
}
override join(Interval other) {
switch (other) {
case EMPTY: this
NonEmpty: new NonEmpty(lower.tryMin(other.lower), upper.tryMin(other.upper))
default: throw new IllegalArgumentException("")
}
}
override operator_minus() {
new NonEmpty(upper?.negate(ROUND_DOWN), lower?.negate(ROUND_UP))
}
override operator_plus(Interval other) {
switch (other) {
case EMPTY: EMPTY
NonEmpty: operator_plus(other)
default: throw new IllegalArgumentException("")
}
}
def operator_plus(NonEmpty other) {
new NonEmpty(
lower.tryAdd(other.lower, ROUND_DOWN),
upper.tryAdd(other.upper, ROUND_UP)
)
}
private static def tryAdd(BigDecimal a, BigDecimal b, MathContext mc) {
if (b === null) {
null
} else {
a?.add(b, mc)
}
}
override operator_minus(Interval other) {
switch (other) {
case EMPTY: EMPTY
NonEmpty: operator_minus(other)
default: throw new IllegalArgumentException("")
}
}
def operator_minus(NonEmpty other) {
new NonEmpty(
lower.trySubtract(other.upper, ROUND_DOWN),
upper.trySubtract(other.lower, ROUND_UP)
)
}
private static def trySubtract(BigDecimal a, BigDecimal b, MathContext mc) {
if (b === null) {
null
} else {
a?.subtract(b, mc)
}
}
override operator_multiply(Interval other) {
switch (other) {
case EMPTY: EMPTY
NonEmpty: operator_multiply(other)
default: throw new IllegalArgumentException("")
}
}
def operator_multiply(NonEmpty other) {
if (this == ZERO || other == ZERO) {
ZERO
} else if (nonpositive) {
if (other.nonpositive) {
new NonEmpty(
upper.multiply(other.upper, ROUND_DOWN),
lower.tryMultiply(other.lower, ROUND_UP)
)
} else if (other.nonnegative) {
new NonEmpty(
lower.tryMultiply(other.upper, ROUND_DOWN),
upper.multiply(other.lower, ROUND_UP)
)
} else {
new NonEmpty(
lower.tryMultiply(other.upper, ROUND_DOWN),
lower.tryMultiply(other.lower, ROUND_UP)
)
}
} else if (nonnegative) {
if (other.nonpositive) {
new NonEmpty(
upper.tryMultiply(other.lower, ROUND_DOWN),
lower.multiply(other.upper, ROUND_UP)
)
} else if (other.nonnegative) {
new NonEmpty(
lower.multiply(other.lower, ROUND_DOWN),
upper.tryMultiply(other.upper, ROUND_UP)
)
} else {
new NonEmpty(
upper.tryMultiply(other.lower, ROUND_DOWN),
upper.tryMultiply(other.upper, ROUND_UP)
)
}
} else {
if (other.nonpositive) {
new NonEmpty(
upper.tryMultiply(other.lower, ROUND_DOWN),
lower.tryMultiply(other.lower, ROUND_UP)
)
} else if (other.nonnegative) {
new NonEmpty(
lower.tryMultiply(other.upper, ROUND_DOWN),
upper.tryMultiply(other.upper, ROUND_UP)
)
} else {
new NonEmpty(
lower.tryMultiply(other.upper, ROUND_DOWN).tryMin(upper.tryMultiply(other.lower, ROUND_DOWN)),
lower.tryMultiply(other.lower, ROUND_UP).tryMax(upper.tryMultiply(other.upper, ROUND_UP))
)
}
}
}
private def isNonpositive() {
upper !== null && upper <= BigDecimal.ZERO
}
private def isNonnegative() {
lower !== null && lower >= BigDecimal.ZERO
}
private static def tryMultiply(BigDecimal a, BigDecimal b, MathContext mc) {
if (b === null) {
null
} else {
a?.multiply(b, mc)
}
}
private static def tryMin(BigDecimal a, BigDecimal b) {
if (b === null) {
null
} else {
a?.min(b)
}
}
private static def tryMax(BigDecimal a, BigDecimal b) {
if (b === null) {
null
} else {
a?.max(b)
}
}
override operator_divide(Interval other) {
switch (other) {
case EMPTY: EMPTY
NonEmpty: operator_divide(other)
default: throw new IllegalArgumentException("")
}
}
def operator_divide(NonEmpty other) {
if (other == ZERO) {
EMPTY
} else if (this == ZERO) {
ZERO
} else if (other.strictlyNegative) {
if (nonpositive) {
new NonEmpty(
upper.tryDivide(other.lower, ROUND_DOWN),
lower.tryDivide(other.upper, ROUND_UP)
)
} else if (nonnegative) {
new NonEmpty(
upper.tryDivide(other.upper, ROUND_DOWN),
lower.tryDivide(other.lower, ROUND_UP)
)
} else { // lower < 0 < upper
new NonEmpty(
upper.tryDivide(other.upper, ROUND_DOWN),
lower.tryDivide(other.upper, ROUND_UP)
)
}
} else if (other.strictlyPositive) {
if (nonpositive) {
new NonEmpty(
lower.tryDivide(other.lower, ROUND_DOWN),
upper.tryDivide(other.upper, ROUND_UP)
)
} else if (nonnegative) {
new NonEmpty(
lower.tryDivide(other.upper, ROUND_DOWN),
upper.tryDivide(other.lower, ROUND_UP)
)
} else { // lower < 0 < upper
new NonEmpty(
lower.tryDivide(other.lower, ROUND_DOWN),
upper.tryDivide(other.lower, ROUND_UP)
)
}
} else { // other contains 0
if (other.lower == BigDecimal.ZERO) { // 0 == other.lower < other.upper, because [0, 0] was exluded earlier
if (nonpositive) {
new NonEmpty(null, upper.tryDivide(other.upper, ROUND_UP))
} else if (nonnegative) {
new NonEmpty(lower.tryDivide(other.upper, ROUND_DOWN), null)
} else { // lower < 0 < upper
UNBOUNDED
}
} else if (other.upper == BigDecimal.ZERO) { // other.lower < other.upper == 0
if (nonpositive) {
new NonEmpty(upper.tryDivide(other.lower, ROUND_DOWN), null)
} else if (nonnegative) {
new NonEmpty(null, lower.tryDivide(other.lower, ROUND_UP))
} else { // lower < 0 < upper
UNBOUNDED
}
} else { // other.lower < 0 < other.upper
UNBOUNDED
}
}
}
private def isStrictlyNegative() {
upper !== null && upper < BigDecimal.ZERO
}
private def isStrictlyPositive() {
lower !== null && lower > BigDecimal.ZERO
}
private static def tryDivide(BigDecimal a, BigDecimal b, MathContext mc) {
if (b === null) {
BigDecimal.ZERO
} else {
a?.divide(b, mc)
}
}
override toString() {
'''«IF lower === null»(-∞«ELSE»[«lower»«ENDIF», «IF upper === null»∞)«ELSE»«upper»]«ENDIF»'''
}
}
}