aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/viatra-runtime-localsearch/src/main/java/tools/refinery/viatra/runtime/localsearch/planner/cost/impl/StatisticsBasedConstraintCostFunction.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/viatra-runtime-localsearch/src/main/java/tools/refinery/viatra/runtime/localsearch/planner/cost/impl/StatisticsBasedConstraintCostFunction.java')
-rw-r--r--subprojects/viatra-runtime-localsearch/src/main/java/tools/refinery/viatra/runtime/localsearch/planner/cost/impl/StatisticsBasedConstraintCostFunction.java413
1 files changed, 413 insertions, 0 deletions
diff --git a/subprojects/viatra-runtime-localsearch/src/main/java/tools/refinery/viatra/runtime/localsearch/planner/cost/impl/StatisticsBasedConstraintCostFunction.java b/subprojects/viatra-runtime-localsearch/src/main/java/tools/refinery/viatra/runtime/localsearch/planner/cost/impl/StatisticsBasedConstraintCostFunction.java
new file mode 100644
index 00000000..873be31d
--- /dev/null
+++ b/subprojects/viatra-runtime-localsearch/src/main/java/tools/refinery/viatra/runtime/localsearch/planner/cost/impl/StatisticsBasedConstraintCostFunction.java
@@ -0,0 +1,413 @@
1/**
2 * Copyright (c) 2010-2016, Grill Balázs, IncQuery Labs Ltd.
3 * This program and the accompanying materials are made available under the
4 * terms of the Eclipse Public License v. 2.0 which is available at
5 * http://www.eclipse.org/legal/epl-v20.html.
6 *
7 * SPDX-License-Identifier: EPL-2.0
8 */
9package tools.refinery.viatra.runtime.localsearch.planner.cost.impl;
10
11import static tools.refinery.viatra.runtime.matchers.planning.helpers.StatisticsHelper.min;
12
13import java.util.ArrayList;
14import java.util.Arrays;
15import java.util.Collection;
16import java.util.Collections;
17import java.util.List;
18import java.util.Map;
19import java.util.Optional;
20import java.util.Set;
21
22import tools.refinery.viatra.runtime.localsearch.matcher.integration.AbstractLocalSearchResultProvider;
23import tools.refinery.viatra.runtime.localsearch.planner.cost.IConstraintEvaluationContext;
24import tools.refinery.viatra.runtime.localsearch.planner.cost.ICostFunction;
25import tools.refinery.viatra.runtime.matchers.ViatraQueryRuntimeException;
26import tools.refinery.viatra.runtime.matchers.backend.IQueryResultProvider;
27import tools.refinery.viatra.runtime.matchers.context.IInputKey;
28import tools.refinery.viatra.runtime.matchers.planning.helpers.FunctionalDependencyHelper;
29import tools.refinery.viatra.runtime.matchers.psystem.IQueryReference;
30import tools.refinery.viatra.runtime.matchers.psystem.PConstraint;
31import tools.refinery.viatra.runtime.matchers.psystem.PVariable;
32import tools.refinery.viatra.runtime.matchers.psystem.analysis.QueryAnalyzer;
33import tools.refinery.viatra.runtime.matchers.psystem.basicdeferred.AggregatorConstraint;
34import tools.refinery.viatra.runtime.matchers.psystem.basicdeferred.ExportedParameter;
35import tools.refinery.viatra.runtime.matchers.psystem.basicdeferred.ExpressionEvaluation;
36import tools.refinery.viatra.runtime.matchers.psystem.basicdeferred.Inequality;
37import tools.refinery.viatra.runtime.matchers.psystem.basicdeferred.NegativePatternCall;
38import tools.refinery.viatra.runtime.matchers.psystem.basicdeferred.PatternMatchCounter;
39import tools.refinery.viatra.runtime.matchers.psystem.basicdeferred.TypeFilterConstraint;
40import tools.refinery.viatra.runtime.matchers.psystem.basicenumerables.BinaryReflexiveTransitiveClosure;
41import tools.refinery.viatra.runtime.matchers.psystem.basicenumerables.BinaryTransitiveClosure;
42import tools.refinery.viatra.runtime.matchers.psystem.basicenumerables.ConstantValue;
43import tools.refinery.viatra.runtime.matchers.psystem.basicenumerables.PositivePatternCall;
44import tools.refinery.viatra.runtime.matchers.psystem.basicenumerables.TypeConstraint;
45import tools.refinery.viatra.runtime.matchers.psystem.queries.PParameter;
46import tools.refinery.viatra.runtime.matchers.tuple.TupleMask;
47import tools.refinery.viatra.runtime.matchers.util.Accuracy;
48import tools.refinery.viatra.runtime.matchers.util.Preconditions;
49
50/**
51 * Cost function which calculates cost based on the cardinality of items in the runtime model
52 *
53 * <p> To provide custom statistics, override
54 * {@link #projectionSize(IConstraintEvaluationContext, IInputKey, TupleMask, Accuracy)}
55 * and {@link #bucketSize(IQueryReference, IConstraintEvaluationContext, TupleMask)}.
56 *
57 * @author Grill Balázs
58 * @since 1.4
59 */
60public abstract class StatisticsBasedConstraintCostFunction implements ICostFunction {
61 protected static final double MAX_COST = 250.0;
62
63 protected static final double DEFAULT_COST = StatisticsBasedConstraintCostFunction.MAX_COST - 100.0;
64
65 /**
66 * @since 2.1
67 */
68 public static final double INVERSE_NAVIGATION_PENALTY_DEFAULT = 0.10;
69 /**
70 * @since 2.1
71 */
72 public static final double INVERSE_NAVIGATION_PENALTY_GENERIC = 0.01;
73 /**
74 * @since 2.7
75 */
76 public static final double EVAL_UNWIND_EXTENSION_FACTOR = 3.0;
77
78 private final double inverseNavigationPenalty;
79
80
81 /**
82 * @since 2.1
83 */
84 public StatisticsBasedConstraintCostFunction(double inverseNavigationPenalty) {
85 super();
86 this.inverseNavigationPenalty = inverseNavigationPenalty;
87 }
88 public StatisticsBasedConstraintCostFunction() {
89 this(INVERSE_NAVIGATION_PENALTY_DEFAULT);
90 }
91
92 /**
93 * @deprecated call and implement {@link #projectionSize(IConstraintEvaluationContext, IInputKey, TupleMask, Accuracy)} instead
94 */
95 @Deprecated
96 public long countTuples(final IConstraintEvaluationContext input, final IInputKey supplierKey) {
97 return projectionSize(input, supplierKey, TupleMask.identity(supplierKey.getArity()), Accuracy.EXACT_COUNT).orElse(-1L);
98 }
99
100 /**
101 * Override this to provide custom statistics on edge/node counts.
102 * New implementors shall implement this instead of {@link #countTuples(IConstraintEvaluationContext, IInputKey)}
103 * @since 2.1
104 */
105 public Optional<Long> projectionSize(final IConstraintEvaluationContext input, final IInputKey supplierKey,
106 final TupleMask groupMask, Accuracy requiredAccuracy) {
107 long legacyCount = countTuples(input, supplierKey);
108 return legacyCount < 0 ? Optional.empty() : Optional.of(legacyCount);
109 }
110
111 /**
112 * Override this to provide custom estimates for match set sizes of called patterns.
113 * @since 2.1
114 */
115 public Optional<Double> bucketSize(final IQueryReference patternCall,
116 final IConstraintEvaluationContext input, TupleMask projMask) {
117 IQueryResultProvider resultProvider = input.resultProviderRequestor().requestResultProvider(patternCall, null);
118 // TODO hack: use LS cost instead of true bucket size estimate
119 if (resultProvider instanceof AbstractLocalSearchResultProvider) {
120 double estimatedCost = ((AbstractLocalSearchResultProvider) resultProvider).estimateCost(projMask);
121 return Optional.of(estimatedCost);
122 } else {
123 return resultProvider.estimateAverageBucketSize(projMask, Accuracy.APPROXIMATION);
124 }
125 }
126
127
128
129 @Override
130 public double apply(final IConstraintEvaluationContext input) {
131 return this.calculateCost(input.getConstraint(), input);
132 }
133
134 protected double _calculateCost(final ConstantValue constant, final IConstraintEvaluationContext input) {
135 return 0.0f;
136 }
137
138 protected double _calculateCost(final TypeConstraint constraint, final IConstraintEvaluationContext input) {
139 final Collection<PVariable> freeMaskVariables = input.getFreeVariables();
140 final Collection<PVariable> boundMaskVariables = input.getBoundVariables();
141 IInputKey supplierKey = constraint.getSupplierKey();
142 long arity = supplierKey.getArity();
143
144 if ((arity == 1)) {
145 // unary constraint
146 return calculateUnaryConstraintCost(constraint, input);
147 } else if ((arity == 2)) {
148 // binary constraint
149 PVariable srcVariable = ((PVariable) constraint.getVariablesTuple().get(0));
150 PVariable dstVariable = ((PVariable) constraint.getVariablesTuple().get(1));
151 boolean isInverse = false;
152 // Check if inverse navigation is needed along the edge
153 if ((freeMaskVariables.contains(srcVariable) && boundMaskVariables.contains(dstVariable))) {
154 isInverse = true;
155 }
156 double binaryExtendCost = calculateBinaryCost(supplierKey, srcVariable, dstVariable, isInverse, input);
157 // Make inverse navigation slightly more expensive than forward navigation
158 // See https://bugs.eclipse.org/bugs/show_bug.cgi?id=501078
159 return (isInverse) ? binaryExtendCost + inverseNavigationPenalty : binaryExtendCost;
160 } else {
161 // n-ary constraint
162 throw new UnsupportedOperationException("Cost calculation for arity " + arity + " is not implemented yet");
163 }
164 }
165
166
167 /**
168 * @deprecated use/implement {@link #calculateBinaryCost(IInputKey, PVariable, PVariable, boolean, IConstraintEvaluationContext)} instead
169 */
170 @Deprecated
171 protected double calculateBinaryExtendCost(final IInputKey supplierKey, final PVariable srcVariable,
172 final PVariable dstVariable, final boolean isInverse, long edgeCount /* TODO remove */,
173 final IConstraintEvaluationContext input) {
174 throw new UnsupportedOperationException();
175 }
176
177 /**
178 * @since 2.1
179 */
180 protected double calculateBinaryCost(final IInputKey supplierKey, final PVariable srcVariable,
181 final PVariable dstVariable, final boolean isInverse,
182 final IConstraintEvaluationContext input) {
183 final Collection<PVariable> freeMaskVariables = input.getFreeVariables();
184 final PConstraint constraint = input.getConstraint();
185
186// IQueryMetaContext metaContext = input.getRuntimeContext().getMetaContext();
187// Collection<InputKeyImplication> implications = metaContext.getImplications(supplierKey);
188
189 Optional<Long> edgeUpper = projectionSize(input, supplierKey, TupleMask.identity(2), Accuracy.BEST_UPPER_BOUND);
190 Optional<Long> srcUpper = projectionSize(input, supplierKey, TupleMask.selectSingle(0, 2), Accuracy.BEST_UPPER_BOUND);
191 Optional<Long> dstUpper = projectionSize(input, supplierKey, TupleMask.selectSingle(1, 2), Accuracy.BEST_UPPER_BOUND);
192
193 if (freeMaskVariables.contains(srcVariable) && freeMaskVariables.contains(dstVariable)) {
194 Double branchCount = edgeUpper.map(Long::doubleValue).orElse(
195 srcUpper.map(Long::doubleValue).orElse(DEFAULT_COST)
196 *
197 dstUpper.map(Long::doubleValue).orElse(DEFAULT_COST)
198 );
199 return branchCount;
200
201 } else {
202
203 Optional<Long> srcLower = projectionSize(input, supplierKey, TupleMask.selectSingle(0, 2), Accuracy.APPROXIMATION);
204 Optional<Long> dstLower = projectionSize(input, supplierKey, TupleMask.selectSingle(1, 2), Accuracy.APPROXIMATION);
205
206 List<Optional<Long>> nodeLower = Arrays.asList(srcLower, dstLower);
207 List<Optional<Long>> nodeUpper = Arrays.asList(srcUpper, dstUpper);
208
209 int from = isInverse ? 1 : 0;
210 int to = isInverse ? 0 : 1;
211
212 Optional<Double> costEstimate = Optional.empty();
213
214 if (!freeMaskVariables.contains(srcVariable) && !freeMaskVariables.contains(dstVariable)) {
215 // both variables bound, this is a simple check
216 costEstimate = min(costEstimate, 0.9);
217 } // TODO use bucket size estimation in the runtime context
218 costEstimate = min(costEstimate,
219 edgeUpper.flatMap(edges ->
220 nodeLower.get(from).map(fromNodes ->
221 // amortize edges over start nodes
222 (fromNodes == 0) ? 0.0 : (((double) edges) / fromNodes)
223 )));
224 if (navigatesThroughFunctionalDependencyInverse(input, constraint)) {
225 costEstimate = min(costEstimate, nodeUpper.get(to).flatMap(toNodes ->
226 nodeLower.get(from).map(fromNodes ->
227 // due to a reverse functional dependency, the destination count is an upper bound for the edge count
228 (fromNodes == 0) ? 0.0 : ((double) toNodes) / fromNodes
229 )));
230 }
231 if (! edgeUpper.isPresent()) {
232 costEstimate = min(costEstimate, nodeUpper.get(to).flatMap(toNodes ->
233 nodeLower.get(from).map(fromNodes ->
234 // If count is 0, no such element exists in the model, so there will be no branching
235 // TODO rethink, why dstNodeCount / srcNodeCount instead of dstNodeCount?
236 // The universally valid bound would be something like sparseEdgeEstimate = dstNodeCount + 1.0
237 // If we assume sparseness, we can reduce it by a SPARSENESS_FACTOR (e.g. 0.1).
238 // Alternatively, discount dstNodeCount * srcNodeCount on a SPARSENESS_EXPONENT (e.g 0.75) and then amortize over srcNodeCount.
239 fromNodes != 0 ? Math.max(1.0, ((double) toNodes) / fromNodes) : 1.0
240 )));
241 }
242 if (navigatesThroughFunctionalDependency(input, constraint)) {
243 // At most one destination value
244 costEstimate = min(costEstimate, 1.0);
245 }
246
247 return costEstimate.orElse(DEFAULT_COST);
248
249 }
250 }
251
252 /**
253 * @since 1.7
254 */
255 protected boolean navigatesThroughFunctionalDependency(final IConstraintEvaluationContext input,
256 final PConstraint constraint) {
257 return navigatesThroughFunctionalDependency(input, constraint, input.getBoundVariables(), input.getFreeVariables());
258 }
259 /**
260 * @since 2.1
261 */
262 protected boolean navigatesThroughFunctionalDependencyInverse(final IConstraintEvaluationContext input,
263 final PConstraint constraint) {
264 return navigatesThroughFunctionalDependency(input, constraint, input.getFreeVariables(), input.getBoundVariables());
265 }
266 /**
267 * @since 2.1
268 */
269 protected boolean navigatesThroughFunctionalDependency(final IConstraintEvaluationContext input,
270 final PConstraint constraint, Collection<PVariable> determining, Collection<PVariable> determined) {
271 final QueryAnalyzer queryAnalyzer = input.getQueryAnalyzer();
272 final Map<Set<PVariable>, Set<PVariable>> functionalDependencies = queryAnalyzer
273 .getFunctionalDependencies(Collections.singleton(constraint), false);
274 final Set<PVariable> impliedVariables = FunctionalDependencyHelper.closureOf(determining,
275 functionalDependencies);
276 return ((impliedVariables != null) && impliedVariables.containsAll(determined));
277 }
278
279 protected double calculateUnaryConstraintCost(final TypeConstraint constraint,
280 final IConstraintEvaluationContext input) {
281 PVariable variable = (PVariable) constraint.getVariablesTuple().get(0);
282 if (input.getBoundVariables().contains(variable)) {
283 return 0.9;
284 } else {
285 return projectionSize(input, constraint.getSupplierKey(), TupleMask.identity(1), Accuracy.APPROXIMATION)
286 .map(count -> 1.0 + count).orElse(DEFAULT_COST);
287 }
288 }
289
290 protected double _calculateCost(final ExportedParameter exportedParam, final IConstraintEvaluationContext input) {
291 return 0.0;
292 }
293
294 protected double _calculateCost(final TypeFilterConstraint exportedParam,
295 final IConstraintEvaluationContext input) {
296 return 0.0;
297 }
298
299 protected double _calculateCost(final PositivePatternCall patternCall, final IConstraintEvaluationContext input) {
300 final List<Integer> boundPositions = new ArrayList<>();
301 final List<PParameter> parameters = patternCall.getReferredQuery().getParameters();
302 for (int i = 0; (i < parameters.size()); i++) {
303 final PVariable variable = patternCall.getVariableInTuple(i);
304 if (input.getBoundVariables().contains(variable)) boundPositions.add(i);
305 }
306 TupleMask projMask = TupleMask.fromSelectedIndices(parameters.size(), boundPositions);
307
308 return bucketSize(patternCall, input, projMask).orElse(DEFAULT_COST);
309 }
310
311
312 /**
313 * @since 1.7
314 */
315 protected double _calculateCost(final ExpressionEvaluation evaluation, final IConstraintEvaluationContext input) {
316 // Even if there are multiple results here, if all output variable is bound eval unwind will not result in
317 // multiple branches in search graph
318 final double multiplier = evaluation.isUnwinding() && !input.getFreeVariables().isEmpty()
319 ? EVAL_UNWIND_EXTENSION_FACTOR
320 : 1.0;
321 return _calculateCost((PConstraint) evaluation, input) * multiplier;
322 }
323
324 /**
325 * @since 1.7
326 */
327 protected double _calculateCost(final Inequality inequality, final IConstraintEvaluationContext input) {
328 return _calculateCost((PConstraint)inequality, input);
329 }
330
331 /**
332 * @since 1.7
333 */
334 protected double _calculateCost(final AggregatorConstraint aggregator, final IConstraintEvaluationContext input) {
335 return _calculateCost((PConstraint)aggregator, input);
336 }
337
338 /**
339 * @since 1.7
340 */
341 protected double _calculateCost(final NegativePatternCall call, final IConstraintEvaluationContext input) {
342 return _calculateCost((PConstraint)call, input);
343 }
344
345 /**
346 * @since 1.7
347 */
348 protected double _calculateCost(final PatternMatchCounter counter, final IConstraintEvaluationContext input) {
349 return _calculateCost((PConstraint)counter, input);
350 }
351
352 /**
353 * @since 1.7
354 */
355 protected double _calculateCost(final BinaryTransitiveClosure closure, final IConstraintEvaluationContext input) {
356 // if (input.getFreeVariables().size() == 1) return 3.0;
357 return StatisticsBasedConstraintCostFunction.DEFAULT_COST;
358 }
359
360 /**
361 * @since 2.0
362 */
363 protected double _calculateCost(final BinaryReflexiveTransitiveClosure closure, final IConstraintEvaluationContext input) {
364 // if (input.getFreeVariables().size() == 1) return 3.0;
365 return StatisticsBasedConstraintCostFunction.DEFAULT_COST;
366 }
367
368 /**
369 * Default cost calculation strategy
370 */
371 protected double _calculateCost(final PConstraint constraint, final IConstraintEvaluationContext input) {
372 if (input.getFreeVariables().isEmpty()) {
373 return 1.0;
374 } else {
375 return StatisticsBasedConstraintCostFunction.DEFAULT_COST;
376 }
377 }
378
379 /**
380 * @throws ViatraQueryRuntimeException
381 */
382 public double calculateCost(final PConstraint constraint, final IConstraintEvaluationContext input) {
383 Preconditions.checkArgument(constraint != null, "Set constraint value correctly");
384 if (constraint instanceof ExportedParameter) {
385 return _calculateCost((ExportedParameter) constraint, input);
386 } else if (constraint instanceof TypeFilterConstraint) {
387 return _calculateCost((TypeFilterConstraint) constraint, input);
388 } else if (constraint instanceof ConstantValue) {
389 return _calculateCost((ConstantValue) constraint, input);
390 } else if (constraint instanceof PositivePatternCall) {
391 return _calculateCost((PositivePatternCall) constraint, input);
392 } else if (constraint instanceof TypeConstraint) {
393 return _calculateCost((TypeConstraint) constraint, input);
394 } else if (constraint instanceof ExpressionEvaluation) {
395 return _calculateCost((ExpressionEvaluation) constraint, input);
396 } else if (constraint instanceof Inequality) {
397 return _calculateCost((Inequality) constraint, input);
398 } else if (constraint instanceof AggregatorConstraint) {
399 return _calculateCost((AggregatorConstraint) constraint, input);
400 } else if (constraint instanceof NegativePatternCall) {
401 return _calculateCost((NegativePatternCall) constraint, input);
402 } else if (constraint instanceof PatternMatchCounter) {
403 return _calculateCost((PatternMatchCounter) constraint, input);
404 } else if (constraint instanceof BinaryTransitiveClosure) {
405 return _calculateCost((BinaryTransitiveClosure) constraint, input);
406 } else if (constraint instanceof BinaryReflexiveTransitiveClosure) {
407 return _calculateCost((BinaryReflexiveTransitiveClosure) constraint, input);
408 } else {
409 // Default cost calculation
410 return _calculateCost(constraint, input);
411 }
412 }
413}