aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/logic/src/main/java/tools/refinery/logic/literal/AbstractCallLiteral.java
blob: 9ae84547576aa6025cea080afac6a7aaebb42973 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
/*
 * SPDX-FileCopyrightText: 2021-2023 The Refinery Authors <https://refinery.tools/>
 *
 * SPDX-License-Identifier: EPL-2.0
 */
package tools.refinery.logic.literal;

import tools.refinery.logic.Constraint;
import tools.refinery.logic.InvalidQueryException;
import tools.refinery.logic.equality.LiteralEqualityHelper;
import tools.refinery.logic.equality.LiteralHashCodeHelper;
import tools.refinery.logic.substitution.Substitution;
import tools.refinery.logic.term.ParameterDirection;
import tools.refinery.logic.term.Variable;

import java.util.*;

// {@link Object#equals(Object)} is implemented by {@link AbstractLiteral}.
@SuppressWarnings("squid:S2160")
public abstract class AbstractCallLiteral extends tools.refinery.logic.literal.AbstractLiteral {
	private final Constraint target;
	private final List<Variable> arguments;
	private final Set<Variable> inArguments;
	private final Set<Variable> outArguments;

	// Use exhaustive switch over enums.
	@SuppressWarnings("squid:S1301")
	protected AbstractCallLiteral(Constraint target, List<Variable> arguments) {
		int arity = target.arity();
		if (arguments.size() != arity) {
			throw new InvalidQueryException("%s needs %d arguments, but got %s".formatted(target.name(),
					target.arity(), arguments.size()));
		}
		this.target = target;
		this.arguments = arguments;
		var mutableInArguments = new LinkedHashSet<Variable>();
		var mutableOutArguments = new LinkedHashSet<Variable>();
		var parameters = target.getParameters();
		for (int i = 0; i < arity; i++) {
			var argument = arguments.get(i);
			var parameter = parameters.get(i);
			if (!parameter.isAssignable(argument)) {
				throw new InvalidQueryException("Argument %d of %s is not assignable to parameter %s"
						.formatted(i, target, parameter));
			}
			switch (parameter.getDirection()) {
			case IN -> {
				mutableOutArguments.remove(argument);
				mutableInArguments.add(argument);
			}
			case OUT -> {
				if (!mutableInArguments.contains(argument)) {
					mutableOutArguments.add(argument);
				}
			}
			}
		}
		inArguments = Collections.unmodifiableSet(mutableInArguments);
		outArguments = Collections.unmodifiableSet(mutableOutArguments);
	}

	public Constraint getTarget() {
		return target;
	}

	public List<Variable> getArguments() {
		return arguments;
	}

	protected Set<Variable> getArgumentsOfDirection(ParameterDirection direction) {
		return switch (direction) {
			case IN -> inArguments;
			case OUT -> outArguments;
		};
	}

	@Override
	public Set<Variable> getInputVariables(Set<? extends Variable> positiveVariablesInClause) {
		var inputVariables = new LinkedHashSet<>(getArgumentsOfDirection(ParameterDirection.OUT));
		inputVariables.retainAll(positiveVariablesInClause);
		inputVariables.addAll(getArgumentsOfDirection(ParameterDirection.IN));
		return Collections.unmodifiableSet(inputVariables);
	}

	@Override
	public Set<Variable> getPrivateVariables(Set<? extends Variable> positiveVariablesInClause) {
		var privateVariables = new LinkedHashSet<>(getArgumentsOfDirection(ParameterDirection.OUT));
		privateVariables.removeAll(positiveVariablesInClause);
		return Collections.unmodifiableSet(privateVariables);
	}

	@Override
	public tools.refinery.logic.literal.Literal substitute(Substitution substitution) {
		var substitutedArguments = arguments.stream().map(substitution::getSubstitute).toList();
		return doSubstitute(substitution, substitutedArguments);
	}

	protected abstract tools.refinery.logic.literal.Literal doSubstitute(Substitution substitution, List<Variable> substitutedArguments);

	public AbstractCallLiteral withTarget(Constraint newTarget) {
		if (Objects.equals(target, newTarget)) {
			return this;
		}
		return withArguments(newTarget, arguments);
	}

	public abstract AbstractCallLiteral withArguments(Constraint newTarget, List<Variable> newArguments);

	@Override
	public boolean equalsWithSubstitution(LiteralEqualityHelper helper, tools.refinery.logic.literal.Literal other) {
		if (!super.equalsWithSubstitution(helper, other)) {
			return false;
		}
		var otherCallLiteral = (AbstractCallLiteral) other;
		var arity = arguments.size();
		if (arity != otherCallLiteral.arguments.size()) {
			return false;
		}
		for (int i = 0; i < arity; i++) {
			if (!helper.variableEqual(arguments.get(i), otherCallLiteral.arguments.get(i))) {
				return false;
			}
		}
		return target.equals(helper, otherCallLiteral.target);
	}

	@Override
	public int hashCodeWithSubstitution(LiteralHashCodeHelper helper) {
		int result = super.hashCodeWithSubstitution(helper) * 31 + target.hashCode();
		for (var argument : arguments) {
			result = result * 31 + helper.getVariableHashCode(argument);
		}
		return result;
	}
}