From 69729ebb20fc34f5d836d0ba9dc416114f2c9c4a Mon Sep 17 00:00:00 2001 From: 20001LastOrder Date: Mon, 24 Jun 2019 10:12:34 -0400 Subject: Implement linear regressor using Weka3 --- .../.classpath | 1 + .../META-INF/MANIFEST.MF | 7 +- .../app/PartialInterpretationMetricDistance.xtend | 86 +++++++++++++++----- .../realistic/metrics/calculator/app/Test.java | 31 ++++++++ .../metrics/calculator/distance/CostDistance.xtend | 15 +--- .../metrics/calculator/io/CsvFileWriter.xtend | 19 ++++- .../metrics/calculator/predictor/LinearModel.xtend | 91 ++++++++++++++++++++++ 7 files changed, 217 insertions(+), 33 deletions(-) create mode 100644 Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java create mode 100644 Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend (limited to 'Metrics') diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath index f4f8357b..006b2686 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath @@ -6,5 +6,6 @@ + diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF index da19e07c..febd4757 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF @@ -15,4 +15,9 @@ Require-Bundle: com.google.guava, hu.bme.mit.inf.dslreasoner.domains.yakindu.sgraph;bundle-version="1.0.0", hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage;bundle-version="1.0.0", org.eclipse.viatra.dse;bundle-version="0.21.2" -Export-Package: ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app +Export-Package: ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app, + ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance, + ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.graph, + ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.io, + ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics, + ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend index b63451e8..45986ecf 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend @@ -8,26 +8,35 @@ import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.Metric import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric +import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor.LinearModel import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation import java.util.ArrayList import java.util.HashMap +import java.util.List import java.util.Map -import org.apache.commons.math3.stat.regression.SimpleRegression -import java.util.stream.DoubleStream.Builder +import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression +import org.eclipse.xtend.lib.annotations.Accessors class PartialInterpretationMetricDistance { var KSDistance ks; var JSDistance js; var Map stateAndHistory; - var SimpleRegression regression; + var OLSMultipleLinearRegression regression; + List samples; + + @Accessors(PUBLIC_GETTER) + var LinearModel linearModel; new(){ ks = new KSDistance(Domain.Yakinduum); js = new JSDistance(Domain.Yakinduum); - regression = new SimpleRegression(); + regression = new OLSMultipleLinearRegression(); + regression.noIntercept = false; stateAndHistory = new HashMap(); + samples = new ArrayList(); + linearModel = new LinearModel(0.01); } def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){ @@ -63,37 +72,76 @@ class PartialInterpretationMetricDistance { } def resetRegression(Object state){ - regression = new SimpleRegression(); + samples.clear(); if(stateAndHistory.containsKey(state)){ var data = stateAndHistory.get(state); - regression.addData(data.numOfNodeFeature, data.value); - while(stateAndHistory.containsKey(data.lastState)){ + var curState = state; + + samples.add(data); + + while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){ + curState = data.lastState; data = stateAndHistory.get(data.lastState); - regression.addData(data.numOfNodeFeature, data.value); + samples.add(data); + } + + if(samples.size == 0){ + println('state: ' + state); + println('last state: ' + data.lastState); } } + println("trajectory sample size:" + samples.size) } - def feedData(Object state, int numOfNodes, double value, Object lastState){ - var data = new StateData(numOfNodes, value, lastState); + def feedData(Object state, double[] features, double value, Object lastState){ + var data = new StateData(features, value, lastState); stateAndHistory.put(state, data); - regression.addData(data.numOfNodeFeature, data.value); + samples.add(data); } - def getPredictionForNextDataSample(int numOfNodes, double value, int numberOfNodesToPredict){ - var data = new StateData(numOfNodes, value, null); - regression.addData(data.numOfNodeFeature, data.value); + def getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){ + if(samples.size <= 4){ + println('OK'); + } + var data = new StateData(features, value, null); + samples.add(data); + + // create training set from current data + var double[][] xSamples = samples.map[it.features]; + var double[] ySamples = samples.map[it.value]; + - var prediction = predict(numberOfNodesToPredict); - regression.removeData(data.numOfNodeFeature, data.value); + regression.newSampleData(ySamples, xSamples); + var prediction = predict(featuresToPredict); + + //remove the last element just added + samples.remove(samples.size - 1); return prediction; } - def predict(int numOfNodes){ - var data = new StateData(numOfNodes, 0, null); - return regression.predict(data.numOfNodeFeature); + def private predict(double[] featuresToPredict){ + var parameters = regression.estimateRegressionParameters(); + // the regression will add an initial column for 1's, the first parameter is constant term + var result = parameters.get(0); + for(var i = 0; i < featuresToPredict.length; i++){ + result += parameters.get(i+1) * featuresToPredict.get(i); + } + return result; + } + + def double[] calculateFeature(int step, int violations){ + var features = newDoubleArrayOfSize(5); + //constant term + features.set(0, 1); + + features.set(1, 1.0 / step); + features.set(2, violations); + features.set(3, Math.pow(violations, 2)); + features.set(4, Math.pow(violations, 0.5)); + + return features; } } diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java new file mode 100644 index 00000000..f06b377f --- /dev/null +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java @@ -0,0 +1,31 @@ +package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app; + +import java.util.ArrayList; +import java.util.List; + +import weka.core.matrix.LinearRegression; +import weka.core.matrix.Matrix; + +public class Test { + public static void main(String[] args) { + linearRegressionTest(); + } + + public static void linearRegressionTest() { + double[][] x = {{1,1,2,3}, {1,2,3,4}, {1,3,5,7}, {1,1,5,7}}; + double[] y = {10, 13, 19, 17}; + double[] valueToPredict = {1,1,1,1}; + Matrix m = new Matrix(x); + Matrix n = new Matrix(y, y.length); + + LinearRegression regression = new LinearRegression(m, n, 0); + double[] coef = regression.getCoefficients(); + + //predict + double a = 0; + for(int i = 0; i < coef.length; i++) { + a += coef[i] * valueToPredict[i]; + } + System.out.println(a); + } +} diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend index ee856201..33d10fa3 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend @@ -1,28 +1,21 @@ package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance -import org.apache.commons.math3.stat.regression.SimpleRegression import org.eclipse.xtend.lib.annotations.Accessors class CostDistance { - - var SimpleRegression regression; - - new(){ - regression = new SimpleRegression(true); - } - + } class StateData{ @Accessors(PUBLIC_GETTER) - var double numOfNodeFeature; + var double[] features; @Accessors(PUBLIC_GETTER) var double value; @Accessors(PUBLIC_GETTER) var Object lastState; - new(int numOfNode, double value, Object lastState){ - this.numOfNodeFeature = 1.0 / numOfNode; + new(double[] features, double value, Object lastState){ + this.features = features; this.value = value this.lastState = lastState; } diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend index 01e3940b..00b38d90 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend @@ -2,19 +2,34 @@ package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.io; import java.io.File import java.io.FileNotFoundException +import java.io.FileOutputStream import java.io.PrintWriter import java.util.ArrayList import java.util.List class CsvFileWriter { + def static void write(ArrayList> datas, String uri) { if(datas.size() <= 0) { return; } - + val PrintWriter writer = new PrintWriter(new File(uri)); + output(writer, datas, uri); + } + + def static void append(ArrayList> datas, String uri) { + if(datas.size() <= 0) { + return; + } + val PrintWriter writer = new PrintWriter(new FileOutputStream(new File(uri), true)); + output(writer, datas, uri); + } + + + def private static void output(PrintWriter writer, ArrayList> datas, String uri) { //println("Output csv for " + uri); try { - val PrintWriter writer = new PrintWriter(new File(uri)); + val output = new StringBuilder; for(List datarow : datas){ for(var i = 0; i < datarow.size() - 1; i++){ diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend new file mode 100644 index 00000000..f0ded347 --- /dev/null +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend @@ -0,0 +1,91 @@ +package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor + +import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance.StateData +import java.util.ArrayList +import java.util.HashMap +import java.util.List +import java.util.Map +import weka.core.matrix.LinearRegression +import weka.core.matrix.Matrix + +class LinearModel { + var double ridge; + var Map stateAndHistory; + List samples; + + new(double ridge){ + this.ridge = ridge; + stateAndHistory = new HashMap(); + samples = new ArrayList(); + } + + /** + * reset the current train data for regression to a new trajectory + * @param state: the last state of the trajectory + */ + def resetRegression(Object state){ + samples.clear(); + + if(stateAndHistory.containsKey(state)){ + var data = stateAndHistory.get(state); + var curState = state; + + samples.add(data); + + //loop through data until the oldest state in the record + while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){ + curState = data.lastState; + data = stateAndHistory.get(data.lastState); + samples.add(data); + } + } + } + + /** + * Add a new data point to the current training set + * @param state: the state on which the new data point is calculated + * @param features: the set of feature value(x) + * @param value: the value of the state (y) + * @param lastState: the state which transformed to current state, used to record the trajectory + */ + def feedData(Object state, double[] features, double value, Object lastState){ + var data = new StateData(features, value, lastState); + stateAndHistory.put(state, data); + samples.add(data); + } + + /** + * get prediction for next state, without storing the data point into the training set + * @param features: the feature values of current state + * @param value: the value of the current state + * @param: featuresToPredict: the features of the state wanted to be predected + * @return the value of the state to be predicted + */ + def double getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){ + var data = new StateData(features, value, null); + samples.add(data); + + // create training set from current data + val double[][] xSamples = samples.map[it.features]; + val double[] ySamples = samples.map[it.value]; + + val x = new Matrix(xSamples); + val y = new Matrix(ySamples, ySamples.size()); + + val regression = new LinearRegression(x, y, ridge); + var prediction = predict(regression.coefficients, featuresToPredict); + + //remove the last element just added + samples.remove(samples.size - 1); + return prediction; + } + + def private predict(double[] parameters, double[] featuresToPredict){ + // the regression will add an initial column for 1's, the first parameter is constant term + var result = parameters.get(0); + for(var i = 0; i < featuresToPredict.length; i++){ + result += parameters.get(i) * featuresToPredict.get(i); + } + return result; + } +} \ No newline at end of file -- cgit v1.2.3-54-g00ecf