From a882ad00515e730bad5e52fa29b74f461a5b9cd6 Mon Sep 17 00:00:00 2001 From: 20001LastOrder Date: Tue, 13 Aug 2019 18:10:02 -0400 Subject: change exploration value function --- ...nRealisticMetricStrategyForModelGeneration.java | 136 +++++++++++---------- 1 file changed, 74 insertions(+), 62 deletions(-) (limited to 'Solvers/VIATRA-Solver') diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java index 4dff00cd..c817d20b 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java @@ -6,12 +6,14 @@ import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.Random; +import java.util.Set; import org.apache.log4j.Logger; import org.eclipse.emf.ecore.util.EcoreUtil; @@ -64,19 +66,21 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements // matchers for detecting the number of violations private Collection> mustMatchers; private Collection> mayMatchers; + // Encode the used activations of a particular state private Map> stateAndActivations; - private Map trajectoryFit; - private boolean allowMustViolation; private Domain domain; - + int targetSize; + // Statistics private int numberOfStatecoderFail = 0; private int numberOfPrintedModel = 0; private int numberOfSolverCalls = 0; private PartialInterpretationMetricDistance metricDistance; - + private double currentStateValue = Double.MAX_VALUE; + private double currentNodeTypeDistance = 1; + private int numNodesToGenerate = 0; public HillClimbingOnRealisticMetricStrategyForModelGeneration( ReasonerWorkspace workspace, @@ -102,12 +106,14 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements public void initStrategy(ThreadContext context) { this.context = context; this.solutionStore = context.getGlobalContext().getSolutionStore(); - + domain = Domain.valueOf(configuration.domain); + ViatraQueryEngine engine = context.getQueryEngine(); // // TODO: visualisation mustMatchers = new LinkedList>(); mayMatchers = new LinkedList>(); - + + // manully restict the number of super types of one class this.method.getInvalidWF().forEach(a ->{ ViatraQueryMatcher matcher = a.getMatcher(engine); mustMatchers.add(matcher); @@ -117,27 +123,27 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements ViatraQueryMatcher matcher = a.getMatcher(engine); mayMatchers.add(matcher); }); + - - this.solutionStoreWithCopy = new SolutionStoreWithCopy(); - this.solutionStoreWithDiversityDescriptor = new SolutionStoreWithDiversityDescriptor(configuration.diversityRequirement); - - trajectoryFit = new HashMap(); + //set up comparator + final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper(); this.comparator = new Comparator() { @Override public int compare(TrajectoryWithFitness o1, TrajectoryWithFitness o2) { - return Double.compare(trajectoryFit.get(o1), trajectoryFit.get(o2)); + return objectiveComparatorHelper.compare(o2.fitness, o1.fitness); } }; + this.solutionStoreWithCopy = new SolutionStoreWithCopy(); + this.solutionStoreWithDiversityDescriptor = new SolutionStoreWithDiversityDescriptor(configuration.diversityRequirement); + trajectoiresToExplore = new PriorityQueue(11, comparator); stateAndActivations = new HashMap>(); - - domain = Domain.valueOf(configuration.domain); metricDistance = new PartialInterpretationMetricDistance(domain); //set whether allows must violations during the realistic generation allowMustViolation = configuration.allowMustViolations; + targetSize = configuration.typeScopes.maxNewElements + 2; } @Override @@ -156,14 +162,12 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements //final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper(); final Object[] firstTrajectory = context.getTrajectory().toArray(new Object[0]); TrajectoryWithFitness currentTrajectoryWithFittness = new TrajectoryWithFitness(firstTrajectory, firstFittness); - trajectoryFit.put(currentTrajectoryWithFittness, Double.MAX_VALUE); trajectoiresToExplore.add(currentTrajectoryWithFittness); Object lastState = null; //if(configuration) visualiseCurrentState(); // the two is the True and False node generated at the beginning of the generation - int targetSize = configuration.typeScopes.maxNewElements + 2; int count = 0; mainLoop: while (!isInterrupted && !configuration.progressMonitor.isCancelled()) { @@ -202,21 +206,22 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements double epsilon = 1.0/count; double draw = Math.random(); count++; - + this.currentNodeTypeDistance = heuristics.getNodeTypeDistance(); + numNodesToGenerate = model.getMaxNewElements(); System.out.println("NA distance: " + heuristics.getNADistance()); System.out.println("MPC distance: " + heuristics.getMPCDistance()); System.out.println("Out degree distance:" + heuristics.getOutDegreeDistance()); - System.out.println("NodeType :" + heuristics.getNodeTypeDistance()); - System.out.println("Edge :" + heuristics.edgeTypeDistance); + System.out.println("NodeType :" + currentNodeTypeDistance); // System.out.println("FinalState :" + heuristics.getNodeTypePercentage("FinalState")); //TODO: the number of activations to be checked should be configurasble + System.out.println(activationIds.size()); if(activationIds.size() > 50) { activationIds = activationIds.subList(0, 50); } - valueMap = sortWithWeight(activationIds, model.getNewElements().size()); + valueMap = sortWithWeight(activationIds); lastState = context.getCurrentStateId(); while (!isInterrupted && !configuration.progressMonitor.isCancelled() && activationIds.size() > 0) { final Object nextActivation = drawWithEpsilonProbabilty(activationIds, valueMap, epsilon, draw); @@ -230,9 +235,10 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements int currentSize = model.getNewElements().size(); int targetDiff = targetSize - currentSize; + boolean shouldFinish = currentSize >= targetSize; // does not allow must violations - if((getNumberOfViolations(mustMatchers) > 0|| getNumberOfViolations(mayMatchers) > targetDiff) && !allowMustViolation) { + if((getNumberOfViolations(mustMatchers) > 0|| getNumberOfViolations(mayMatchers) > targetDiff) && !allowMustViolation && !shouldFinish) { context.backtrack(); }else { final Fitness nextFitness = context.calculateFitness(); @@ -251,38 +257,21 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements context.getTrajectory().toArray(), nextFitness); int nodeSize = ((PartialInterpretation) context.getModel()).getNewElements().size(); int violation = getNumberOfViolations(mayMatchers); - metricDistance.getLinearModel().feedData(context.getCurrentStateId(), metricDistance.calculateFeature(nodeSize, violation), calculateCurrentStateValue(nodeSize, violation), lastState); - double value = calculateCurrentStateValue(nodeSize, violation); - trajectoryFit.put(nextTrajectoryWithFittness, value); + double currentValue = calculateCurrentStateValue(nodeSize, violation); + metricDistance.getLinearModel().feedData(context.getCurrentStateId(), metricDistance.calculateFeature(nodeSize, violation), currentValue, lastState); trajectoiresToExplore.add(nextTrajectoryWithFittness); - + currentStateValue = currentValue; //Currently, just go to the next state without considering the value of trajectory currentTrajectoryWithFittness = nextTrajectoryWithFittness; continue mainLoop; -// int compare = objectiveComparatorHelper.compare(currentTrajectoryWithFittness.fitness, -// nextTrajectoryWithFittness.fitness); -// if (compare < 0) { -// logger.debug("Better fitness, moving on: " + nextFitness); -// currentTrajectoryWithFittness = nextTrajectoryWithFittness; -// continue mainLoop; -// } else if (compare == 0) { -// logger.debug("Equally good fitness, moving on: " + nextFitness); -// currentTrajectoryWithFittness = nextTrajectoryWithFittness; -// continue mainLoop; -// } else { -// logger.debug("Worse fitness."); -// currentTrajectoryWithFittness = nextTrajectoryWithFittness; -// continue mainLoop; - } } } logger.debug("State is fully traversed."); trajectoiresToExplore.remove(currentTrajectoryWithFittness); currentTrajectoryWithFittness = null; context.backtrack(); -// } - + } logger.info("Interrupted."); } @@ -291,14 +280,23 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements * @param activationIds * @return: activation to value map */ - private Map sortWithWeight(List activationIds, int factor){ + private Map sortWithWeight(List activationIds){ Map valueMap = new HashMap(); - + Object currentId = context.getCurrentStateId(); // check for next states for(Object id : activationIds) { context.executeAcitvationId(id); int violation = getNumberOfViolations(mayMatchers); - valueMap.put(id, calculateFutureStateValue(factor, violation)); + + if(!allowMustViolation && getNumberOfViolations(mustMatchers) > 0) { + valueMap.put(id, Double.MAX_VALUE); + stateAndActivations.get(currentId).add(id); + }else { + valueMap.put(id, calculateFutureStateValue(violation)); + } + + + context.backtrack(); } @@ -308,25 +306,22 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements return valueMap; } - private double calculateFutureStateValue(int step, int violation) { - double currentValue = calculateCurrentStateValue(step, violation); - - if(step > 10 && currentValue < 10000) { - double[] toPredict = metricDistance.calculateFeature(100, violation); - try { - return metricDistance.getLinearModel().getPredictionForNextDataSample(metricDistance.calculateFeature(step, violation), currentValue, toPredict); - }catch(IllegalArgumentException e) { - return currentValue; - } - }else { + private double calculateFutureStateValue(int violation) { + int nodeSize = ((PartialInterpretation) context.getModel()).getNewElements().size(); + double currentValue = calculateCurrentStateValue(nodeSize,violation); + double[] toPredict = metricDistance.calculateFeature(100, violation); + if(Math.abs(currentValue - currentStateValue) < 0.001) { + return Double.MAX_VALUE; + } + try { + return metricDistance.getLinearModel().getPredictionForNextDataSample(metricDistance.calculateFeature(nodeSize, violation), currentValue, toPredict); + }catch(IllegalArgumentException e) { return currentValue; } } - private double calculateCurrentStateValue(int factor, int violation) { PartialInterpretation model = (PartialInterpretation) context.getModel(); MetricDistanceGroup g = metricDistance.calculateMetricDistanceKS(model); - if(configuration.realisticGuidance == RealisticGuidance.MPC) { return g.getMPCDistance(); }else if(configuration.realisticGuidance == RealisticGuidance.NodeActivity) { @@ -337,18 +332,35 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements return g.getNodeTypeDistance(); }else if(configuration.realisticGuidance == RealisticGuidance.Composite) { double consistenceWeights = 5 * factor / (configuration.typeScopes.maxNewElements + 2) * (1- 1.0/(1+violation)); - if(domain == Domain.Yakindumm) { - return(100.0 *(g.getNodeTypeDistance()) + 5*(g.getNADistance() + 5*g.getMPCDistance() +g.getOutDegreeDistance()) + consistenceWeights); + double unfinishFactor = 50 * (1 - (double)factor / targetSize); + double nodeTypeFactor = g.getNodeTypeDistance(); + double normalFactor = 5; + if(currentNodeTypeDistance <= 0.05 || numNodesToGenerate == 1) { + nodeTypeFactor = 0; + normalFactor = 100; + unfinishFactor = 0; + } + + return 100*(nodeTypeFactor) + normalFactor*(2*g.getNADistance() + g.getMPCDistance() + 2*g.getOutDegreeDistance()) + normalFactor / 5*consistenceWeights + unfinishFactor; }else { - return 10*(g.getNodeTypeDistance()) + 5*(g.getNADistance() + g.getMPCDistance() +2*g.getOutDegreeDistance()) + consistenceWeights; + double unfinishFactor = 100 * (1 - (double)factor / targetSize); + double nodeTypeFactor = g.getNodeTypeDistance(); + double normalFactor = 5; + if(currentNodeTypeDistance <= 0.12 || numNodesToGenerate == 1) { + nodeTypeFactor = 0; + normalFactor = 100; + unfinishFactor *= 0.5; + } + + return 100*(nodeTypeFactor) + normalFactor*(2*g.getNADistance() + g.getMPCDistance() + 2*g.getOutDegreeDistance()) + normalFactor / 5*consistenceWeights + unfinishFactor; } }else if(configuration.realisticGuidance == RealisticGuidance.Composite_Without_Violations) { if(domain == Domain.Yakindumm) { return 100.0 *(g.getNodeTypeDistance()) + 5*(g.getNADistance() + g.getMPCDistance() +g.getOutDegreeDistance()); }else { - return 10*(g.getNodeTypeDistance()) + 5*(g.getNADistance() + g.getMPCDistance() + 2*g.getOutDegreeDistance()); + return 15*(g.getNodeTypeDistance()) + 5*(g.getNADistance() + g.getMPCDistance() + 4*g.getOutDegreeDistance()); } }else { return violation; -- cgit v1.2.3-54-g00ecf