aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/logic/src/main/java/tools/refinery/logic/dnf/FunctionalQuery.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/logic/src/main/java/tools/refinery/logic/dnf/FunctionalQuery.java')
-rw-r--r--subprojects/logic/src/main/java/tools/refinery/logic/dnf/FunctionalQuery.java126
1 files changed, 126 insertions, 0 deletions
diff --git a/subprojects/logic/src/main/java/tools/refinery/logic/dnf/FunctionalQuery.java b/subprojects/logic/src/main/java/tools/refinery/logic/dnf/FunctionalQuery.java
new file mode 100644
index 00000000..1df63fbd
--- /dev/null
+++ b/subprojects/logic/src/main/java/tools/refinery/logic/dnf/FunctionalQuery.java
@@ -0,0 +1,126 @@
1/*
2 * SPDX-FileCopyrightText: 2021-2024 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.logic.dnf;
7
8import tools.refinery.logic.InvalidQueryException;
9import tools.refinery.logic.literal.CallPolarity;
10import tools.refinery.logic.term.Aggregator;
11import tools.refinery.logic.term.AssignedValue;
12import tools.refinery.logic.term.NodeVariable;
13import tools.refinery.logic.term.Variable;
14
15import java.util.ArrayList;
16import java.util.List;
17import java.util.Objects;
18
19public final class FunctionalQuery<T> extends Query<T> {
20 private final Class<T> type;
21
22 FunctionalQuery(Dnf dnf, Class<T> type) {
23 super(dnf);
24 var parameters = dnf.getSymbolicParameters();
25 int outputIndex = dnf.arity() - 1;
26 for (int i = 0; i < outputIndex; i++) {
27 var parameter = parameters.get(i);
28 var parameterType = parameter.tryGetType();
29 if (parameterType.isPresent()) {
30 throw new InvalidQueryException("Expected parameter %s of %s to be a node variable, got %s instead"
31 .formatted(parameter, dnf, parameterType.get().getName()));
32 }
33 }
34 var outputParameter = parameters.get(outputIndex);
35 var outputParameterType = outputParameter.tryGetType();
36 if (outputParameterType.isEmpty() || !outputParameterType.get().equals(type)) {
37 throw new InvalidQueryException("Expected parameter %s of %s to be %s, but got %s instead".formatted(
38 outputParameter, dnf, type, outputParameterType.map(Class::getName).orElse("node")));
39 }
40 this.type = type;
41 }
42
43 @Override
44 public int arity() {
45 return getDnf().arity() - 1;
46 }
47
48 @Override
49 public Class<T> valueType() {
50 return type;
51 }
52
53 @Override
54 public T defaultValue() {
55 return null;
56 }
57
58 @Override
59 protected FunctionalQuery<T> withDnfInternal(Dnf newDnf) {
60 return newDnf.asFunction(type);
61 }
62
63 @Override
64 public FunctionalQuery<T> withDnf(Dnf newDnf) {
65 return (FunctionalQuery<T>) super.withDnf(newDnf);
66 }
67
68 public AssignedValue<T> call(List<NodeVariable> arguments) {
69 return targetVariable -> {
70 var argumentsWithTarget = new ArrayList<Variable>(arguments.size() + 1);
71 argumentsWithTarget.addAll(arguments);
72 argumentsWithTarget.add(targetVariable);
73 return getDnf().call(CallPolarity.POSITIVE, argumentsWithTarget);
74 };
75 }
76
77 public AssignedValue<T> call(NodeVariable... arguments) {
78 return call(List.of(arguments));
79 }
80
81 public <R> AssignedValue<R> aggregate(Aggregator<R, T> aggregator, List<NodeVariable> arguments) {
82 return targetVariable -> {
83 var placeholderVariable = Variable.of(type);
84 var argumentsWithPlaceholder = new ArrayList<Variable>(arguments.size() + 1);
85 argumentsWithPlaceholder.addAll(arguments);
86 argumentsWithPlaceholder.add(placeholderVariable);
87 return getDnf()
88 .aggregateBy(placeholderVariable, aggregator, argumentsWithPlaceholder)
89 .toLiteral(targetVariable);
90 };
91 }
92
93 public <R> AssignedValue<R> aggregate(Aggregator<R, T> aggregator, NodeVariable... arguments) {
94 return aggregate(aggregator, List.of(arguments));
95 }
96
97 public AssignedValue<T> leftJoin(T defaultValue, List<NodeVariable> arguments) {
98 return targetVariable -> {
99 var placeholderVariable = Variable.of(type);
100 var argumentsWithPlaceholder = new ArrayList<Variable>(arguments.size() + 1);
101 argumentsWithPlaceholder.addAll(arguments);
102 argumentsWithPlaceholder.add(placeholderVariable);
103 return getDnf()
104 .leftJoinBy(placeholderVariable, defaultValue, argumentsWithPlaceholder)
105 .toLiteral(targetVariable);
106 };
107 }
108
109 public AssignedValue<T> leftJoin(T defaultValue, NodeVariable... arguments) {
110 return leftJoin(defaultValue, List.of(arguments));
111 }
112
113 @Override
114 public boolean equals(Object o) {
115 if (this == o) return true;
116 if (o == null || getClass() != o.getClass()) return false;
117 if (!super.equals(o)) return false;
118 FunctionalQuery<?> that = (FunctionalQuery<?>) o;
119 return Objects.equals(type, that.type);
120 }
121
122 @Override
123 public int hashCode() {
124 return Objects.hash(super.hashCode(), type);
125 }
126}