aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/CallLiteral.java
blob: 091b4e043e7feeb457511c0ef0be439dd7846772 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package tools.refinery.store.query.literal;

import tools.refinery.store.query.RelationLike;
import tools.refinery.store.query.Variable;
import tools.refinery.store.query.equality.LiteralEqualityHelper;
import tools.refinery.store.query.substitution.Substitution;

import java.util.List;
import java.util.Objects;
import java.util.Set;

public abstract class CallLiteral<T extends RelationLike> implements Literal {
	private final CallPolarity polarity;
	private final T target;
	private final List<Variable> arguments;

	protected CallLiteral(CallPolarity polarity, T target, List<Variable> arguments) {
		if (arguments.size() != target.arity()) {
			throw new IllegalArgumentException("%s needs %d arguments, but got %s".formatted(target.name(),
					target.arity(), arguments.size()));
		}
		if (polarity.isTransitive() && target.arity() != 2) {
			throw new IllegalArgumentException("Transitive closures can only take binary relations");
		}
		this.polarity = polarity;
		this.target = target;
		this.arguments = arguments;
	}

	public CallPolarity getPolarity() {
		return polarity;
	}

	public abstract Class<T> getTargetType();

	public T getTarget() {
		return target;
	}

	public List<Variable> getArguments() {
		return arguments;
	}

	@Override
	public void collectAllVariables(Set<Variable> variables) {
		if (polarity.isPositive()) {
			variables.addAll(arguments);
		}
	}

	protected List<Variable> substituteArguments(Substitution substitution) {
		return arguments.stream().map(substitution::getSubstitute).toList();
	}

	/**
	 * Compares the target of this call literal with another object.
	 *
	 * @param helper      Equality helper for comparing {@link Variable} and {@link tools.refinery.store.query.Dnf}
	 *                    instances.
	 * @param otherTarget The object to compare the target to.
	 * @return {@code true} if {@code otherTarget} is equal to the return value of {@link #getTarget()} according to
	 * {@code helper}, {@code false} otherwise.
	 */
	protected boolean targetEquals(LiteralEqualityHelper helper, T otherTarget) {
		return target.equals(otherTarget);
	}

	@Override
	public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) {
		if (other.getClass() != getClass()) {
			return false;
		}
		var otherCallLiteral = (CallLiteral<?>) other;
		if (getTargetType() != otherCallLiteral.getTargetType() || polarity != otherCallLiteral.polarity) {
			return false;
		}
		var arity = arguments.size();
		if (arity != otherCallLiteral.arguments.size()) {
			return false;
		}
		for (int i = 0; i < arity; i++) {
			if (!helper.variableEqual(arguments.get(i), otherCallLiteral.arguments.get(i))) {
				return false;
			}
		}
		@SuppressWarnings("unchecked")
		var otherTarget = (T) otherCallLiteral.target;
		return targetEquals(helper, otherTarget);
	}

	@Override
	public boolean equals(Object o) {
		if (this == o) return true;
		if (o == null || getClass() != o.getClass()) return false;
		CallLiteral<?> callAtom = (CallLiteral<?>) o;
		return polarity == callAtom.polarity && Objects.equals(target, callAtom.target) &&
				Objects.equals(arguments, callAtom.arguments);
	}

	@Override
	public int hashCode() {
		return Objects.hash(polarity, target, arguments);
	}

	protected String targetToString() {
		return "@%s %s".formatted(getTargetType().getSimpleName(), target.name());
	}

	@Override
	public String toString() {
		var builder = new StringBuilder();
		if (!polarity.isPositive()) {
			builder.append("!(");
		}
		builder.append(targetToString());
		if (polarity.isTransitive()) {
			builder.append("+");
		}
		builder.append("(");
		var argumentIterator = arguments.iterator();
		if (argumentIterator.hasNext()) {
			builder.append(argumentIterator.next());
			while (argumentIterator.hasNext()) {
				builder.append(", ").append(argumentIterator.next());
			}
		}
		builder.append(")");
		if (!polarity.isPositive()) {
			builder.append(")");
		}
		return builder.toString();
	}
}