diff options
author | Kristóf Marussy <kristof@marussy.com> | 2023-07-25 16:06:36 +0200 |
---|---|---|
committer | Kristóf Marussy <kristof@marussy.com> | 2023-07-25 16:06:36 +0200 |
commit | 6a25ba145844c79d3507f8eabdbed854be2b8097 (patch) | |
tree | 0ea9d4c7a9b5b94a0d4341eaa25eeb7e4d3f4f56 /subprojects/store-reasoning/src/test/java | |
parent | feat: custom connected component RETE node (diff) | |
download | refinery-6a25ba145844c79d3507f8eabdbed854be2b8097.tar.gz refinery-6a25ba145844c79d3507f8eabdbed854be2b8097.tar.zst refinery-6a25ba145844c79d3507f8eabdbed854be2b8097.zip |
feat: concrete count in partial models
Diffstat (limited to 'subprojects/store-reasoning/src/test/java')
-rw-r--r-- | subprojects/store-reasoning/src/test/java/tools/refinery/store/reasoning/ConcreteCountTest.java | 324 |
1 files changed, 324 insertions, 0 deletions
diff --git a/subprojects/store-reasoning/src/test/java/tools/refinery/store/reasoning/ConcreteCountTest.java b/subprojects/store-reasoning/src/test/java/tools/refinery/store/reasoning/ConcreteCountTest.java new file mode 100644 index 00000000..21111d7c --- /dev/null +++ b/subprojects/store-reasoning/src/test/java/tools/refinery/store/reasoning/ConcreteCountTest.java | |||
@@ -0,0 +1,324 @@ | |||
1 | /* | ||
2 | * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/> | ||
3 | * | ||
4 | * SPDX-License-Identifier: EPL-2.0 | ||
5 | */ | ||
6 | package tools.refinery.store.reasoning; | ||
7 | |||
8 | import org.junit.jupiter.api.Test; | ||
9 | import tools.refinery.store.model.ModelStore; | ||
10 | import tools.refinery.store.query.ModelQueryAdapter; | ||
11 | import tools.refinery.store.query.dnf.Query; | ||
12 | import tools.refinery.store.query.resultset.ResultSet; | ||
13 | import tools.refinery.store.query.term.Variable; | ||
14 | import tools.refinery.store.query.viatra.ViatraModelQueryAdapter; | ||
15 | import tools.refinery.store.reasoning.literal.Concreteness; | ||
16 | import tools.refinery.store.reasoning.literal.CountLowerBoundLiteral; | ||
17 | import tools.refinery.store.reasoning.literal.CountUpperBoundLiteral; | ||
18 | import tools.refinery.store.reasoning.representation.PartialRelation; | ||
19 | import tools.refinery.store.reasoning.seed.ModelSeed; | ||
20 | import tools.refinery.store.reasoning.translator.PartialRelationTranslator; | ||
21 | import tools.refinery.store.reasoning.translator.multiobject.MultiObjectTranslator; | ||
22 | import tools.refinery.store.representation.Symbol; | ||
23 | import tools.refinery.store.representation.TruthValue; | ||
24 | import tools.refinery.store.representation.cardinality.CardinalityIntervals; | ||
25 | import tools.refinery.store.representation.cardinality.UpperCardinalities; | ||
26 | import tools.refinery.store.representation.cardinality.UpperCardinality; | ||
27 | import tools.refinery.store.tuple.Tuple; | ||
28 | |||
29 | import java.util.List; | ||
30 | |||
31 | import static org.hamcrest.MatcherAssert.assertThat; | ||
32 | import static org.hamcrest.Matchers.is; | ||
33 | import static tools.refinery.store.query.literal.Literals.not; | ||
34 | import static tools.refinery.store.reasoning.literal.PartialLiterals.must; | ||
35 | |||
36 | class ConcreteCountTest { | ||
37 | private static final PartialRelation person = new PartialRelation("Person", 1); | ||
38 | private static final PartialRelation friend = new PartialRelation("friend", 2); | ||
39 | |||
40 | @Test | ||
41 | void lowerBoundZeroTest() { | ||
42 | var query = Query.of("LowerBound", Integer.class, (builder, p1, p2, output) -> builder.clause( | ||
43 | must(person.call(p1)), | ||
44 | must(person.call(p2)), | ||
45 | new CountLowerBoundLiteral(output, Concreteness.PARTIAL, friend, List.of(p1, p2)) | ||
46 | )); | ||
47 | |||
48 | var modelSeed = ModelSeed.builder(2) | ||
49 | .seed(MultiObjectTranslator.COUNT_SYMBOL, builder -> builder | ||
50 | .put(Tuple.of(0), CardinalityIntervals.atLeast(3)) | ||
51 | .put(Tuple.of(1), CardinalityIntervals.atMost(7))) | ||
52 | .seed(person, builder -> builder.reducedValue(TruthValue.TRUE)) | ||
53 | .seed(friend, builder -> builder | ||
54 | .put(Tuple.of(0, 1), TruthValue.TRUE) | ||
55 | .put(Tuple.of(1, 0), TruthValue.UNKNOWN) | ||
56 | .put(Tuple.of(1, 1), TruthValue.ERROR)) | ||
57 | .build(); | ||
58 | |||
59 | var resultSet = getResultSet(query, modelSeed); | ||
60 | assertThat(resultSet.get(Tuple.of(0, 0)), is(0)); | ||
61 | assertThat(resultSet.get(Tuple.of(0, 1)), is(1)); | ||
62 | assertThat(resultSet.get(Tuple.of(1, 0)), is(0)); | ||
63 | assertThat(resultSet.get(Tuple.of(1, 1)), is(1)); | ||
64 | } | ||
65 | |||
66 | @Test | ||
67 | void upperBoundZeroTest() { | ||
68 | var query = Query.of("UpperBound", UpperCardinality.class, (builder, p1, p2, output) -> builder.clause( | ||
69 | must(person.call(p1)), | ||
70 | must(person.call(p2)), | ||
71 | new CountUpperBoundLiteral(output, Concreteness.PARTIAL, friend, List.of(p1, p2)) | ||
72 | )); | ||
73 | |||
74 | var modelSeed = ModelSeed.builder(2) | ||
75 | .seed(MultiObjectTranslator.COUNT_SYMBOL, builder -> builder | ||
76 | .put(Tuple.of(0), CardinalityIntervals.atLeast(3)) | ||
77 | .put(Tuple.of(1), CardinalityIntervals.atMost(7))) | ||
78 | .seed(person, builder -> builder.reducedValue(TruthValue.TRUE)) | ||
79 | .seed(friend, builder -> builder | ||
80 | .put(Tuple.of(0, 1), TruthValue.TRUE) | ||
81 | .put(Tuple.of(1, 0), TruthValue.UNKNOWN) | ||
82 | .put(Tuple.of(1, 1), TruthValue.ERROR)) | ||
83 | .build(); | ||
84 | |||
85 | var resultSet = getResultSet(query, modelSeed); | ||
86 | assertThat(resultSet.get(Tuple.of(0, 0)), is(UpperCardinalities.ZERO)); | ||
87 | assertThat(resultSet.get(Tuple.of(0, 1)), is(UpperCardinalities.ONE)); | ||
88 | assertThat(resultSet.get(Tuple.of(1, 0)), is(UpperCardinalities.ONE)); | ||
89 | assertThat(resultSet.get(Tuple.of(1, 1)), is(UpperCardinalities.ZERO)); | ||
90 | } | ||
91 | |||
92 | @Test | ||
93 | void lowerBoundOneTest() { | ||
94 | var query = Query.of("LowerBound", Integer.class, (builder, p1, output) -> builder.clause( | ||
95 | must(person.call(p1)), | ||
96 | new CountLowerBoundLiteral(output, Concreteness.PARTIAL, friend, List.of(p1, Variable.of())) | ||
97 | )); | ||
98 | |||
99 | var modelSeed = ModelSeed.builder(4) | ||
100 | .seed(MultiObjectTranslator.COUNT_SYMBOL, builder -> builder | ||
101 | .reducedValue(CardinalityIntervals.ONE) | ||
102 | .put(Tuple.of(1), CardinalityIntervals.atLeast(3)) | ||
103 | .put(Tuple.of(2), CardinalityIntervals.atMost(7))) | ||
104 | .seed(person, builder -> builder.reducedValue(TruthValue.TRUE)) | ||
105 | .seed(friend, builder -> builder | ||
106 | .put(Tuple.of(0, 1), TruthValue.TRUE) | ||
107 | .put(Tuple.of(0, 2), TruthValue.TRUE) | ||
108 | .put(Tuple.of(0, 3), TruthValue.TRUE) | ||
109 | .put(Tuple.of(1, 0), TruthValue.TRUE) | ||
110 | .put(Tuple.of(1, 2), TruthValue.UNKNOWN) | ||
111 | .put(Tuple.of(1, 3), TruthValue.UNKNOWN) | ||
112 | .put(Tuple.of(2, 0), TruthValue.TRUE) | ||
113 | .put(Tuple.of(2, 1), TruthValue.ERROR)) | ||
114 | .build(); | ||
115 | |||
116 | var resultSet = getResultSet(query, modelSeed); | ||
117 | assertThat(resultSet.get(Tuple.of(0)), is(4)); | ||
118 | assertThat(resultSet.get(Tuple.of(1)), is(1)); | ||
119 | assertThat(resultSet.get(Tuple.of(2)), is(4)); | ||
120 | assertThat(resultSet.get(Tuple.of(3)), is(0)); | ||
121 | } | ||
122 | |||
123 | @Test | ||
124 | void upperBoundOneTest() { | ||
125 | var query = Query.of("UpperBound", UpperCardinality.class, (builder, p1, output) -> builder.clause( | ||
126 | must(person.call(p1)), | ||
127 | new CountUpperBoundLiteral(output, Concreteness.PARTIAL, friend, List.of(p1, Variable.of())) | ||
128 | )); | ||
129 | |||
130 | var modelSeed = ModelSeed.builder(4) | ||
131 | .seed(MultiObjectTranslator.COUNT_SYMBOL, builder -> builder | ||
132 | .reducedValue(CardinalityIntervals.ONE) | ||
133 | .put(Tuple.of(1), CardinalityIntervals.atLeast(3)) | ||
134 | .put(Tuple.of(2), CardinalityIntervals.atMost(7))) | ||
135 | .seed(person, builder -> builder.reducedValue(TruthValue.TRUE)) | ||
136 | .seed(friend, builder -> builder | ||
137 | .put(Tuple.of(0, 1), TruthValue.TRUE) | ||
138 | .put(Tuple.of(0, 2), TruthValue.TRUE) | ||
139 | .put(Tuple.of(0, 3), TruthValue.TRUE) | ||
140 | .put(Tuple.of(1, 0), TruthValue.TRUE) | ||
141 | .put(Tuple.of(1, 2), TruthValue.UNKNOWN) | ||
142 | .put(Tuple.of(1, 3), TruthValue.UNKNOWN) | ||
143 | .put(Tuple.of(2, 0), TruthValue.TRUE) | ||
144 | .put(Tuple.of(2, 1), TruthValue.ERROR)) | ||
145 | .build(); | ||
146 | |||
147 | var resultSet = getResultSet(query, modelSeed); | ||
148 | assertThat(resultSet.get(Tuple.of(0)), is(UpperCardinalities.UNBOUNDED)); | ||
149 | assertThat(resultSet.get(Tuple.of(1)), is(UpperCardinalities.atMost(9))); | ||
150 | assertThat(resultSet.get(Tuple.of(2)), is(UpperCardinalities.ONE)); | ||
151 | assertThat(resultSet.get(Tuple.of(3)), is(UpperCardinalities.ZERO)); | ||
152 | } | ||
153 | |||
154 | @Test | ||
155 | void lowerBoundTwoTest() { | ||
156 | var subQuery = Query.of("SubQuery", (builder, p1, p2, p3) -> builder.clause( | ||
157 | friend.call(p1, p2), | ||
158 | friend.call(p1, p3), | ||
159 | friend.call(p2, p3) | ||
160 | )); | ||
161 | var query = Query.of("LowerBound", Integer.class, (builder, p1, output) -> builder.clause( | ||
162 | must(person.call(p1)), | ||
163 | new CountLowerBoundLiteral(output, Concreteness.PARTIAL, subQuery.getDnf(), | ||
164 | List.of(p1, Variable.of(), Variable.of())) | ||
165 | )); | ||
166 | |||
167 | var modelSeed = ModelSeed.builder(4) | ||
168 | .seed(MultiObjectTranslator.COUNT_SYMBOL, builder -> builder | ||
169 | .reducedValue(CardinalityIntervals.ONE) | ||
170 | .put(Tuple.of(0), CardinalityIntervals.between(5, 9)) | ||
171 | .put(Tuple.of(1), CardinalityIntervals.atLeast(3)) | ||
172 | .put(Tuple.of(2), CardinalityIntervals.atMost(7))) | ||
173 | .seed(person, builder -> builder.reducedValue(TruthValue.TRUE)) | ||
174 | .seed(friend, builder -> builder | ||
175 | .put(Tuple.of(0, 1), TruthValue.TRUE) | ||
176 | .put(Tuple.of(0, 2), TruthValue.TRUE) | ||
177 | .put(Tuple.of(0, 3), TruthValue.TRUE) | ||
178 | .put(Tuple.of(1, 0), TruthValue.TRUE) | ||
179 | .put(Tuple.of(1, 2), TruthValue.TRUE) | ||
180 | .put(Tuple.of(1, 3), TruthValue.TRUE) | ||
181 | .put(Tuple.of(2, 0), TruthValue.TRUE) | ||
182 | .put(Tuple.of(2, 1), TruthValue.ERROR)) | ||
183 | .build(); | ||
184 | |||
185 | var resultSet = getResultSet(query, modelSeed); | ||
186 | assertThat(resultSet.get(Tuple.of(0)), is(3)); | ||
187 | assertThat(resultSet.get(Tuple.of(1)), is(5)); | ||
188 | assertThat(resultSet.get(Tuple.of(2)), is(30)); | ||
189 | assertThat(resultSet.get(Tuple.of(3)), is(0)); | ||
190 | } | ||
191 | |||
192 | @Test | ||
193 | void upperBoundTwoTest() { | ||
194 | var subQuery = Query.of("SubQuery", (builder, p1, p2, p3) -> builder.clause( | ||
195 | friend.call(p1, p2), | ||
196 | friend.call(p1, p3), | ||
197 | friend.call(p2, p3) | ||
198 | )); | ||
199 | var query = Query.of("UpperBound", UpperCardinality.class, (builder, p1, output) -> builder.clause( | ||
200 | must(person.call(p1)), | ||
201 | new CountUpperBoundLiteral(output, Concreteness.PARTIAL, subQuery.getDnf(), | ||
202 | List.of(p1, Variable.of(), Variable.of())) | ||
203 | )); | ||
204 | |||
205 | var modelSeed = ModelSeed.builder(4) | ||
206 | .seed(MultiObjectTranslator.COUNT_SYMBOL, builder -> builder | ||
207 | .reducedValue(CardinalityIntervals.ONE) | ||
208 | .put(Tuple.of(0), CardinalityIntervals.between(5, 9)) | ||
209 | .put(Tuple.of(1), CardinalityIntervals.atLeast(3)) | ||
210 | .put(Tuple.of(2), CardinalityIntervals.atMost(7))) | ||
211 | .seed(person, builder -> builder.reducedValue(TruthValue.TRUE)) | ||
212 | .seed(friend, builder -> builder | ||
213 | .put(Tuple.of(0, 1), TruthValue.TRUE) | ||
214 | .put(Tuple.of(0, 2), TruthValue.TRUE) | ||
215 | .put(Tuple.of(0, 3), TruthValue.TRUE) | ||
216 | .put(Tuple.of(1, 0), TruthValue.TRUE) | ||
217 | .put(Tuple.of(1, 2), TruthValue.UNKNOWN) | ||
218 | .put(Tuple.of(1, 3), TruthValue.UNKNOWN) | ||
219 | .put(Tuple.of(2, 0), TruthValue.TRUE) | ||
220 | .put(Tuple.of(2, 1), TruthValue.ERROR)) | ||
221 | .build(); | ||
222 | |||
223 | var resultSet = getResultSet(query, modelSeed); | ||
224 | assertThat(resultSet.get(Tuple.of(0)), is(UpperCardinalities.UNBOUNDED)); | ||
225 | assertThat(resultSet.get(Tuple.of(1)), is(UpperCardinalities.atMost(135))); | ||
226 | assertThat(resultSet.get(Tuple.of(2)), is(UpperCardinalities.ZERO)); | ||
227 | assertThat(resultSet.get(Tuple.of(3)), is(UpperCardinalities.ZERO)); | ||
228 | } | ||
229 | |||
230 | @Test | ||
231 | void lowerBoundDiagonalTest() { | ||
232 | var subQuery = Query.of("SubQuery", (builder, p1, p2, p3) -> builder.clause( | ||
233 | friend.call(p1, p2), | ||
234 | friend.call(p1, p3), | ||
235 | not(friend.call(p2, p3)) | ||
236 | )); | ||
237 | var query = Query.of("LowerBound", Integer.class, (builder, p1, output) -> builder.clause(v1 -> List.of( | ||
238 | must(person.call(p1)), | ||
239 | new CountLowerBoundLiteral(output, Concreteness.PARTIAL, subQuery.getDnf(), List.of(p1, v1, v1)) | ||
240 | ))); | ||
241 | |||
242 | var modelSeed = ModelSeed.builder(4) | ||
243 | .seed(MultiObjectTranslator.COUNT_SYMBOL, builder -> builder | ||
244 | .reducedValue(CardinalityIntervals.ONE) | ||
245 | .put(Tuple.of(0), CardinalityIntervals.between(5, 9)) | ||
246 | .put(Tuple.of(1), CardinalityIntervals.atLeast(3)) | ||
247 | .put(Tuple.of(2), CardinalityIntervals.atMost(7))) | ||
248 | .seed(person, builder -> builder.reducedValue(TruthValue.TRUE)) | ||
249 | .seed(friend, builder -> builder | ||
250 | .put(Tuple.of(0, 1), TruthValue.TRUE) | ||
251 | .put(Tuple.of(0, 2), TruthValue.TRUE) | ||
252 | .put(Tuple.of(0, 3), TruthValue.TRUE) | ||
253 | .put(Tuple.of(1, 0), TruthValue.TRUE) | ||
254 | .put(Tuple.of(1, 2), TruthValue.UNKNOWN) | ||
255 | .put(Tuple.of(1, 3), TruthValue.UNKNOWN) | ||
256 | .put(Tuple.of(2, 0), TruthValue.TRUE) | ||
257 | .put(Tuple.of(2, 1), TruthValue.ERROR)) | ||
258 | .build(); | ||
259 | |||
260 | var resultSet = getResultSet(query, modelSeed); | ||
261 | assertThat(resultSet.get(Tuple.of(0)), is(4)); | ||
262 | assertThat(resultSet.get(Tuple.of(1)), is(5)); | ||
263 | assertThat(resultSet.get(Tuple.of(2)), is(8)); | ||
264 | assertThat(resultSet.get(Tuple.of(3)), is(0)); | ||
265 | } | ||
266 | |||
267 | @Test | ||
268 | void upperBoundDiagonalTest() { | ||
269 | var subQuery = Query.of("SubQuery", (builder, p1, p2, p3) -> builder.clause( | ||
270 | friend.call(p1, p2), | ||
271 | friend.call(p1, p3), | ||
272 | not(friend.call(p2, p3)) | ||
273 | )); | ||
274 | var query = Query.of("UpperBound", UpperCardinality.class, (builder, p1, output) -> builder | ||
275 | .clause(v1 -> List.of( | ||
276 | must(person.call(p1)), | ||
277 | new CountUpperBoundLiteral(output, Concreteness.PARTIAL, subQuery.getDnf(), | ||
278 | List.of(p1, v1, v1)) | ||
279 | ))); | ||
280 | |||
281 | var modelSeed = ModelSeed.builder(4) | ||
282 | .seed(MultiObjectTranslator.COUNT_SYMBOL, builder -> builder | ||
283 | .reducedValue(CardinalityIntervals.ONE) | ||
284 | .put(Tuple.of(0), CardinalityIntervals.between(5, 9)) | ||
285 | .put(Tuple.of(1), CardinalityIntervals.atLeast(3)) | ||
286 | .put(Tuple.of(2), CardinalityIntervals.atMost(7))) | ||
287 | .seed(person, builder -> builder.reducedValue(TruthValue.TRUE)) | ||
288 | .seed(friend, builder -> builder | ||
289 | .put(Tuple.of(0, 1), TruthValue.TRUE) | ||
290 | .put(Tuple.of(0, 2), TruthValue.TRUE) | ||
291 | .put(Tuple.of(0, 3), TruthValue.TRUE) | ||
292 | .put(Tuple.of(1, 0), TruthValue.TRUE) | ||
293 | .put(Tuple.of(1, 2), TruthValue.UNKNOWN) | ||
294 | .put(Tuple.of(1, 3), TruthValue.UNKNOWN) | ||
295 | .put(Tuple.of(2, 0), TruthValue.TRUE) | ||
296 | .put(Tuple.of(2, 1), TruthValue.ERROR)) | ||
297 | .build(); | ||
298 | |||
299 | var resultSet = getResultSet(query, modelSeed); | ||
300 | assertThat(resultSet.get(Tuple.of(0)), is(UpperCardinalities.UNBOUNDED)); | ||
301 | assertThat(resultSet.get(Tuple.of(1)), is(UpperCardinalities.atMost(17))); | ||
302 | assertThat(resultSet.get(Tuple.of(2)), is(UpperCardinalities.atMost(9))); | ||
303 | assertThat(resultSet.get(Tuple.of(3)), is(UpperCardinalities.ZERO)); | ||
304 | } | ||
305 | |||
306 | private static <T> ResultSet<T> getResultSet(Query<T> query, ModelSeed modelSeed) { | ||
307 | var personStorage = Symbol.of("Person", 1, TruthValue.class, TruthValue.FALSE); | ||
308 | var friendStorage = Symbol.of("friend", 2, TruthValue.class, TruthValue.FALSE); | ||
309 | |||
310 | var store = ModelStore.builder() | ||
311 | .with(ViatraModelQueryAdapter.builder() | ||
312 | .query(query)) | ||
313 | .with(ReasoningAdapter.builder()) | ||
314 | .with(new MultiObjectTranslator()) | ||
315 | .with(PartialRelationTranslator.of(person) | ||
316 | .symbol(personStorage)) | ||
317 | .with(PartialRelationTranslator.of(friend) | ||
318 | .symbol(friendStorage)) | ||
319 | .build(); | ||
320 | |||
321 | var model = store.getAdapter(ReasoningStoreAdapter.class).createInitialModel(modelSeed); | ||
322 | return model.getAdapter(ModelQueryAdapter.class).getResultSet(query); | ||
323 | } | ||
324 | } | ||