diff options
Diffstat (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/optimization/ThreeValuedCostObjective.xtend')
-rw-r--r-- | Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/optimization/ThreeValuedCostObjective.xtend | 99 |
1 files changed, 47 insertions, 52 deletions
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/optimization/ThreeValuedCostObjective.xtend b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/optimization/ThreeValuedCostObjective.xtend index 0a6fd55b..9b1a7e9f 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/optimization/ThreeValuedCostObjective.xtend +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/optimization/ThreeValuedCostObjective.xtend | |||
@@ -1,85 +1,80 @@ | |||
1 | package hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner.optimization | 1 | package hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner.optimization |
2 | 2 | ||
3 | import com.google.common.collect.ImmutableList | 3 | import hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.cardinality.BoundSaturationListener |
4 | import java.util.Collection | 4 | import java.util.Map |
5 | import org.eclipse.viatra.dse.base.ThreadContext | 5 | import org.eclipse.viatra.dse.base.ThreadContext |
6 | import org.eclipse.viatra.query.runtime.api.IPatternMatch | 6 | import org.eclipse.xtend.lib.annotations.Accessors |
7 | import org.eclipse.viatra.query.runtime.api.IQuerySpecification | ||
8 | import org.eclipse.viatra.query.runtime.api.ViatraQueryMatcher | ||
9 | import org.eclipse.xtend.lib.annotations.Data | ||
10 | 7 | ||
11 | @Data | 8 | class ThreeValuedCostObjective extends AbstractThreeValuedObjective implements BoundSaturationListener { |
12 | class ThreeValuedCostElement { | 9 | @Accessors val Map<String, CostElementMatchers> matchers |
13 | val IQuerySpecification<? extends ViatraQueryMatcher<? extends IPatternMatch>> currentMatchQuery | 10 | double lowerBoundHint = Double.NEGATIVE_INFINITY |
14 | val IQuerySpecification<? extends ViatraQueryMatcher<? extends IPatternMatch>> mayMatchQuery | 11 | double upperBoundHint = Double.POSITIVE_INFINITY |
15 | val IQuerySpecification<? extends ViatraQueryMatcher<? extends IPatternMatch>> mustMatchQuery | ||
16 | val int weight | ||
17 | } | ||
18 | |||
19 | class ThreeValuedCostObjective extends AbstractThreeValuedObjective { | ||
20 | val Collection<ThreeValuedCostElement> costElements | ||
21 | Collection<CostElementMatchers> matchers | ||
22 | 12 | ||
23 | new(String name, Collection<ThreeValuedCostElement> costElements, ObjectiveKind kind, ObjectiveThreshold threshold, | 13 | new(String name, Map<String, CostElementMatchers> matchers, ObjectiveKind kind, ObjectiveThreshold threshold, |
24 | int level) { | 14 | int level) { |
25 | super(name, kind, threshold, level) | 15 | super(name, kind, threshold, level) |
26 | this.costElements = costElements | 16 | this.matchers = matchers |
27 | } | 17 | } |
28 | 18 | ||
29 | override createNew() { | 19 | override createNew() { |
30 | new ThreeValuedCostObjective(name, costElements, kind, threshold, level) | 20 | // new ThreeValuedCostObjective(name, matchers, kind, threshold, level) |
21 | throw new UnsupportedOperationException("ThreeValuedCostObjective can only be used from a single thread") | ||
31 | } | 22 | } |
32 | 23 | ||
33 | override init(ThreadContext context) { | 24 | override init(ThreadContext context) { |
34 | val queryEngine = context.queryEngine | ||
35 | matchers = ImmutableList.copyOf(costElements.map [ element | | ||
36 | new CostElementMatchers( | ||
37 | queryEngine.getMatcher(element.currentMatchQuery), | ||
38 | queryEngine.getMatcher(element.mayMatchQuery), | ||
39 | queryEngine.getMatcher(element.mustMatchQuery), | ||
40 | element.weight | ||
41 | ) | ||
42 | ]) | ||
43 | } | 25 | } |
44 | 26 | ||
45 | override getRawFitness(ThreadContext context) { | 27 | override getRawFitness(ThreadContext context) { |
46 | var int cost = 0 | 28 | var double cost = 0 |
47 | for (matcher : matchers) { | 29 | for (matcher : matchers.values) { |
48 | cost += matcher.weight * matcher.currentMatcher.countMatches | 30 | cost += matcher.weight * matcher.currentNumberOfMatches |
49 | } | 31 | } |
50 | cost as double | 32 | cost |
51 | } | 33 | } |
52 | 34 | ||
53 | override getLowestPossibleFitness(ThreadContext threadContext) { | 35 | override getLowestPossibleFitness(ThreadContext threadContext) { |
54 | var int cost = 0 | 36 | var double cost = 0 |
55 | for (matcher : matchers) { | 37 | for (matcher : matchers.values) { |
56 | if (matcher.weight >= 0) { | 38 | if (matcher.weight >= 0) { |
57 | cost += matcher.weight * matcher.mustMatcher.countMatches | 39 | cost += matcher.weight * matcher.minimumNumberOfMatches |
58 | } else if (matcher.mayMatcher.countMatches > 0) { | 40 | } else { |
59 | // TODO Count may matches. | 41 | cost += matcher.weight * matcher.maximumNumberOfMatches |
60 | return Double.NEGATIVE_INFINITY | ||
61 | } | 42 | } |
62 | } | 43 | } |
63 | cost as double | 44 | val boundWithHint = Math.max(lowerBoundHint, cost) |
45 | if (boundWithHint > upperBoundHint) { | ||
46 | throw new IllegalStateException("Inconsistent cost bounds") | ||
47 | } | ||
48 | boundWithHint | ||
64 | } | 49 | } |
65 | 50 | ||
66 | override getHighestPossibleFitness(ThreadContext threadContext) { | 51 | override getHighestPossibleFitness(ThreadContext threadContext) { |
67 | var int cost = 0 | 52 | var double cost = 0 |
68 | for (matcher : matchers) { | 53 | for (matcher : matchers.values) { |
69 | if (matcher.weight <= 0) { | 54 | if (matcher.weight <= 0) { |
70 | cost += matcher.weight * matcher.mustMatcher.countMatches | 55 | cost += matcher.weight * matcher.minimumNumberOfMatches |
71 | } else if (matcher.mayMatcher.countMatches > 0) { | 56 | } else { |
72 | return Double.POSITIVE_INFINITY | 57 | cost += matcher.weight * matcher.maximumNumberOfMatches |
73 | } | 58 | } |
74 | } | 59 | } |
75 | cost as double | 60 | val boundWithHint = Math.min(upperBoundHint, cost) |
61 | if (boundWithHint < lowerBoundHint) { | ||
62 | throw new IllegalStateException("Inconsistent cost bounds") | ||
63 | } | ||
64 | boundWithHint | ||
76 | } | 65 | } |
77 | 66 | ||
78 | @Data | 67 | override boundsSaturated(Integer lower, Integer upper) { |
79 | private static class CostElementMatchers { | 68 | lowerBoundHint = if (lower === null) { |
80 | val ViatraQueryMatcher<? extends IPatternMatch> currentMatcher | 69 | Double.NEGATIVE_INFINITY |
81 | val ViatraQueryMatcher<? extends IPatternMatch> mayMatcher | 70 | } else { |
82 | val ViatraQueryMatcher<? extends IPatternMatch> mustMatcher | 71 | lower |
83 | val int weight | 72 | } |
73 | upperBoundHint = if (upper === null) { | ||
74 | Double.POSITIVE_INFINITY | ||
75 | } else { | ||
76 | upper | ||
77 | } | ||
78 | println('''Bounds saturated: «lower»..«upper»''') | ||
84 | } | 79 | } |
85 | } | 80 | } |