aboutsummaryrefslogtreecommitdiffstats
path: root/Solvers
diff options
context:
space:
mode:
authorLibravatar 20001LastOrder <boqi.chen@mail.mcgill.ca>2019-08-13 18:10:02 -0400
committerLibravatar 20001LastOrder <boqi.chen@mail.mcgill.ca>2019-08-13 18:10:02 -0400
commita882ad00515e730bad5e52fa29b74f461a5b9cd6 (patch)
tree5c892f1dc5b501aa3f7355e97f24cca277a473db /Solvers
parentConfigurations for generation and new domain for generation ecore model (diff)
downloadVIATRA-Generator-a882ad00515e730bad5e52fa29b74f461a5b9cd6.tar.gz
VIATRA-Generator-a882ad00515e730bad5e52fa29b74f461a5b9cd6.tar.zst
VIATRA-Generator-a882ad00515e730bad5e52fa29b74f461a5b9cd6.zip
change exploration value function
Diffstat (limited to 'Solvers')
-rw-r--r--Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java136
1 files changed, 74 insertions, 62 deletions
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java
index 4dff00cd..c817d20b 100644
--- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java
+++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java
@@ -6,12 +6,14 @@ import java.util.Collection;
6import java.util.Collections; 6import java.util.Collections;
7import java.util.Comparator; 7import java.util.Comparator;
8import java.util.HashMap; 8import java.util.HashMap;
9import java.util.HashSet;
9import java.util.Iterator; 10import java.util.Iterator;
10import java.util.LinkedList; 11import java.util.LinkedList;
11import java.util.List; 12import java.util.List;
12import java.util.Map; 13import java.util.Map;
13import java.util.PriorityQueue; 14import java.util.PriorityQueue;
14import java.util.Random; 15import java.util.Random;
16import java.util.Set;
15 17
16import org.apache.log4j.Logger; 18import org.apache.log4j.Logger;
17import org.eclipse.emf.ecore.util.EcoreUtil; 19import org.eclipse.emf.ecore.util.EcoreUtil;
@@ -64,19 +66,21 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
64 // matchers for detecting the number of violations 66 // matchers for detecting the number of violations
65 private Collection<ViatraQueryMatcher<? extends IPatternMatch>> mustMatchers; 67 private Collection<ViatraQueryMatcher<? extends IPatternMatch>> mustMatchers;
66 private Collection<ViatraQueryMatcher<? extends IPatternMatch>> mayMatchers; 68 private Collection<ViatraQueryMatcher<? extends IPatternMatch>> mayMatchers;
69
67 // Encode the used activations of a particular state 70 // Encode the used activations of a particular state
68 private Map<Object, List<Object>> stateAndActivations; 71 private Map<Object, List<Object>> stateAndActivations;
69 private Map<TrajectoryWithFitness, Double> trajectoryFit;
70
71 private boolean allowMustViolation; 72 private boolean allowMustViolation;
72 private Domain domain; 73 private Domain domain;
73 74 int targetSize;
75
74 // Statistics 76 // Statistics
75 private int numberOfStatecoderFail = 0; 77 private int numberOfStatecoderFail = 0;
76 private int numberOfPrintedModel = 0; 78 private int numberOfPrintedModel = 0;
77 private int numberOfSolverCalls = 0; 79 private int numberOfSolverCalls = 0;
78 private PartialInterpretationMetricDistance metricDistance; 80 private PartialInterpretationMetricDistance metricDistance;
79 81 private double currentStateValue = Double.MAX_VALUE;
82 private double currentNodeTypeDistance = 1;
83 private int numNodesToGenerate = 0;
80 84
81 public HillClimbingOnRealisticMetricStrategyForModelGeneration( 85 public HillClimbingOnRealisticMetricStrategyForModelGeneration(
82 ReasonerWorkspace workspace, 86 ReasonerWorkspace workspace,
@@ -102,12 +106,14 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
102 public void initStrategy(ThreadContext context) { 106 public void initStrategy(ThreadContext context) {
103 this.context = context; 107 this.context = context;
104 this.solutionStore = context.getGlobalContext().getSolutionStore(); 108 this.solutionStore = context.getGlobalContext().getSolutionStore();
105 109 domain = Domain.valueOf(configuration.domain);
110
106 ViatraQueryEngine engine = context.getQueryEngine(); 111 ViatraQueryEngine engine = context.getQueryEngine();
107// // TODO: visualisation 112// // TODO: visualisation
108 mustMatchers = new LinkedList<ViatraQueryMatcher<? extends IPatternMatch>>(); 113 mustMatchers = new LinkedList<ViatraQueryMatcher<? extends IPatternMatch>>();
109 mayMatchers = new LinkedList<ViatraQueryMatcher<? extends IPatternMatch>>(); 114 mayMatchers = new LinkedList<ViatraQueryMatcher<? extends IPatternMatch>>();
110 115
116 // manully restict the number of super types of one class
111 this.method.getInvalidWF().forEach(a ->{ 117 this.method.getInvalidWF().forEach(a ->{
112 ViatraQueryMatcher<? extends IPatternMatch> matcher = a.getMatcher(engine); 118 ViatraQueryMatcher<? extends IPatternMatch> matcher = a.getMatcher(engine);
113 mustMatchers.add(matcher); 119 mustMatchers.add(matcher);
@@ -117,27 +123,27 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
117 ViatraQueryMatcher<? extends IPatternMatch> matcher = a.getMatcher(engine); 123 ViatraQueryMatcher<? extends IPatternMatch> matcher = a.getMatcher(engine);
118 mayMatchers.add(matcher); 124 mayMatchers.add(matcher);
119 }); 125 });
126
120 127
121 128 //set up comparator
122 this.solutionStoreWithCopy = new SolutionStoreWithCopy(); 129 final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper();
123 this.solutionStoreWithDiversityDescriptor = new SolutionStoreWithDiversityDescriptor(configuration.diversityRequirement);
124
125 trajectoryFit = new HashMap<TrajectoryWithFitness, Double>();
126 this.comparator = new Comparator<TrajectoryWithFitness>() { 130 this.comparator = new Comparator<TrajectoryWithFitness>() {
127 @Override 131 @Override
128 public int compare(TrajectoryWithFitness o1, TrajectoryWithFitness o2) { 132 public int compare(TrajectoryWithFitness o1, TrajectoryWithFitness o2) {
129 return Double.compare(trajectoryFit.get(o1), trajectoryFit.get(o2)); 133 return objectiveComparatorHelper.compare(o2.fitness, o1.fitness);
130 } 134 }
131 }; 135 };
132 136
137 this.solutionStoreWithCopy = new SolutionStoreWithCopy();
138 this.solutionStoreWithDiversityDescriptor = new SolutionStoreWithDiversityDescriptor(configuration.diversityRequirement);
139
133 trajectoiresToExplore = new PriorityQueue<TrajectoryWithFitness>(11, comparator); 140 trajectoiresToExplore = new PriorityQueue<TrajectoryWithFitness>(11, comparator);
134 stateAndActivations = new HashMap<Object, List<Object>>(); 141 stateAndActivations = new HashMap<Object, List<Object>>();
135
136 domain = Domain.valueOf(configuration.domain);
137 metricDistance = new PartialInterpretationMetricDistance(domain); 142 metricDistance = new PartialInterpretationMetricDistance(domain);
138 143
139 //set whether allows must violations during the realistic generation 144 //set whether allows must violations during the realistic generation
140 allowMustViolation = configuration.allowMustViolations; 145 allowMustViolation = configuration.allowMustViolations;
146 targetSize = configuration.typeScopes.maxNewElements + 2;
141 } 147 }
142 148
143 @Override 149 @Override
@@ -156,14 +162,12 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
156 //final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper(); 162 //final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper();
157 final Object[] firstTrajectory = context.getTrajectory().toArray(new Object[0]); 163 final Object[] firstTrajectory = context.getTrajectory().toArray(new Object[0]);
158 TrajectoryWithFitness currentTrajectoryWithFittness = new TrajectoryWithFitness(firstTrajectory, firstFittness); 164 TrajectoryWithFitness currentTrajectoryWithFittness = new TrajectoryWithFitness(firstTrajectory, firstFittness);
159 trajectoryFit.put(currentTrajectoryWithFittness, Double.MAX_VALUE);
160 trajectoiresToExplore.add(currentTrajectoryWithFittness); 165 trajectoiresToExplore.add(currentTrajectoryWithFittness);
161 Object lastState = null; 166 Object lastState = null;
162 167
163 //if(configuration) 168 //if(configuration)
164 visualiseCurrentState(); 169 visualiseCurrentState();
165 // the two is the True and False node generated at the beginning of the generation 170 // the two is the True and False node generated at the beginning of the generation
166 int targetSize = configuration.typeScopes.maxNewElements + 2;
167 int count = 0; 171 int count = 0;
168 mainLoop: while (!isInterrupted && !configuration.progressMonitor.isCancelled()) { 172 mainLoop: while (!isInterrupted && !configuration.progressMonitor.isCancelled()) {
169 173
@@ -202,21 +206,22 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
202 double epsilon = 1.0/count; 206 double epsilon = 1.0/count;
203 double draw = Math.random(); 207 double draw = Math.random();
204 count++; 208 count++;
205 209 this.currentNodeTypeDistance = heuristics.getNodeTypeDistance();
210 numNodesToGenerate = model.getMaxNewElements();
206 System.out.println("NA distance: " + heuristics.getNADistance()); 211 System.out.println("NA distance: " + heuristics.getNADistance());
207 System.out.println("MPC distance: " + heuristics.getMPCDistance()); 212 System.out.println("MPC distance: " + heuristics.getMPCDistance());
208 System.out.println("Out degree distance:" + heuristics.getOutDegreeDistance()); 213 System.out.println("Out degree distance:" + heuristics.getOutDegreeDistance());
209 System.out.println("NodeType :" + heuristics.getNodeTypeDistance()); 214 System.out.println("NodeType :" + currentNodeTypeDistance);
210 System.out.println("Edge :" + heuristics.edgeTypeDistance);
211 215
212// System.out.println("FinalState :" + heuristics.getNodeTypePercentage("FinalState")); 216// System.out.println("FinalState :" + heuristics.getNodeTypePercentage("FinalState"));
213 217
214 //TODO: the number of activations to be checked should be configurasble 218 //TODO: the number of activations to be checked should be configurasble
219 System.out.println(activationIds.size());
215 if(activationIds.size() > 50) { 220 if(activationIds.size() > 50) {
216 activationIds = activationIds.subList(0, 50); 221 activationIds = activationIds.subList(0, 50);
217 } 222 }
218 223
219 valueMap = sortWithWeight(activationIds, model.getNewElements().size()); 224 valueMap = sortWithWeight(activationIds);
220 lastState = context.getCurrentStateId(); 225 lastState = context.getCurrentStateId();
221 while (!isInterrupted && !configuration.progressMonitor.isCancelled() && activationIds.size() > 0) { 226 while (!isInterrupted && !configuration.progressMonitor.isCancelled() && activationIds.size() > 0) {
222 final Object nextActivation = drawWithEpsilonProbabilty(activationIds, valueMap, epsilon, draw); 227 final Object nextActivation = drawWithEpsilonProbabilty(activationIds, valueMap, epsilon, draw);
@@ -230,9 +235,10 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
230 235
231 int currentSize = model.getNewElements().size(); 236 int currentSize = model.getNewElements().size();
232 int targetDiff = targetSize - currentSize; 237 int targetDiff = targetSize - currentSize;
238 boolean shouldFinish = currentSize >= targetSize;
233 239
234 // does not allow must violations 240 // does not allow must violations
235 if((getNumberOfViolations(mustMatchers) > 0|| getNumberOfViolations(mayMatchers) > targetDiff) && !allowMustViolation) { 241 if((getNumberOfViolations(mustMatchers) > 0|| getNumberOfViolations(mayMatchers) > targetDiff) && !allowMustViolation && !shouldFinish) {
236 context.backtrack(); 242 context.backtrack();
237 }else { 243 }else {
238 final Fitness nextFitness = context.calculateFitness(); 244 final Fitness nextFitness = context.calculateFitness();
@@ -251,38 +257,21 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
251 context.getTrajectory().toArray(), nextFitness); 257 context.getTrajectory().toArray(), nextFitness);
252 int nodeSize = ((PartialInterpretation) context.getModel()).getNewElements().size(); 258 int nodeSize = ((PartialInterpretation) context.getModel()).getNewElements().size();
253 int violation = getNumberOfViolations(mayMatchers); 259 int violation = getNumberOfViolations(mayMatchers);
254 metricDistance.getLinearModel().feedData(context.getCurrentStateId(), metricDistance.calculateFeature(nodeSize, violation), calculateCurrentStateValue(nodeSize, violation), lastState); 260 double currentValue = calculateCurrentStateValue(nodeSize, violation);
255 double value = calculateCurrentStateValue(nodeSize, violation); 261 metricDistance.getLinearModel().feedData(context.getCurrentStateId(), metricDistance.calculateFeature(nodeSize, violation), currentValue, lastState);
256 trajectoryFit.put(nextTrajectoryWithFittness, value);
257 trajectoiresToExplore.add(nextTrajectoryWithFittness); 262 trajectoiresToExplore.add(nextTrajectoryWithFittness);
258 263 currentStateValue = currentValue;
259 //Currently, just go to the next state without considering the value of trajectory 264 //Currently, just go to the next state without considering the value of trajectory
260 currentTrajectoryWithFittness = nextTrajectoryWithFittness; 265 currentTrajectoryWithFittness = nextTrajectoryWithFittness;
261 continue mainLoop; 266 continue mainLoop;
262 267
263// int compare = objectiveComparatorHelper.compare(currentTrajectoryWithFittness.fitness,
264// nextTrajectoryWithFittness.fitness);
265// if (compare < 0) {
266// logger.debug("Better fitness, moving on: " + nextFitness);
267// currentTrajectoryWithFittness = nextTrajectoryWithFittness;
268// continue mainLoop;
269// } else if (compare == 0) {
270// logger.debug("Equally good fitness, moving on: " + nextFitness);
271// currentTrajectoryWithFittness = nextTrajectoryWithFittness;
272// continue mainLoop;
273// } else {
274// logger.debug("Worse fitness.");
275// currentTrajectoryWithFittness = nextTrajectoryWithFittness;
276// continue mainLoop;
277 }
278 } 268 }
279 } 269 }
280 logger.debug("State is fully traversed."); 270 logger.debug("State is fully traversed.");
281 trajectoiresToExplore.remove(currentTrajectoryWithFittness); 271 trajectoiresToExplore.remove(currentTrajectoryWithFittness);
282 currentTrajectoryWithFittness = null; 272 currentTrajectoryWithFittness = null;
283 context.backtrack(); 273 context.backtrack();
284// } 274 }
285
286 logger.info("Interrupted."); 275 logger.info("Interrupted.");
287 } 276 }
288 277
@@ -291,14 +280,23 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
291 * @param activationIds 280 * @param activationIds
292 * @return: activation to value map 281 * @return: activation to value map
293 */ 282 */
294 private Map<Object, Double> sortWithWeight(List<Object> activationIds, int factor){ 283 private Map<Object, Double> sortWithWeight(List<Object> activationIds){
295 Map<Object, Double> valueMap = new HashMap<Object, Double>(); 284 Map<Object, Double> valueMap = new HashMap<Object, Double>();
296 285 Object currentId = context.getCurrentStateId();
297 // check for next states 286 // check for next states
298 for(Object id : activationIds) { 287 for(Object id : activationIds) {
299 context.executeAcitvationId(id); 288 context.executeAcitvationId(id);
300 int violation = getNumberOfViolations(mayMatchers); 289 int violation = getNumberOfViolations(mayMatchers);
301 valueMap.put(id, calculateFutureStateValue(factor, violation)); 290
291 if(!allowMustViolation && getNumberOfViolations(mustMatchers) > 0) {
292 valueMap.put(id, Double.MAX_VALUE);
293 stateAndActivations.get(currentId).add(id);
294 }else {
295 valueMap.put(id, calculateFutureStateValue(violation));
296 }
297
298
299
302 context.backtrack(); 300 context.backtrack();
303 } 301 }
304 302
@@ -308,25 +306,22 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
308 return valueMap; 306 return valueMap;
309 } 307 }
310 308
311 private double calculateFutureStateValue(int step, int violation) { 309 private double calculateFutureStateValue(int violation) {
312 double currentValue = calculateCurrentStateValue(step, violation); 310 int nodeSize = ((PartialInterpretation) context.getModel()).getNewElements().size();
313 311 double currentValue = calculateCurrentStateValue(nodeSize,violation);
314 if(step > 10 && currentValue < 10000) { 312 double[] toPredict = metricDistance.calculateFeature(100, violation);
315 double[] toPredict = metricDistance.calculateFeature(100, violation); 313 if(Math.abs(currentValue - currentStateValue) < 0.001) {
316 try { 314 return Double.MAX_VALUE;
317 return metricDistance.getLinearModel().getPredictionForNextDataSample(metricDistance.calculateFeature(step, violation), currentValue, toPredict); 315 }
318 }catch(IllegalArgumentException e) { 316 try {
319 return currentValue; 317 return metricDistance.getLinearModel().getPredictionForNextDataSample(metricDistance.calculateFeature(nodeSize, violation), currentValue, toPredict);
320 } 318 }catch(IllegalArgumentException e) {
321 }else {
322 return currentValue; 319 return currentValue;
323 } 320 }
324 } 321 }
325
326 private double calculateCurrentStateValue(int factor, int violation) { 322 private double calculateCurrentStateValue(int factor, int violation) {
327 PartialInterpretation model = (PartialInterpretation) context.getModel(); 323 PartialInterpretation model = (PartialInterpretation) context.getModel();
328 MetricDistanceGroup g = metricDistance.calculateMetricDistanceKS(model); 324 MetricDistanceGroup g = metricDistance.calculateMetricDistanceKS(model);
329
330 if(configuration.realisticGuidance == RealisticGuidance.MPC) { 325 if(configuration.realisticGuidance == RealisticGuidance.MPC) {
331 return g.getMPCDistance(); 326 return g.getMPCDistance();
332 }else if(configuration.realisticGuidance == RealisticGuidance.NodeActivity) { 327 }else if(configuration.realisticGuidance == RealisticGuidance.NodeActivity) {
@@ -337,18 +332,35 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
337 return g.getNodeTypeDistance(); 332 return g.getNodeTypeDistance();
338 }else if(configuration.realisticGuidance == RealisticGuidance.Composite) { 333 }else if(configuration.realisticGuidance == RealisticGuidance.Composite) {
339 double consistenceWeights = 5 * factor / (configuration.typeScopes.maxNewElements + 2) * (1- 1.0/(1+violation)); 334 double consistenceWeights = 5 * factor / (configuration.typeScopes.maxNewElements + 2) * (1- 1.0/(1+violation));
340
341 if(domain == Domain.Yakindumm) { 335 if(domain == Domain.Yakindumm) {
342 return(100.0 *(g.getNodeTypeDistance()) + 5*(g.getNADistance() + 5*g.getMPCDistance() +g.getOutDegreeDistance()) + consistenceWeights); 336 double unfinishFactor = 50 * (1 - (double)factor / targetSize);
337 double nodeTypeFactor = g.getNodeTypeDistance();
338 double normalFactor = 5;
339 if(currentNodeTypeDistance <= 0.05 || numNodesToGenerate == 1) {
340 nodeTypeFactor = 0;
341 normalFactor = 100;
342 unfinishFactor = 0;
343 }
344
345 return 100*(nodeTypeFactor) + normalFactor*(2*g.getNADistance() + g.getMPCDistance() + 2*g.getOutDegreeDistance()) + normalFactor / 5*consistenceWeights + unfinishFactor;
343 }else { 346 }else {
344 return 10*(g.getNodeTypeDistance()) + 5*(g.getNADistance() + g.getMPCDistance() +2*g.getOutDegreeDistance()) + consistenceWeights; 347 double unfinishFactor = 100 * (1 - (double)factor / targetSize);
348 double nodeTypeFactor = g.getNodeTypeDistance();
349 double normalFactor = 5;
350 if(currentNodeTypeDistance <= 0.12 || numNodesToGenerate == 1) {
351 nodeTypeFactor = 0;
352 normalFactor = 100;
353 unfinishFactor *= 0.5;
354 }
355
356 return 100*(nodeTypeFactor) + normalFactor*(2*g.getNADistance() + g.getMPCDistance() + 2*g.getOutDegreeDistance()) + normalFactor / 5*consistenceWeights + unfinishFactor;
345 } 357 }
346 358
347 }else if(configuration.realisticGuidance == RealisticGuidance.Composite_Without_Violations) { 359 }else if(configuration.realisticGuidance == RealisticGuidance.Composite_Without_Violations) {
348 if(domain == Domain.Yakindumm) { 360 if(domain == Domain.Yakindumm) {
349 return 100.0 *(g.getNodeTypeDistance()) + 5*(g.getNADistance() + g.getMPCDistance() +g.getOutDegreeDistance()); 361 return 100.0 *(g.getNodeTypeDistance()) + 5*(g.getNADistance() + g.getMPCDistance() +g.getOutDegreeDistance());
350 }else { 362 }else {
351 return 10*(g.getNodeTypeDistance()) + 5*(g.getNADistance() + g.getMPCDistance() + 2*g.getOutDegreeDistance()); 363 return 15*(g.getNodeTypeDistance()) + 5*(g.getNADistance() + g.getMPCDistance() + 4*g.getOutDegreeDistance());
352 } 364 }
353 }else { 365 }else {
354 return violation; 366 return violation;