aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-query/src/main/java/tools/refinery/store/query/literal/LeftJoinLiteral.java
blob: bdddf120f0383dd4dc2c20cf48d0a71667d3abf9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/*
 * SPDX-FileCopyrightText: 2024 The Refinery Authors <https://refinery.tools/>
 *
 * SPDX-License-Identifier: EPL-2.0
 */
package tools.refinery.store.query.literal;

import tools.refinery.store.query.Constraint;
import tools.refinery.store.query.InvalidQueryException;
import tools.refinery.store.query.equality.LiteralEqualityHelper;
import tools.refinery.store.query.equality.LiteralHashCodeHelper;
import tools.refinery.store.query.substitution.Substitution;
import tools.refinery.store.query.term.ConstantTerm;
import tools.refinery.store.query.term.DataVariable;
import tools.refinery.store.query.term.ParameterDirection;
import tools.refinery.store.query.term.Variable;

import java.util.*;

// {@link Object#equals(Object)} is implemented by {@link AbstractLiteral}.
@SuppressWarnings("squid:S2160")
public class LeftJoinLiteral<T> extends AbstractCallLiteral {
	private final DataVariable<T> resultVariable;
	private final DataVariable<T> placeholderVariable;
	private final T defaultValue;

	public LeftJoinLiteral(DataVariable<T> resultVariable, DataVariable<T> placeholderVariable,
						   T defaultValue, Constraint target, List<Variable> arguments) {
		super(target, arguments);
		this.resultVariable = resultVariable;
		this.placeholderVariable = placeholderVariable;
		this.defaultValue = defaultValue;
		if (defaultValue == null) {
			throw new InvalidQueryException("Default value must not be null");
		}
		if (!resultVariable.getType().isInstance(defaultValue)) {
			throw new InvalidQueryException("Default value %s must be assignable to result variable %s type %s"
					.formatted(defaultValue, resultVariable, resultVariable.getType().getName()));
		}
		if (!getArgumentsOfDirection(ParameterDirection.OUT).contains(placeholderVariable)) {
			throw new InvalidQueryException(
					"Placeholder variable %s must be bound with direction %s in the argument list"
							.formatted(resultVariable, ParameterDirection.OUT));
		}
		if (arguments.contains(resultVariable)) {
			throw new InvalidQueryException("Result variable must not appear in the argument list");
		}
	}

	public DataVariable<T> getResultVariable() {
		return resultVariable;
	}

	public DataVariable<T> getPlaceholderVariable() {
		return placeholderVariable;
	}

	public T getDefaultValue() {
		return defaultValue;
	}

	@Override
	public Set<Variable> getOutputVariables() {
		return Set.of(resultVariable);
	}

	@Override
	public Set<Variable> getInputVariables(Set<? extends Variable> positiveVariablesInClause) {
		var inputVariables = new LinkedHashSet<>(getArguments());
		inputVariables.remove(placeholderVariable);
		return Collections.unmodifiableSet(inputVariables);
	}

	@Override
	public Set<Variable> getPrivateVariables(Set<? extends Variable> positiveVariablesInClause) {
		return Set.of(placeholderVariable);
	}

	@Override
	public Literal reduce() {
		var reduction = getTarget().getReduction();
		return switch (reduction) {
			case ALWAYS_FALSE -> resultVariable.assign(new ConstantTerm<>(resultVariable.getType(), defaultValue));
			case ALWAYS_TRUE -> throw new InvalidQueryException("Trying to left join an infinite set");
			case NOT_REDUCIBLE -> this;
		};
	}

	@Override
	protected Literal doSubstitute(Substitution substitution, List<Variable> substitutedArguments) {
		return new LeftJoinLiteral<>(substitution.getTypeSafeSubstitute(resultVariable),
				substitution.getTypeSafeSubstitute(placeholderVariable), defaultValue, getTarget(),
				substitutedArguments);
	}

	@Override
	public AbstractCallLiteral withArguments(Constraint newTarget, List<Variable> newArguments) {
		return new LeftJoinLiteral<>(resultVariable, placeholderVariable, defaultValue, newTarget, newArguments);
	}

	@Override
	public boolean equalsWithSubstitution(LiteralEqualityHelper helper, Literal other) {
		if (!super.equalsWithSubstitution(helper, other)) {
			return false;
		}
		var otherLeftJoinLiteral = (LeftJoinLiteral<?>) other;
		return helper.variableEqual(resultVariable, otherLeftJoinLiteral.resultVariable) &&
				helper.variableEqual(placeholderVariable, otherLeftJoinLiteral.placeholderVariable) &&
				Objects.equals(defaultValue, otherLeftJoinLiteral.defaultValue);
	}

	@Override
	public int hashCodeWithSubstitution(LiteralHashCodeHelper helper) {
		return Objects.hash(super.hashCodeWithSubstitution(helper), helper.getVariableHashCode(resultVariable),
				helper.getVariableHashCode(placeholderVariable), defaultValue);
	}

	@Override
	public String toString() {
		var builder = new StringBuilder();
		var argumentIterator = getArguments().iterator();
		if (argumentIterator.hasNext()) {
			appendArgument(builder, argumentIterator.next());
			while (argumentIterator.hasNext()) {
				builder.append(", ");
				appendArgument(builder, argumentIterator.next());
			}
		}
		builder.append(")");
		return builder.toString();
	}

	private void appendArgument(StringBuilder builder, Variable argument) {
		if (placeholderVariable.equals(argument)) {
			builder.append("@Default(").append(defaultValue).append(") ");
			argument = resultVariable;
		}
		builder.append(argument);
	}
}