From 69729ebb20fc34f5d836d0ba9dc416114f2c9c4a Mon Sep 17 00:00:00 2001 From: 20001LastOrder Date: Mon, 24 Jun 2019 10:12:34 -0400 Subject: Implement linear regressor using Weka3 --- ...nRealisticMetricStrategyForModelGeneration.java | 110 +++++++++++---------- 1 file changed, 60 insertions(+), 50 deletions(-) (limited to 'Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu') diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java index 148cb243..d9f6f2aa 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java @@ -26,6 +26,7 @@ import org.eclipse.viatra.query.runtime.api.ViatraQueryMatcher; import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.MetricDistanceGroup; import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetric; +import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetricDistance; import hu.bme.mit.inf.dslreasoner.logic.model.builder.DocumentationLevel; import hu.bme.mit.inf.dslreasoner.logic.model.builder.LogicReasoner; import hu.bme.mit.inf.dslreasoner.logic.model.logicproblem.LogicProblem; @@ -62,12 +63,12 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements private Collection> mayMatchers; private Map> stateAndActivations; private Map trajectoryFit; - // Statistics private int numberOfStatecoderFail = 0; private int numberOfPrintedModel = 0; private int numberOfSolverCalls = 0; - + private PartialInterpretationMetricDistance metricDistance; + public HillClimbingOnRealisticMetricStrategyForModelGeneration( ReasonerWorkspace workspace, ViatraReasonerConfiguration configuration, @@ -112,7 +113,7 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements this.solutionStoreWithCopy = new SolutionStoreWithCopy(); this.solutionStoreWithDiversityDescriptor = new SolutionStoreWithDiversityDescriptor(configuration.diversityRequirement); - final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper(); + //final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper(); trajectoryFit = new HashMap(); this.comparator = new Comparator() { @@ -124,6 +125,7 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements trajectoiresToExplore = new PriorityQueue(11, comparator); stateAndActivations = new HashMap>(); + metricDistance = new PartialInterpretationMetricDistance(); } @Override @@ -145,11 +147,11 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements TrajectoryWithFitness currentTrajectoryWithFittness = new TrajectoryWithFitness(firstTrajectory, firstFittness); trajectoryFit.put(currentTrajectoryWithFittness, Double.MAX_VALUE); trajectoiresToExplore.add(currentTrajectoryWithFittness); + Object lastState = null; //if(configuration) visualiseCurrentState(); - PartialInterpretationMetric.initPaths(); //create matcher int count = 0; mainLoop: while (!isInterrupted && !configuration.progressMonitor.isCancelled()) { @@ -165,6 +167,9 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements logger.debug("New trajectory is chosen: " + currentTrajectoryWithFittness); } context.getDesignSpaceManager().executeTrajectoryWithMinimalBacktrackWithoutStateCoding(currentTrajectoryWithFittness.trajectory); + + // reset the regression for this trajectory + metricDistance.getLinearModel().resetRegression(context.getCurrentStateId()); } } @@ -178,10 +183,10 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements //init epsilon and draw double epsilon = 0; double draw = 1; - MetricDistanceGroup heuristics = PartialInterpretationMetric.calculateMetricDistanceKS(model); - - if(!stateAndActivations.containsKey(model)) { - stateAndActivations.put(model, new ArrayList()); + MetricDistanceGroup heuristics = metricDistance.calculateMetricDistanceKS(model); + + if(!stateAndActivations.containsKey(context.getCurrentStateId())) { + stateAndActivations.put(context.getCurrentStateId(), new ArrayList()); } //Output intermediate model @@ -190,41 +195,26 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements draw = Math.random(); count++; // cut off the trajectory for bad graph - double distance = heuristics.getMPCDistance(); System.out.println("KS"); System.out.println("NA distance: " + heuristics.getNADistance()); System.out.println("MPC distance: " + heuristics.getMPCDistance()); System.out.println("Out degree distance:" + heuristics.getOutDegreeDistance()); - MetricDistanceGroup lsHeuristic = PartialInterpretationMetric.calculateMetricDistance(model); - System.out.println("LS"); - System.out.println("NA distance: " + lsHeuristic.getNADistance()); - System.out.println("MPC distance: " + lsHeuristic.getMPCDistance()); - System.out.println("Out degree distance:" + lsHeuristic.getOutDegreeDistance()); - - if(distance <= 0.23880597014925373 + 0.00001 && distance >= 0.23880597014925373 - 0.00001) { - context.backtrack(); - final Fitness nextFitness = context.calculateFitness(); - currentTrajectoryWithFittness = new TrajectoryWithFitness(context.getTrajectory().toArray(), nextFitness); - continue; - } +// MetricDistanceGroup lsHeuristic = metricDistance.calculateMetricDistance(model); +// System.out.println("LS"); +// System.out.println("NA distance: " + lsHeuristic.getNADistance()); +// System.out.println("MPC distance: " + lsHeuristic.getMPCDistance()); +// System.out.println("Out degree distance:" + lsHeuristic.getOutDegreeDistance()); //check for next value when doing greedy move + valueMap = sortWithWeight(activationIds, currentTrajectoryWithFittness.trajectory.length+1); -// if(activationIds.isEmpty() || (model.getNewElements().size() >= 20 && valueMap.get(activationIds.get(0)) > currentValue && epsilon < draw)) { -// context.backtrack(); -// final Fitness nextFitness = context.calculateFitness(); -// currentTrajectoryWithFittness = new TrajectoryWithFitness(context.getTrajectory().toArray(), nextFitness); -// continue; -// } } + lastState = context.getCurrentStateId(); while (!isInterrupted && !configuration.progressMonitor.isCancelled() && activationIds.size() > 0) { final Object nextActivation = drawWithEpsilonProbabilty(activationIds, valueMap, epsilon, draw); -// if (!iterator.hasNext()) { -// logger.debug("Last untraversed activation of the state."); -// trajectoiresToExplore.remove(currentTrajectoryWithFittness); -// } - stateAndActivations.get(context.getModel()).add(nextActivation); + + stateAndActivations.get(context.getCurrentStateId()).add(nextActivation); logger.debug("Executing new activation: " + nextActivation); context.executeAcitvationId(nextActivation); visualiseCurrentState(); @@ -242,15 +232,17 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements logger.debug("Global contraint is not satisifed."); context.backtrack(); } else*/// { - if(getNumberOfViolations(mustMatchers) > 8) { + /*if(getNumberOfViolations(mustMatchers) > 0) { + context.backtrack(); + }else*/ if(model.getNewElements().size() > 90 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.3) { context.backtrack(); - }else if(model.getNewElements().size() > 90 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.18) { + }else if(model.getNewElements().size() > 70 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.45) { context.backtrack(); - }else if(model.getNewElements().size() > 70 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.36) { + } else if(model.getNewElements().size() > 50 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.60) { context.backtrack(); - } else if(model.getNewElements().size() > 50 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.72) { + } else if(model.getNewElements().size() > 30 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.70) { context.backtrack(); - } else { + }else { final Fitness nextFitness = context.calculateFitness(); // the only hard objectives are the size @@ -271,9 +263,14 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements continue; } + TrajectoryWithFitness nextTrajectoryWithFittness = new TrajectoryWithFitness( context.getTrajectory().toArray(), nextFitness); - trajectoryFit.put(nextTrajectoryWithFittness, calculateCurrentStateValue(nextTrajectoryWithFittness.trajectory.length)); + int step = nextTrajectoryWithFittness.trajectory.length; + int violation = getNumberOfViolations(mustMatchers) + getNumberOfViolations(mayMatchers); + metricDistance.getLinearModel().feedData(context.getCurrentStateId(), metricDistance.calculateFeature(step, violation), calculateCurrentStateValue(step, violation), lastState); + + trajectoryFit.put(nextTrajectoryWithFittness, calculateCurrentStateValue(step, violation)); trajectoiresToExplore.add(nextTrajectoryWithFittness); int compare = objectiveComparatorHelper.compare(currentTrajectoryWithFittness.fitness, @@ -298,9 +295,9 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements currentTrajectoryWithFittness = null; context.backtrack(); } - PartialInterpretation model = (PartialInterpretation) context.getModel(); - PartialInterpretationMetric.calculateMetric(model, "debug/metric/output", context.getCurrentStateId().toString(), count); - count++; +// PartialInterpretation model = (PartialInterpretation) context.getModel(); +// PartialInterpretationMetric.calculateMetric(model, "debug/metric/output", context.getCurrentStateId().toString(), count); +// count++; logger.info("Interrupted."); } @@ -314,8 +311,8 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements // do hill climbing for(Object id : activationIds) { context.executeAcitvationId(id); - - valueMap.put(id, calculateCurrentStateValue(factor)); + int violation = getNumberOfViolations(mayMatchers) + getNumberOfViolations(mustMatchers); + valueMap.put(id, calculateFutureStateValue(factor, violation)); context.backtrack(); } @@ -323,14 +320,27 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements return valueMap; } - private double calculateCurrentStateValue(int factor) { + private double calculateFutureStateValue(int step, int violation) { + double currentValue = calculateCurrentStateValue(step, violation); + if(step > 40) { + double[] toPredict = metricDistance.calculateFeature(200, violation); + try { + return metricDistance.getLinearModel().getPredictionForNextDataSample(metricDistance.calculateFeature(step, violation), currentValue, toPredict); + }catch(IllegalArgumentException e) { + return currentValue; + } + }else { + return currentValue; + } + } + + private double calculateCurrentStateValue(int factor, int violation) { PartialInterpretation model = (PartialInterpretation) context.getModel(); - MetricDistanceGroup g = PartialInterpretationMetric.calculateMetricDistanceKS(model); + MetricDistanceGroup g = metricDistance.calculateMetricDistanceKS(model); - int violations = getNumberOfViolations(mayMatchers); - double consistenceWeights = 1.0/(1+violations); + double consistenceWeights = 1- 1.0/(1+violation); - return(2.5 / Math.log(factor) * (g.getNADistance() + g.getMPCDistance() + g.getOutDegreeDistance()) + 1-consistenceWeights); + return( /*/ Math.log(factor)*/(g.getNADistance() + g.getMPCDistance() + g.getOutDegreeDistance()) + consistenceWeights); } private int getNumberOfViolations(Collection> matchers) { @@ -379,8 +389,8 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements List activationIds; try { activationIds = new ArrayList(context.getCurrentActivationIds()); - if(stateAndActivations.containsKey(context.getModel())) { - activationIds.removeAll(stateAndActivations.get(context.getModel())); + if(stateAndActivations.containsKey(context.getCurrentStateId())) { + activationIds.removeAll(stateAndActivations.get(context.getCurrentStateId())); } Collections.shuffle(activationIds); } catch (NullPointerException e) { -- cgit v1.2.3-54-g00ecf