diff options
Diffstat (limited to 'subprojects/store-query/src/main/java/tools/refinery/store/query/term/BinaryTerm.java')
-rw-r--r-- | subprojects/store-query/src/main/java/tools/refinery/store/query/term/BinaryTerm.java | 93 |
1 files changed, 93 insertions, 0 deletions
diff --git a/subprojects/store-query/src/main/java/tools/refinery/store/query/term/BinaryTerm.java b/subprojects/store-query/src/main/java/tools/refinery/store/query/term/BinaryTerm.java new file mode 100644 index 00000000..34f48ccc --- /dev/null +++ b/subprojects/store-query/src/main/java/tools/refinery/store/query/term/BinaryTerm.java | |||
@@ -0,0 +1,93 @@ | |||
1 | package tools.refinery.store.query.term; | ||
2 | |||
3 | import tools.refinery.store.query.equality.LiteralEqualityHelper; | ||
4 | import tools.refinery.store.query.substitution.Substitution; | ||
5 | import tools.refinery.store.query.valuation.Valuation; | ||
6 | |||
7 | import java.util.Collections; | ||
8 | import java.util.HashSet; | ||
9 | import java.util.Objects; | ||
10 | import java.util.Set; | ||
11 | |||
12 | public abstract class BinaryTerm<R, T1, T2> implements Term<R> { | ||
13 | private final Term<T1> left; | ||
14 | private final Term<T2> right; | ||
15 | |||
16 | protected BinaryTerm(Term<T1> left, Term<T2> right) { | ||
17 | if (!left.getType().equals(getLeftType())) { | ||
18 | throw new IllegalArgumentException("Expected left %s to be of type %s, got %s instead".formatted(left, | ||
19 | getLeftType().getName(), left.getType().getName())); | ||
20 | } | ||
21 | if (!right.getType().equals(getRightType())) { | ||
22 | throw new IllegalArgumentException("Expected right %s to be of type %s, got %s instead".formatted(right, | ||
23 | getRightType().getName(), right.getType().getName())); | ||
24 | } | ||
25 | this.left = left; | ||
26 | this.right = right; | ||
27 | } | ||
28 | |||
29 | public abstract Class<T1> getLeftType(); | ||
30 | |||
31 | public abstract Class<T2> getRightType(); | ||
32 | |||
33 | public Term<T1> getLeft() { | ||
34 | return left; | ||
35 | } | ||
36 | |||
37 | public Term<T2> getRight() { | ||
38 | return right; | ||
39 | } | ||
40 | |||
41 | @Override | ||
42 | public R evaluate(Valuation valuation) { | ||
43 | var leftValue = left.evaluate(valuation); | ||
44 | if (leftValue == null) { | ||
45 | return null; | ||
46 | } | ||
47 | var rightValue = right.evaluate(valuation); | ||
48 | if (rightValue == null) { | ||
49 | return null; | ||
50 | } | ||
51 | return doEvaluate(leftValue, rightValue); | ||
52 | } | ||
53 | |||
54 | protected abstract R doEvaluate(T1 leftValue, T2 rightValue); | ||
55 | |||
56 | @Override | ||
57 | public boolean equalsWithSubstitution(LiteralEqualityHelper helper, AnyTerm other) { | ||
58 | if (getClass() != other.getClass()) { | ||
59 | return false; | ||
60 | } | ||
61 | var otherBinaryTerm = (BinaryTerm<?, ?, ?>) other; | ||
62 | return left.equalsWithSubstitution(helper, otherBinaryTerm.left) && right.equalsWithSubstitution(helper, | ||
63 | otherBinaryTerm.right); | ||
64 | } | ||
65 | |||
66 | @Override | ||
67 | public Term<R> substitute(Substitution substitution) { | ||
68 | return doSubstitute(substitution, left.substitute(substitution), right.substitute(substitution)); | ||
69 | } | ||
70 | |||
71 | public abstract Term<R> doSubstitute(Substitution substitution, Term<T1> substitutedLeft, | ||
72 | Term<T2> substitutedRight); | ||
73 | |||
74 | @Override | ||
75 | public Set<AnyDataVariable> getInputVariables() { | ||
76 | var inputVariables = new HashSet<>(left.getInputVariables()); | ||
77 | inputVariables.addAll(right.getInputVariables()); | ||
78 | return Collections.unmodifiableSet(inputVariables); | ||
79 | } | ||
80 | |||
81 | @Override | ||
82 | public boolean equals(Object o) { | ||
83 | if (this == o) return true; | ||
84 | if (o == null || getClass() != o.getClass()) return false; | ||
85 | BinaryTerm<?, ?, ?> that = (BinaryTerm<?, ?, ?>) o; | ||
86 | return left.equals(that.left) && right.equals(that.right); | ||
87 | } | ||
88 | |||
89 | @Override | ||
90 | public int hashCode() { | ||
91 | return Objects.hash(getClass(), left, right); | ||
92 | } | ||
93 | } | ||