aboutsummaryrefslogtreecommitdiffstats
path: root/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java
diff options
context:
space:
mode:
Diffstat (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java')
-rw-r--r--Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java121
1 files changed, 50 insertions, 71 deletions
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java
index 6c13e760..9409ac48 100644
--- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java
+++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java
@@ -25,7 +25,6 @@ import org.eclipse.viatra.query.runtime.api.ViatraQueryEngine;
25import org.eclipse.viatra.query.runtime.api.ViatraQueryMatcher; 25import org.eclipse.viatra.query.runtime.api.ViatraQueryMatcher;
26 26
27import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.MetricDistanceGroup; 27import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.MetricDistanceGroup;
28import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetric;
29import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetricDistance; 28import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetricDistance;
30import hu.bme.mit.inf.dslreasoner.logic.model.builder.DocumentationLevel; 29import hu.bme.mit.inf.dslreasoner.logic.model.builder.DocumentationLevel;
31import hu.bme.mit.inf.dslreasoner.logic.model.builder.LogicReasoner; 30import hu.bme.mit.inf.dslreasoner.logic.model.builder.LogicReasoner;
@@ -59,10 +58,14 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
59 private volatile boolean isInterrupted = false; 58 private volatile boolean isInterrupted = false;
60 private ModelResult modelResultByInternalSolver = null; 59 private ModelResult modelResultByInternalSolver = null;
61 private Random random = new Random(); 60 private Random random = new Random();
61
62 // matchers for detecting the number of violations
62 private Collection<ViatraQueryMatcher<? extends IPatternMatch>> mustMatchers; 63 private Collection<ViatraQueryMatcher<? extends IPatternMatch>> mustMatchers;
63 private Collection<ViatraQueryMatcher<? extends IPatternMatch>> mayMatchers; 64 private Collection<ViatraQueryMatcher<? extends IPatternMatch>> mayMatchers;
65 // Encode the used activations of a particular state
64 private Map<Object, List<Object>> stateAndActivations; 66 private Map<Object, List<Object>> stateAndActivations;
65 private Map<TrajectoryWithFitness, Double> trajectoryFit; 67 private Map<TrajectoryWithFitness, Double> trajectoryFit;
68
66 // Statistics 69 // Statistics
67 private int numberOfStatecoderFail = 0; 70 private int numberOfStatecoderFail = 0;
68 private int numberOfPrintedModel = 0; 71 private int numberOfPrintedModel = 0;
@@ -113,16 +116,14 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
113 this.solutionStoreWithCopy = new SolutionStoreWithCopy(); 116 this.solutionStoreWithCopy = new SolutionStoreWithCopy();
114 this.solutionStoreWithDiversityDescriptor = new SolutionStoreWithDiversityDescriptor(configuration.diversityRequirement); 117 this.solutionStoreWithDiversityDescriptor = new SolutionStoreWithDiversityDescriptor(configuration.diversityRequirement);
115 118
116 //final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper();
117 trajectoryFit = new HashMap<TrajectoryWithFitness, Double>(); 119 trajectoryFit = new HashMap<TrajectoryWithFitness, Double>();
118
119 this.comparator = new Comparator<TrajectoryWithFitness>() { 120 this.comparator = new Comparator<TrajectoryWithFitness>() {
120 @Override 121 @Override
121 public int compare(TrajectoryWithFitness o1, TrajectoryWithFitness o2) { 122 public int compare(TrajectoryWithFitness o1, TrajectoryWithFitness o2) {
122 return Double.compare(trajectoryFit.get(o1), trajectoryFit.get(o2)); 123 return Double.compare(trajectoryFit.get(o1), trajectoryFit.get(o2));
123 } 124 }
124 }; 125 };
125 126
126 trajectoiresToExplore = new PriorityQueue<TrajectoryWithFitness>(11, comparator); 127 trajectoiresToExplore = new PriorityQueue<TrajectoryWithFitness>(11, comparator);
127 stateAndActivations = new HashMap<Object, List<Object>>(); 128 stateAndActivations = new HashMap<Object, List<Object>>();
128 metricDistance = new PartialInterpretationMetricDistance(); 129 metricDistance = new PartialInterpretationMetricDistance();
@@ -140,7 +141,6 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
140 } 141 }
141 142
142 final Fitness firstFittness = context.calculateFitness(); 143 final Fitness firstFittness = context.calculateFitness();
143 //checkForSolution(firstFittness);
144 144
145 final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper(); 145 final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper();
146 final Object[] firstTrajectory = context.getTrajectory().toArray(new Object[0]); 146 final Object[] firstTrajectory = context.getTrajectory().toArray(new Object[0]);
@@ -152,7 +152,8 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
152 //if(configuration) 152 //if(configuration)
153 visualiseCurrentState(); 153 visualiseCurrentState();
154 154
155 //create matcher 155 // the two is the True and False node generated at the beginning of the generation
156 int targetSize = configuration.typeScopes.maxNewElements + 2;
156 int count = 0; 157 int count = 0;
157 mainLoop: while (!isInterrupted && !configuration.progressMonitor.isCancelled()) { 158 mainLoop: while (!isInterrupted && !configuration.progressMonitor.isCancelled()) {
158 159
@@ -181,40 +182,38 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
181 Map<Object, Double> valueMap = new HashMap<Object,Double>(); 182 Map<Object, Double> valueMap = new HashMap<Object,Double>();
182 183
183 //init epsilon and draw 184 //init epsilon and draw
184 double epsilon = 0;
185 double draw = 1;
186 MetricDistanceGroup heuristics = metricDistance.calculateMetricDistanceKS(model); 185 MetricDistanceGroup heuristics = metricDistance.calculateMetricDistanceKS(model);
187 186
188 if(!stateAndActivations.containsKey(context.getCurrentStateId())) { 187 if(!stateAndActivations.containsKey(context.getCurrentStateId())) {
189 stateAndActivations.put(context.getCurrentStateId(), new ArrayList<Object>()); 188 stateAndActivations.put(context.getCurrentStateId(), new ArrayList<Object>());
190 } 189 }
191 190
192 //Output intermediate model 191 // calculate values for epsilon greedy
193 if(model.getNewElements().size() > 0) { 192 double epsilon = 1.0/count;
194 epsilon = 1.0/count; 193 double draw = Math.random();
195 draw = Math.random(); 194 count++;
196 count++; 195
197 // cut off the trajectory for bad graph 196 // output statistics
198 //System.out.println("KS"); 197 //System.out.println("KS");
199 System.out.println("NA distance: " + heuristics.getNADistance()); 198 System.out.println("NA distance: " + heuristics.getNADistance());
200 System.out.println("MPC distance: " + heuristics.getMPCDistance()); 199 System.out.println("MPC distance: " + heuristics.getMPCDistance());
201 System.out.println("Out degree distance:" + heuristics.getOutDegreeDistance()); 200 System.out.println("Out degree distance:" + heuristics.getOutDegreeDistance());
202 System.out.println("Exit :" + heuristics.getNodeTypePercentage("Exit")); 201 System.out.println("NodeType :" + heuristics.getNodeTypeDistance());
202 System.out.println("Transition :" + heuristics.getNodeTypePercentage("Transition"));
203 System.out.println("FinalState :" + heuristics.getNodeTypePercentage("FinalState"));
203 204
204// MetricDistanceGroup lsHeuristic = metricDistance.calculateMetricDistance(model); 205// MetricDistanceGroup lsHeuristic = metricDistance.calculateMetricDistance(model);
205// System.out.println("LS"); 206// System.out.println("LS");
206// System.out.println("NA distance: " + lsHeuristic.getNADistance()); 207// System.out.println("NA distance: " + lsHeuristic.getNADistance());
207// System.out.println("MPC distance: " + lsHeuristic.getMPCDistance()); 208// System.out.println("MPC distance: " + lsHeuristic.getMPCDistance());
208// System.out.println("Out degree distance:" + lsHeuristic.getOutDegreeDistance()); 209// System.out.println("Out degree distance:" + lsHeuristic.getOutDegreeDistance());
209 210
210 //check for next value when doing greedy move 211 //TODO: the number of activations to be checked should be configurasble
211 212 if(activationIds.size() > 50) {
212 if(activationIds.size() > 50) { 213 activationIds = activationIds.subList(0, 50);
213 activationIds = activationIds.subList(0, 50);
214 }
215
216 valueMap = sortWithWeight(activationIds, model.getNewElements().size());
217 } 214 }
215
216 valueMap = sortWithWeight(activationIds, model.getNewElements().size());
218 lastState = context.getCurrentStateId(); 217 lastState = context.getCurrentStateId();
219 while (!isInterrupted && !configuration.progressMonitor.isCancelled() && activationIds.size() > 0) { 218 while (!isInterrupted && !configuration.progressMonitor.isCancelled() && activationIds.size() > 0) {
220 final Object nextActivation = drawWithEpsilonProbabilty(activationIds, valueMap, epsilon, draw); 219 final Object nextActivation = drawWithEpsilonProbabilty(activationIds, valueMap, epsilon, draw);
@@ -224,62 +223,44 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
224 context.executeAcitvationId(nextActivation); 223 context.executeAcitvationId(nextActivation);
225 visualiseCurrentState(); 224 visualiseCurrentState();
226 225
227 //calculate the metrics for each state
228// logCurrentStateMetric();
229
230 boolean consistencyCheckResult = checkConsistency(currentTrajectoryWithFittness); 226 boolean consistencyCheckResult = checkConsistency(currentTrajectoryWithFittness);
231 if(consistencyCheckResult == true) { continue mainLoop; } 227 if(consistencyCheckResult == true) { continue mainLoop; }
232 boolean shouldFinish = model.getNewElements().size() >= configuration.typeScopes.maxNewElements + 2;
233 228
234 /* if (context.isCurrentStateAlreadyTraversed()) { 229 int currentSize = model.getNewElements().size();
235// logger.info("The new state is already visited."); 230 int targetDiff = targetSize - currentSize;
231 boolean shouldFinish = targetDiff <= 0;
232
233 // does not allow must violations
234// if((getNumberOfViolations(mustMatchers) > 0|| getNumberOfViolations(mayMatchers) > targetDiff) && !shouldFinish) {
236// context.backtrack(); 235// context.backtrack();
237// } else if (!context.checkGlobalConstraints()) { 236// }else {
238 logger.debug("Global contraint is not satisifed.");
239 context.backtrack();
240 } else*/// {
241 if(getNumberOfViolations(mustMatchers) > 0 && !shouldFinish) {
242 context.backtrack();
243 }else if(model.getNewElements().size() > 90 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 5.3) {
244 context.backtrack();
245 }else if(model.getNewElements().size() > 70 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 5.45) {
246 context.backtrack();
247 } else if(model.getNewElements().size() > 50 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 5.60) {
248 context.backtrack();
249 } else if(model.getNewElements().size() > 30 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 5.70) {
250 context.backtrack();
251 }else {
252 final Fitness nextFitness = context.calculateFitness(); 237 final Fitness nextFitness = context.calculateFitness();
253 238
254 // the only hard objectives are the size 239 // the only hard objectives are the size
255
256 if(shouldFinish) { 240 if(shouldFinish) {
257 System.out.println("Solution Found!!"); 241 System.out.println("Solution Found!!");
258 System.out.println("# violations: " + (getNumberOfViolations(mustMatchers) + getNumberOfViolations(mayMatchers))); 242 System.out.println("# violations: " + (getNumberOfViolations(mustMatchers)));
259 nextFitness.setSatisifiesHardObjectives(true); 243 nextFitness.setSatisifiesHardObjectives(true);
260 } 244 }
261
262 checkForSolution(nextFitness); 245 checkForSolution(nextFitness);
263 246
264
265
266 if (context.getDepth() > configuration.searchSpaceConstraints.maxDepth) { 247 if (context.getDepth() > configuration.searchSpaceConstraints.maxDepth) {
267 logger.debug("Reached max depth."); 248 logger.debug("Reached max depth.");
268 context.backtrack(); 249 context.backtrack();
269 continue; 250 continue;
270 } 251 }
271
272 252
253 //Record value for current trajectory
273 TrajectoryWithFitness nextTrajectoryWithFittness = new TrajectoryWithFitness( 254 TrajectoryWithFitness nextTrajectoryWithFittness = new TrajectoryWithFitness(
274 context.getTrajectory().toArray(), nextFitness); 255 context.getTrajectory().toArray(), nextFitness);
275 int step = nextTrajectoryWithFittness.trajectory.length; 256 int nodeSize = ((PartialInterpretation) context.getModel()).getNewElements().size();
276 int violation = getNumberOfViolations(mustMatchers) + getNumberOfViolations(mayMatchers); 257 int violation = getNumberOfViolations(mayMatchers);
277 metricDistance.getLinearModel().feedData(context.getCurrentStateId(), metricDistance.calculateFeature(step, violation), calculateCurrentStateValue(step, violation), lastState); 258 metricDistance.getLinearModel().feedData(context.getCurrentStateId(), metricDistance.calculateFeature(nodeSize, violation), calculateCurrentStateValue(nodeSize, violation), lastState);
278 double value = calculateCurrentStateValue(step, violation); 259 double value = calculateCurrentStateValue(nodeSize, violation);
279
280 trajectoryFit.put(nextTrajectoryWithFittness, value); 260 trajectoryFit.put(nextTrajectoryWithFittness, value);
281 trajectoiresToExplore.add(nextTrajectoryWithFittness); 261 trajectoiresToExplore.add(nextTrajectoryWithFittness);
282 262
263 //Currently, just go to the next state without considering the value of trajectory
283 int compare = objectiveComparatorHelper.compare(currentTrajectoryWithFittness.fitness, 264 int compare = objectiveComparatorHelper.compare(currentTrajectoryWithFittness.fitness,
284 nextTrajectoryWithFittness.fitness); 265 nextTrajectoryWithFittness.fitness);
285 if (compare < 0) { 266 if (compare < 0) {
@@ -301,10 +282,8 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
301 trajectoiresToExplore.remove(currentTrajectoryWithFittness); 282 trajectoiresToExplore.remove(currentTrajectoryWithFittness);
302 currentTrajectoryWithFittness = null; 283 currentTrajectoryWithFittness = null;
303 context.backtrack(); 284 context.backtrack();
304 } 285 //}
305// PartialInterpretation model = (PartialInterpretation) context.getModel(); 286
306// PartialInterpretationMetric.calculateMetric(model, "debug/metric/output", context.getCurrentStateId().toString(), count);
307// count++;
308 logger.info("Interrupted."); 287 logger.info("Interrupted.");
309 } 288 }
310 289
@@ -315,7 +294,8 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
315 */ 294 */
316 private Map<Object, Double> sortWithWeight(List<Object> activationIds, int factor){ 295 private Map<Object, Double> sortWithWeight(List<Object> activationIds, int factor){
317 Map<Object, Double> valueMap = new HashMap<Object, Double>(); 296 Map<Object, Double> valueMap = new HashMap<Object, Double>();
318 // do hill climbing 297
298 // check for next states
319 for(Object id : activationIds) { 299 for(Object id : activationIds) {
320 context.executeAcitvationId(id); 300 context.executeAcitvationId(id);
321 int violation = getNumberOfViolations(mayMatchers); 301 int violation = getNumberOfViolations(mayMatchers);
@@ -323,6 +303,7 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
323 context.backtrack(); 303 context.backtrack();
324 } 304 }
325 305
306 //remove all the elements having large distance
326 activationIds.removeIf(li -> valueMap.get(li) >= 10000); 307 activationIds.removeIf(li -> valueMap.get(li) >= 10000);
327 Collections.sort(activationIds, Comparator.comparing(li -> valueMap.get(li))); 308 Collections.sort(activationIds, Comparator.comparing(li -> valueMap.get(li)));
328 return valueMap; 309 return valueMap;
@@ -346,12 +327,10 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
346 private double calculateCurrentStateValue(int factor, int violation) { 327 private double calculateCurrentStateValue(int factor, int violation) {
347 PartialInterpretation model = (PartialInterpretation) context.getModel(); 328 PartialInterpretation model = (PartialInterpretation) context.getModel();
348 MetricDistanceGroup g = metricDistance.calculateMetricDistanceKS(model); 329 MetricDistanceGroup g = metricDistance.calculateMetricDistanceKS(model);
349 if(g.getNodeTypePercentage("Exit") > 0.0) { 330
350 return 10000; 331 double consistenceWeights = 5 * factor / (configuration.typeScopes.maxNewElements + 2) * (1- 1.0/(1+violation));
351 } 332
352 double consistenceWeights = 1- 1.0/(1+violation); 333 return(100.0 *(g.getNodeTypeDistance()) + 5*(g.getNADistance() + g.getMPCDistance() +g.getOutDegreeDistance()) + consistenceWeights);
353
354 return( 5.0 *(g.getNADistance() + g.getMPCDistance() + g.getOutDegreeDistance()) + consistenceWeights);
355 } 334 }
356 335
357 private int getNumberOfViolations(Collection<ViatraQueryMatcher<? extends IPatternMatch>> matchers) { 336 private int getNumberOfViolations(Collection<ViatraQueryMatcher<? extends IPatternMatch>> matchers) {