diff options
Diffstat (limited to 'subprojects/logic/src/main/java/tools/refinery/logic/dnf/FunctionalQuery.java')
-rw-r--r-- | subprojects/logic/src/main/java/tools/refinery/logic/dnf/FunctionalQuery.java | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/subprojects/logic/src/main/java/tools/refinery/logic/dnf/FunctionalQuery.java b/subprojects/logic/src/main/java/tools/refinery/logic/dnf/FunctionalQuery.java new file mode 100644 index 00000000..1df63fbd --- /dev/null +++ b/subprojects/logic/src/main/java/tools/refinery/logic/dnf/FunctionalQuery.java | |||
@@ -0,0 +1,126 @@ | |||
1 | /* | ||
2 | * SPDX-FileCopyrightText: 2021-2024 The Refinery Authors <https://refinery.tools/> | ||
3 | * | ||
4 | * SPDX-License-Identifier: EPL-2.0 | ||
5 | */ | ||
6 | package tools.refinery.logic.dnf; | ||
7 | |||
8 | import tools.refinery.logic.InvalidQueryException; | ||
9 | import tools.refinery.logic.literal.CallPolarity; | ||
10 | import tools.refinery.logic.term.Aggregator; | ||
11 | import tools.refinery.logic.term.AssignedValue; | ||
12 | import tools.refinery.logic.term.NodeVariable; | ||
13 | import tools.refinery.logic.term.Variable; | ||
14 | |||
15 | import java.util.ArrayList; | ||
16 | import java.util.List; | ||
17 | import java.util.Objects; | ||
18 | |||
19 | public final class FunctionalQuery<T> extends Query<T> { | ||
20 | private final Class<T> type; | ||
21 | |||
22 | FunctionalQuery(Dnf dnf, Class<T> type) { | ||
23 | super(dnf); | ||
24 | var parameters = dnf.getSymbolicParameters(); | ||
25 | int outputIndex = dnf.arity() - 1; | ||
26 | for (int i = 0; i < outputIndex; i++) { | ||
27 | var parameter = parameters.get(i); | ||
28 | var parameterType = parameter.tryGetType(); | ||
29 | if (parameterType.isPresent()) { | ||
30 | throw new InvalidQueryException("Expected parameter %s of %s to be a node variable, got %s instead" | ||
31 | .formatted(parameter, dnf, parameterType.get().getName())); | ||
32 | } | ||
33 | } | ||
34 | var outputParameter = parameters.get(outputIndex); | ||
35 | var outputParameterType = outputParameter.tryGetType(); | ||
36 | if (outputParameterType.isEmpty() || !outputParameterType.get().equals(type)) { | ||
37 | throw new InvalidQueryException("Expected parameter %s of %s to be %s, but got %s instead".formatted( | ||
38 | outputParameter, dnf, type, outputParameterType.map(Class::getName).orElse("node"))); | ||
39 | } | ||
40 | this.type = type; | ||
41 | } | ||
42 | |||
43 | @Override | ||
44 | public int arity() { | ||
45 | return getDnf().arity() - 1; | ||
46 | } | ||
47 | |||
48 | @Override | ||
49 | public Class<T> valueType() { | ||
50 | return type; | ||
51 | } | ||
52 | |||
53 | @Override | ||
54 | public T defaultValue() { | ||
55 | return null; | ||
56 | } | ||
57 | |||
58 | @Override | ||
59 | protected FunctionalQuery<T> withDnfInternal(Dnf newDnf) { | ||
60 | return newDnf.asFunction(type); | ||
61 | } | ||
62 | |||
63 | @Override | ||
64 | public FunctionalQuery<T> withDnf(Dnf newDnf) { | ||
65 | return (FunctionalQuery<T>) super.withDnf(newDnf); | ||
66 | } | ||
67 | |||
68 | public AssignedValue<T> call(List<NodeVariable> arguments) { | ||
69 | return targetVariable -> { | ||
70 | var argumentsWithTarget = new ArrayList<Variable>(arguments.size() + 1); | ||
71 | argumentsWithTarget.addAll(arguments); | ||
72 | argumentsWithTarget.add(targetVariable); | ||
73 | return getDnf().call(CallPolarity.POSITIVE, argumentsWithTarget); | ||
74 | }; | ||
75 | } | ||
76 | |||
77 | public AssignedValue<T> call(NodeVariable... arguments) { | ||
78 | return call(List.of(arguments)); | ||
79 | } | ||
80 | |||
81 | public <R> AssignedValue<R> aggregate(Aggregator<R, T> aggregator, List<NodeVariable> arguments) { | ||
82 | return targetVariable -> { | ||
83 | var placeholderVariable = Variable.of(type); | ||
84 | var argumentsWithPlaceholder = new ArrayList<Variable>(arguments.size() + 1); | ||
85 | argumentsWithPlaceholder.addAll(arguments); | ||
86 | argumentsWithPlaceholder.add(placeholderVariable); | ||
87 | return getDnf() | ||
88 | .aggregateBy(placeholderVariable, aggregator, argumentsWithPlaceholder) | ||
89 | .toLiteral(targetVariable); | ||
90 | }; | ||
91 | } | ||
92 | |||
93 | public <R> AssignedValue<R> aggregate(Aggregator<R, T> aggregator, NodeVariable... arguments) { | ||
94 | return aggregate(aggregator, List.of(arguments)); | ||
95 | } | ||
96 | |||
97 | public AssignedValue<T> leftJoin(T defaultValue, List<NodeVariable> arguments) { | ||
98 | return targetVariable -> { | ||
99 | var placeholderVariable = Variable.of(type); | ||
100 | var argumentsWithPlaceholder = new ArrayList<Variable>(arguments.size() + 1); | ||
101 | argumentsWithPlaceholder.addAll(arguments); | ||
102 | argumentsWithPlaceholder.add(placeholderVariable); | ||
103 | return getDnf() | ||
104 | .leftJoinBy(placeholderVariable, defaultValue, argumentsWithPlaceholder) | ||
105 | .toLiteral(targetVariable); | ||
106 | }; | ||
107 | } | ||
108 | |||
109 | public AssignedValue<T> leftJoin(T defaultValue, NodeVariable... arguments) { | ||
110 | return leftJoin(defaultValue, List.of(arguments)); | ||
111 | } | ||
112 | |||
113 | @Override | ||
114 | public boolean equals(Object o) { | ||
115 | if (this == o) return true; | ||
116 | if (o == null || getClass() != o.getClass()) return false; | ||
117 | if (!super.equals(o)) return false; | ||
118 | FunctionalQuery<?> that = (FunctionalQuery<?>) o; | ||
119 | return Objects.equals(type, that.type); | ||
120 | } | ||
121 | |||
122 | @Override | ||
123 | public int hashCode() { | ||
124 | return Objects.hash(super.hashCode(), type); | ||
125 | } | ||
126 | } | ||