aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/logic/src/main/java/tools/refinery/logic/literal/AbstractCallLiteral.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/logic/src/main/java/tools/refinery/logic/literal/AbstractCallLiteral.java')
-rw-r--r--subprojects/logic/src/main/java/tools/refinery/logic/literal/AbstractCallLiteral.java135
1 files changed, 135 insertions, 0 deletions
diff --git a/subprojects/logic/src/main/java/tools/refinery/logic/literal/AbstractCallLiteral.java b/subprojects/logic/src/main/java/tools/refinery/logic/literal/AbstractCallLiteral.java
new file mode 100644
index 00000000..9ae84547
--- /dev/null
+++ b/subprojects/logic/src/main/java/tools/refinery/logic/literal/AbstractCallLiteral.java
@@ -0,0 +1,135 @@
1/*
2 * SPDX-FileCopyrightText: 2021-2023 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.logic.literal;
7
8import tools.refinery.logic.Constraint;
9import tools.refinery.logic.InvalidQueryException;
10import tools.refinery.logic.equality.LiteralEqualityHelper;
11import tools.refinery.logic.equality.LiteralHashCodeHelper;
12import tools.refinery.logic.substitution.Substitution;
13import tools.refinery.logic.term.ParameterDirection;
14import tools.refinery.logic.term.Variable;
15
16import java.util.*;
17
18// {@link Object#equals(Object)} is implemented by {@link AbstractLiteral}.
19@SuppressWarnings("squid:S2160")
20public abstract class AbstractCallLiteral extends tools.refinery.logic.literal.AbstractLiteral {
21 private final Constraint target;
22 private final List<Variable> arguments;
23 private final Set<Variable> inArguments;
24 private final Set<Variable> outArguments;
25
26 // Use exhaustive switch over enums.
27 @SuppressWarnings("squid:S1301")
28 protected AbstractCallLiteral(Constraint target, List<Variable> arguments) {
29 int arity = target.arity();
30 if (arguments.size() != arity) {
31 throw new InvalidQueryException("%s needs %d arguments, but got %s".formatted(target.name(),
32 target.arity(), arguments.size()));
33 }
34 this.target = target;
35 this.arguments = arguments;
36 var mutableInArguments = new LinkedHashSet<Variable>();
37 var mutableOutArguments = new LinkedHashSet<Variable>();
38 var parameters = target.getParameters();
39 for (int i = 0; i < arity; i++) {
40 var argument = arguments.get(i);
41 var parameter = parameters.get(i);
42 if (!parameter.isAssignable(argument)) {
43 throw new InvalidQueryException("Argument %d of %s is not assignable to parameter %s"
44 .formatted(i, target, parameter));
45 }
46 switch (parameter.getDirection()) {
47 case IN -> {
48 mutableOutArguments.remove(argument);
49 mutableInArguments.add(argument);
50 }
51 case OUT -> {
52 if (!mutableInArguments.contains(argument)) {
53 mutableOutArguments.add(argument);
54 }
55 }
56 }
57 }
58 inArguments = Collections.unmodifiableSet(mutableInArguments);
59 outArguments = Collections.unmodifiableSet(mutableOutArguments);
60 }
61
62 public Constraint getTarget() {
63 return target;
64 }
65
66 public List<Variable> getArguments() {
67 return arguments;
68 }
69
70 protected Set<Variable> getArgumentsOfDirection(ParameterDirection direction) {
71 return switch (direction) {
72 case IN -> inArguments;
73 case OUT -> outArguments;
74 };
75 }
76
77 @Override
78 public Set<Variable> getInputVariables(Set<? extends Variable> positiveVariablesInClause) {
79 var inputVariables = new LinkedHashSet<>(getArgumentsOfDirection(ParameterDirection.OUT));
80 inputVariables.retainAll(positiveVariablesInClause);
81 inputVariables.addAll(getArgumentsOfDirection(ParameterDirection.IN));
82 return Collections.unmodifiableSet(inputVariables);
83 }
84
85 @Override
86 public Set<Variable> getPrivateVariables(Set<? extends Variable> positiveVariablesInClause) {
87 var privateVariables = new LinkedHashSet<>(getArgumentsOfDirection(ParameterDirection.OUT));
88 privateVariables.removeAll(positiveVariablesInClause);
89 return Collections.unmodifiableSet(privateVariables);
90 }
91
92 @Override
93 public tools.refinery.logic.literal.Literal substitute(Substitution substitution) {
94 var substitutedArguments = arguments.stream().map(substitution::getSubstitute).toList();
95 return doSubstitute(substitution, substitutedArguments);
96 }
97
98 protected abstract tools.refinery.logic.literal.Literal doSubstitute(Substitution substitution, List<Variable> substitutedArguments);
99
100 public AbstractCallLiteral withTarget(Constraint newTarget) {
101 if (Objects.equals(target, newTarget)) {
102 return this;
103 }
104 return withArguments(newTarget, arguments);
105 }
106
107 public abstract AbstractCallLiteral withArguments(Constraint newTarget, List<Variable> newArguments);
108
109 @Override
110 public boolean equalsWithSubstitution(LiteralEqualityHelper helper, tools.refinery.logic.literal.Literal other) {
111 if (!super.equalsWithSubstitution(helper, other)) {
112 return false;
113 }
114 var otherCallLiteral = (AbstractCallLiteral) other;
115 var arity = arguments.size();
116 if (arity != otherCallLiteral.arguments.size()) {
117 return false;
118 }
119 for (int i = 0; i < arity; i++) {
120 if (!helper.variableEqual(arguments.get(i), otherCallLiteral.arguments.get(i))) {
121 return false;
122 }
123 }
124 return target.equals(helper, otherCallLiteral.target);
125 }
126
127 @Override
128 public int hashCodeWithSubstitution(LiteralHashCodeHelper helper) {
129 int result = super.hashCodeWithSubstitution(helper) * 31 + target.hashCode();
130 for (var argument : arguments) {
131 result = result * 31 + helper.getVariableHashCode(argument);
132 }
133 return result;
134 }
135}