From 69729ebb20fc34f5d836d0ba9dc416114f2c9c4a Mon Sep 17 00:00:00 2001 From: 20001LastOrder Date: Mon, 24 Jun 2019 10:12:34 -0400 Subject: Implement linear regressor using Weka3 --- .../.classpath | 1 + .../META-INF/MANIFEST.MF | 7 +- .../app/PartialInterpretationMetricDistance.xtend | 86 ++++++++++++---- .../realistic/metrics/calculator/app/Test.java | 31 ++++++ .../metrics/calculator/distance/CostDistance.xtend | 15 +-- .../metrics/calculator/io/CsvFileWriter.xtend | 19 +++- .../metrics/calculator/predictor/LinearModel.xtend | 91 +++++++++++++++++ ...nRealisticMetricStrategyForModelGeneration.java | 110 +++++++++++---------- 8 files changed, 277 insertions(+), 83 deletions(-) create mode 100644 Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java create mode 100644 Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath index f4f8357b..006b2686 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath @@ -6,5 +6,6 @@ + diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF index da19e07c..febd4757 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF @@ -15,4 +15,9 @@ Require-Bundle: com.google.guava, hu.bme.mit.inf.dslreasoner.domains.yakindu.sgraph;bundle-version="1.0.0", hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage;bundle-version="1.0.0", org.eclipse.viatra.dse;bundle-version="0.21.2" -Export-Package: ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app +Export-Package: ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app, + ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance, + ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.graph, + ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.io, + ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics, + ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend index b63451e8..45986ecf 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend @@ -8,26 +8,35 @@ import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.Metric import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric +import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor.LinearModel import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation import java.util.ArrayList import java.util.HashMap +import java.util.List import java.util.Map -import org.apache.commons.math3.stat.regression.SimpleRegression -import java.util.stream.DoubleStream.Builder +import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression +import org.eclipse.xtend.lib.annotations.Accessors class PartialInterpretationMetricDistance { var KSDistance ks; var JSDistance js; var Map stateAndHistory; - var SimpleRegression regression; + var OLSMultipleLinearRegression regression; + List samples; + + @Accessors(PUBLIC_GETTER) + var LinearModel linearModel; new(){ ks = new KSDistance(Domain.Yakinduum); js = new JSDistance(Domain.Yakinduum); - regression = new SimpleRegression(); + regression = new OLSMultipleLinearRegression(); + regression.noIntercept = false; stateAndHistory = new HashMap(); + samples = new ArrayList(); + linearModel = new LinearModel(0.01); } def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){ @@ -63,37 +72,76 @@ class PartialInterpretationMetricDistance { } def resetRegression(Object state){ - regression = new SimpleRegression(); + samples.clear(); if(stateAndHistory.containsKey(state)){ var data = stateAndHistory.get(state); - regression.addData(data.numOfNodeFeature, data.value); - while(stateAndHistory.containsKey(data.lastState)){ + var curState = state; + + samples.add(data); + + while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){ + curState = data.lastState; data = stateAndHistory.get(data.lastState); - regression.addData(data.numOfNodeFeature, data.value); + samples.add(data); + } + + if(samples.size == 0){ + println('state: ' + state); + println('last state: ' + data.lastState); } } + println("trajectory sample size:" + samples.size) } - def feedData(Object state, int numOfNodes, double value, Object lastState){ - var data = new StateData(numOfNodes, value, lastState); + def feedData(Object state, double[] features, double value, Object lastState){ + var data = new StateData(features, value, lastState); stateAndHistory.put(state, data); - regression.addData(data.numOfNodeFeature, data.value); + samples.add(data); } - def getPredictionForNextDataSample(int numOfNodes, double value, int numberOfNodesToPredict){ - var data = new StateData(numOfNodes, value, null); - regression.addData(data.numOfNodeFeature, data.value); + def getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){ + if(samples.size <= 4){ + println('OK'); + } + var data = new StateData(features, value, null); + samples.add(data); + + // create training set from current data + var double[][] xSamples = samples.map[it.features]; + var double[] ySamples = samples.map[it.value]; + - var prediction = predict(numberOfNodesToPredict); - regression.removeData(data.numOfNodeFeature, data.value); + regression.newSampleData(ySamples, xSamples); + var prediction = predict(featuresToPredict); + + //remove the last element just added + samples.remove(samples.size - 1); return prediction; } - def predict(int numOfNodes){ - var data = new StateData(numOfNodes, 0, null); - return regression.predict(data.numOfNodeFeature); + def private predict(double[] featuresToPredict){ + var parameters = regression.estimateRegressionParameters(); + // the regression will add an initial column for 1's, the first parameter is constant term + var result = parameters.get(0); + for(var i = 0; i < featuresToPredict.length; i++){ + result += parameters.get(i+1) * featuresToPredict.get(i); + } + return result; + } + + def double[] calculateFeature(int step, int violations){ + var features = newDoubleArrayOfSize(5); + //constant term + features.set(0, 1); + + features.set(1, 1.0 / step); + features.set(2, violations); + features.set(3, Math.pow(violations, 2)); + features.set(4, Math.pow(violations, 0.5)); + + return features; } } diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java new file mode 100644 index 00000000..f06b377f --- /dev/null +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java @@ -0,0 +1,31 @@ +package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app; + +import java.util.ArrayList; +import java.util.List; + +import weka.core.matrix.LinearRegression; +import weka.core.matrix.Matrix; + +public class Test { + public static void main(String[] args) { + linearRegressionTest(); + } + + public static void linearRegressionTest() { + double[][] x = {{1,1,2,3}, {1,2,3,4}, {1,3,5,7}, {1,1,5,7}}; + double[] y = {10, 13, 19, 17}; + double[] valueToPredict = {1,1,1,1}; + Matrix m = new Matrix(x); + Matrix n = new Matrix(y, y.length); + + LinearRegression regression = new LinearRegression(m, n, 0); + double[] coef = regression.getCoefficients(); + + //predict + double a = 0; + for(int i = 0; i < coef.length; i++) { + a += coef[i] * valueToPredict[i]; + } + System.out.println(a); + } +} diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend index ee856201..33d10fa3 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend @@ -1,28 +1,21 @@ package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance -import org.apache.commons.math3.stat.regression.SimpleRegression import org.eclipse.xtend.lib.annotations.Accessors class CostDistance { - - var SimpleRegression regression; - - new(){ - regression = new SimpleRegression(true); - } - + } class StateData{ @Accessors(PUBLIC_GETTER) - var double numOfNodeFeature; + var double[] features; @Accessors(PUBLIC_GETTER) var double value; @Accessors(PUBLIC_GETTER) var Object lastState; - new(int numOfNode, double value, Object lastState){ - this.numOfNodeFeature = 1.0 / numOfNode; + new(double[] features, double value, Object lastState){ + this.features = features; this.value = value this.lastState = lastState; } diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend index 01e3940b..00b38d90 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend @@ -2,19 +2,34 @@ package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.io; import java.io.File import java.io.FileNotFoundException +import java.io.FileOutputStream import java.io.PrintWriter import java.util.ArrayList import java.util.List class CsvFileWriter { + def static void write(ArrayList> datas, String uri) { if(datas.size() <= 0) { return; } - + val PrintWriter writer = new PrintWriter(new File(uri)); + output(writer, datas, uri); + } + + def static void append(ArrayList> datas, String uri) { + if(datas.size() <= 0) { + return; + } + val PrintWriter writer = new PrintWriter(new FileOutputStream(new File(uri), true)); + output(writer, datas, uri); + } + + + def private static void output(PrintWriter writer, ArrayList> datas, String uri) { //println("Output csv for " + uri); try { - val PrintWriter writer = new PrintWriter(new File(uri)); + val output = new StringBuilder; for(List datarow : datas){ for(var i = 0; i < datarow.size() - 1; i++){ diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend new file mode 100644 index 00000000..f0ded347 --- /dev/null +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend @@ -0,0 +1,91 @@ +package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor + +import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance.StateData +import java.util.ArrayList +import java.util.HashMap +import java.util.List +import java.util.Map +import weka.core.matrix.LinearRegression +import weka.core.matrix.Matrix + +class LinearModel { + var double ridge; + var Map stateAndHistory; + List samples; + + new(double ridge){ + this.ridge = ridge; + stateAndHistory = new HashMap(); + samples = new ArrayList(); + } + + /** + * reset the current train data for regression to a new trajectory + * @param state: the last state of the trajectory + */ + def resetRegression(Object state){ + samples.clear(); + + if(stateAndHistory.containsKey(state)){ + var data = stateAndHistory.get(state); + var curState = state; + + samples.add(data); + + //loop through data until the oldest state in the record + while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){ + curState = data.lastState; + data = stateAndHistory.get(data.lastState); + samples.add(data); + } + } + } + + /** + * Add a new data point to the current training set + * @param state: the state on which the new data point is calculated + * @param features: the set of feature value(x) + * @param value: the value of the state (y) + * @param lastState: the state which transformed to current state, used to record the trajectory + */ + def feedData(Object state, double[] features, double value, Object lastState){ + var data = new StateData(features, value, lastState); + stateAndHistory.put(state, data); + samples.add(data); + } + + /** + * get prediction for next state, without storing the data point into the training set + * @param features: the feature values of current state + * @param value: the value of the current state + * @param: featuresToPredict: the features of the state wanted to be predected + * @return the value of the state to be predicted + */ + def double getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){ + var data = new StateData(features, value, null); + samples.add(data); + + // create training set from current data + val double[][] xSamples = samples.map[it.features]; + val double[] ySamples = samples.map[it.value]; + + val x = new Matrix(xSamples); + val y = new Matrix(ySamples, ySamples.size()); + + val regression = new LinearRegression(x, y, ridge); + var prediction = predict(regression.coefficients, featuresToPredict); + + //remove the last element just added + samples.remove(samples.size - 1); + return prediction; + } + + def private predict(double[] parameters, double[] featuresToPredict){ + // the regression will add an initial column for 1's, the first parameter is constant term + var result = parameters.get(0); + for(var i = 0; i < featuresToPredict.length; i++){ + result += parameters.get(i) * featuresToPredict.get(i); + } + return result; + } +} \ No newline at end of file diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java index 148cb243..d9f6f2aa 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java @@ -26,6 +26,7 @@ import org.eclipse.viatra.query.runtime.api.ViatraQueryMatcher; import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.MetricDistanceGroup; import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetric; +import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetricDistance; import hu.bme.mit.inf.dslreasoner.logic.model.builder.DocumentationLevel; import hu.bme.mit.inf.dslreasoner.logic.model.builder.LogicReasoner; import hu.bme.mit.inf.dslreasoner.logic.model.logicproblem.LogicProblem; @@ -62,12 +63,12 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements private Collection> mayMatchers; private Map> stateAndActivations; private Map trajectoryFit; - // Statistics private int numberOfStatecoderFail = 0; private int numberOfPrintedModel = 0; private int numberOfSolverCalls = 0; - + private PartialInterpretationMetricDistance metricDistance; + public HillClimbingOnRealisticMetricStrategyForModelGeneration( ReasonerWorkspace workspace, ViatraReasonerConfiguration configuration, @@ -112,7 +113,7 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements this.solutionStoreWithCopy = new SolutionStoreWithCopy(); this.solutionStoreWithDiversityDescriptor = new SolutionStoreWithDiversityDescriptor(configuration.diversityRequirement); - final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper(); + //final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper(); trajectoryFit = new HashMap(); this.comparator = new Comparator() { @@ -124,6 +125,7 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements trajectoiresToExplore = new PriorityQueue(11, comparator); stateAndActivations = new HashMap>(); + metricDistance = new PartialInterpretationMetricDistance(); } @Override @@ -145,11 +147,11 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements TrajectoryWithFitness currentTrajectoryWithFittness = new TrajectoryWithFitness(firstTrajectory, firstFittness); trajectoryFit.put(currentTrajectoryWithFittness, Double.MAX_VALUE); trajectoiresToExplore.add(currentTrajectoryWithFittness); + Object lastState = null; //if(configuration) visualiseCurrentState(); - PartialInterpretationMetric.initPaths(); //create matcher int count = 0; mainLoop: while (!isInterrupted && !configuration.progressMonitor.isCancelled()) { @@ -165,6 +167,9 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements logger.debug("New trajectory is chosen: " + currentTrajectoryWithFittness); } context.getDesignSpaceManager().executeTrajectoryWithMinimalBacktrackWithoutStateCoding(currentTrajectoryWithFittness.trajectory); + + // reset the regression for this trajectory + metricDistance.getLinearModel().resetRegression(context.getCurrentStateId()); } } @@ -178,10 +183,10 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements //init epsilon and draw double epsilon = 0; double draw = 1; - MetricDistanceGroup heuristics = PartialInterpretationMetric.calculateMetricDistanceKS(model); - - if(!stateAndActivations.containsKey(model)) { - stateAndActivations.put(model, new ArrayList()); + MetricDistanceGroup heuristics = metricDistance.calculateMetricDistanceKS(model); + + if(!stateAndActivations.containsKey(context.getCurrentStateId())) { + stateAndActivations.put(context.getCurrentStateId(), new ArrayList()); } //Output intermediate model @@ -190,41 +195,26 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements draw = Math.random(); count++; // cut off the trajectory for bad graph - double distance = heuristics.getMPCDistance(); System.out.println("KS"); System.out.println("NA distance: " + heuristics.getNADistance()); System.out.println("MPC distance: " + heuristics.getMPCDistance()); System.out.println("Out degree distance:" + heuristics.getOutDegreeDistance()); - MetricDistanceGroup lsHeuristic = PartialInterpretationMetric.calculateMetricDistance(model); - System.out.println("LS"); - System.out.println("NA distance: " + lsHeuristic.getNADistance()); - System.out.println("MPC distance: " + lsHeuristic.getMPCDistance()); - System.out.println("Out degree distance:" + lsHeuristic.getOutDegreeDistance()); - - if(distance <= 0.23880597014925373 + 0.00001 && distance >= 0.23880597014925373 - 0.00001) { - context.backtrack(); - final Fitness nextFitness = context.calculateFitness(); - currentTrajectoryWithFittness = new TrajectoryWithFitness(context.getTrajectory().toArray(), nextFitness); - continue; - } +// MetricDistanceGroup lsHeuristic = metricDistance.calculateMetricDistance(model); +// System.out.println("LS"); +// System.out.println("NA distance: " + lsHeuristic.getNADistance()); +// System.out.println("MPC distance: " + lsHeuristic.getMPCDistance()); +// System.out.println("Out degree distance:" + lsHeuristic.getOutDegreeDistance()); //check for next value when doing greedy move + valueMap = sortWithWeight(activationIds, currentTrajectoryWithFittness.trajectory.length+1); -// if(activationIds.isEmpty() || (model.getNewElements().size() >= 20 && valueMap.get(activationIds.get(0)) > currentValue && epsilon < draw)) { -// context.backtrack(); -// final Fitness nextFitness = context.calculateFitness(); -// currentTrajectoryWithFittness = new TrajectoryWithFitness(context.getTrajectory().toArray(), nextFitness); -// continue; -// } } + lastState = context.getCurrentStateId(); while (!isInterrupted && !configuration.progressMonitor.isCancelled() && activationIds.size() > 0) { final Object nextActivation = drawWithEpsilonProbabilty(activationIds, valueMap, epsilon, draw); -// if (!iterator.hasNext()) { -// logger.debug("Last untraversed activation of the state."); -// trajectoiresToExplore.remove(currentTrajectoryWithFittness); -// } - stateAndActivations.get(context.getModel()).add(nextActivation); + + stateAndActivations.get(context.getCurrentStateId()).add(nextActivation); logger.debug("Executing new activation: " + nextActivation); context.executeAcitvationId(nextActivation); visualiseCurrentState(); @@ -242,15 +232,17 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements logger.debug("Global contraint is not satisifed."); context.backtrack(); } else*/// { - if(getNumberOfViolations(mustMatchers) > 8) { + /*if(getNumberOfViolations(mustMatchers) > 0) { + context.backtrack(); + }else*/ if(model.getNewElements().size() > 90 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.3) { context.backtrack(); - }else if(model.getNewElements().size() > 90 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.18) { + }else if(model.getNewElements().size() > 70 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.45) { context.backtrack(); - }else if(model.getNewElements().size() > 70 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.36) { + } else if(model.getNewElements().size() > 50 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.60) { context.backtrack(); - } else if(model.getNewElements().size() > 50 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.72) { + } else if(model.getNewElements().size() > 30 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.70) { context.backtrack(); - } else { + }else { final Fitness nextFitness = context.calculateFitness(); // the only hard objectives are the size @@ -271,9 +263,14 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements continue; } + TrajectoryWithFitness nextTrajectoryWithFittness = new TrajectoryWithFitness( context.getTrajectory().toArray(), nextFitness); - trajectoryFit.put(nextTrajectoryWithFittness, calculateCurrentStateValue(nextTrajectoryWithFittness.trajectory.length)); + int step = nextTrajectoryWithFittness.trajectory.length; + int violation = getNumberOfViolations(mustMatchers) + getNumberOfViolations(mayMatchers); + metricDistance.getLinearModel().feedData(context.getCurrentStateId(), metricDistance.calculateFeature(step, violation), calculateCurrentStateValue(step, violation), lastState); + + trajectoryFit.put(nextTrajectoryWithFittness, calculateCurrentStateValue(step, violation)); trajectoiresToExplore.add(nextTrajectoryWithFittness); int compare = objectiveComparatorHelper.compare(currentTrajectoryWithFittness.fitness, @@ -298,9 +295,9 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements currentTrajectoryWithFittness = null; context.backtrack(); } - PartialInterpretation model = (PartialInterpretation) context.getModel(); - PartialInterpretationMetric.calculateMetric(model, "debug/metric/output", context.getCurrentStateId().toString(), count); - count++; +// PartialInterpretation model = (PartialInterpretation) context.getModel(); +// PartialInterpretationMetric.calculateMetric(model, "debug/metric/output", context.getCurrentStateId().toString(), count); +// count++; logger.info("Interrupted."); } @@ -314,8 +311,8 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements // do hill climbing for(Object id : activationIds) { context.executeAcitvationId(id); - - valueMap.put(id, calculateCurrentStateValue(factor)); + int violation = getNumberOfViolations(mayMatchers) + getNumberOfViolations(mustMatchers); + valueMap.put(id, calculateFutureStateValue(factor, violation)); context.backtrack(); } @@ -323,14 +320,27 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements return valueMap; } - private double calculateCurrentStateValue(int factor) { + private double calculateFutureStateValue(int step, int violation) { + double currentValue = calculateCurrentStateValue(step, violation); + if(step > 40) { + double[] toPredict = metricDistance.calculateFeature(200, violation); + try { + return metricDistance.getLinearModel().getPredictionForNextDataSample(metricDistance.calculateFeature(step, violation), currentValue, toPredict); + }catch(IllegalArgumentException e) { + return currentValue; + } + }else { + return currentValue; + } + } + + private double calculateCurrentStateValue(int factor, int violation) { PartialInterpretation model = (PartialInterpretation) context.getModel(); - MetricDistanceGroup g = PartialInterpretationMetric.calculateMetricDistanceKS(model); + MetricDistanceGroup g = metricDistance.calculateMetricDistanceKS(model); - int violations = getNumberOfViolations(mayMatchers); - double consistenceWeights = 1.0/(1+violations); + double consistenceWeights = 1- 1.0/(1+violation); - return(2.5 / Math.log(factor) * (g.getNADistance() + g.getMPCDistance() + g.getOutDegreeDistance()) + 1-consistenceWeights); + return( /*/ Math.log(factor)*/(g.getNADistance() + g.getMPCDistance() + g.getOutDegreeDistance()) + consistenceWeights); } private int getNumberOfViolations(Collection> matchers) { @@ -379,8 +389,8 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements List activationIds; try { activationIds = new ArrayList(context.getCurrentActivationIds()); - if(stateAndActivations.containsKey(context.getModel())) { - activationIds.removeAll(stateAndActivations.get(context.getModel())); + if(stateAndActivations.containsKey(context.getCurrentStateId())) { + activationIds.removeAll(stateAndActivations.get(context.getCurrentStateId())); } Collections.shuffle(activationIds); } catch (NullPointerException e) { -- cgit v1.2.3-54-g00ecf