From 69729ebb20fc34f5d836d0ba9dc416114f2c9c4a Mon Sep 17 00:00:00 2001 From: 20001LastOrder Date: Mon, 24 Jun 2019 10:12:34 -0400 Subject: Implement linear regressor using Weka3 --- .../app/PartialInterpretationMetricDistance.xtend | 86 +++++++++++++++++----- .../realistic/metrics/calculator/app/Test.java | 31 ++++++++ 2 files changed, 98 insertions(+), 19 deletions(-) create mode 100644 Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java (limited to 'Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app') diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend index b63451e8..45986ecf 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend @@ -8,26 +8,35 @@ import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.Metric import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric +import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor.LinearModel import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation import java.util.ArrayList import java.util.HashMap +import java.util.List import java.util.Map -import org.apache.commons.math3.stat.regression.SimpleRegression -import java.util.stream.DoubleStream.Builder +import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression +import org.eclipse.xtend.lib.annotations.Accessors class PartialInterpretationMetricDistance { var KSDistance ks; var JSDistance js; var Map stateAndHistory; - var SimpleRegression regression; + var OLSMultipleLinearRegression regression; + List samples; + + @Accessors(PUBLIC_GETTER) + var LinearModel linearModel; new(){ ks = new KSDistance(Domain.Yakinduum); js = new JSDistance(Domain.Yakinduum); - regression = new SimpleRegression(); + regression = new OLSMultipleLinearRegression(); + regression.noIntercept = false; stateAndHistory = new HashMap(); + samples = new ArrayList(); + linearModel = new LinearModel(0.01); } def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){ @@ -63,37 +72,76 @@ class PartialInterpretationMetricDistance { } def resetRegression(Object state){ - regression = new SimpleRegression(); + samples.clear(); if(stateAndHistory.containsKey(state)){ var data = stateAndHistory.get(state); - regression.addData(data.numOfNodeFeature, data.value); - while(stateAndHistory.containsKey(data.lastState)){ + var curState = state; + + samples.add(data); + + while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){ + curState = data.lastState; data = stateAndHistory.get(data.lastState); - regression.addData(data.numOfNodeFeature, data.value); + samples.add(data); + } + + if(samples.size == 0){ + println('state: ' + state); + println('last state: ' + data.lastState); } } + println("trajectory sample size:" + samples.size) } - def feedData(Object state, int numOfNodes, double value, Object lastState){ - var data = new StateData(numOfNodes, value, lastState); + def feedData(Object state, double[] features, double value, Object lastState){ + var data = new StateData(features, value, lastState); stateAndHistory.put(state, data); - regression.addData(data.numOfNodeFeature, data.value); + samples.add(data); } - def getPredictionForNextDataSample(int numOfNodes, double value, int numberOfNodesToPredict){ - var data = new StateData(numOfNodes, value, null); - regression.addData(data.numOfNodeFeature, data.value); + def getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){ + if(samples.size <= 4){ + println('OK'); + } + var data = new StateData(features, value, null); + samples.add(data); + + // create training set from current data + var double[][] xSamples = samples.map[it.features]; + var double[] ySamples = samples.map[it.value]; + - var prediction = predict(numberOfNodesToPredict); - regression.removeData(data.numOfNodeFeature, data.value); + regression.newSampleData(ySamples, xSamples); + var prediction = predict(featuresToPredict); + + //remove the last element just added + samples.remove(samples.size - 1); return prediction; } - def predict(int numOfNodes){ - var data = new StateData(numOfNodes, 0, null); - return regression.predict(data.numOfNodeFeature); + def private predict(double[] featuresToPredict){ + var parameters = regression.estimateRegressionParameters(); + // the regression will add an initial column for 1's, the first parameter is constant term + var result = parameters.get(0); + for(var i = 0; i < featuresToPredict.length; i++){ + result += parameters.get(i+1) * featuresToPredict.get(i); + } + return result; + } + + def double[] calculateFeature(int step, int violations){ + var features = newDoubleArrayOfSize(5); + //constant term + features.set(0, 1); + + features.set(1, 1.0 / step); + features.set(2, violations); + features.set(3, Math.pow(violations, 2)); + features.set(4, Math.pow(violations, 0.5)); + + return features; } } diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java new file mode 100644 index 00000000..f06b377f --- /dev/null +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java @@ -0,0 +1,31 @@ +package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app; + +import java.util.ArrayList; +import java.util.List; + +import weka.core.matrix.LinearRegression; +import weka.core.matrix.Matrix; + +public class Test { + public static void main(String[] args) { + linearRegressionTest(); + } + + public static void linearRegressionTest() { + double[][] x = {{1,1,2,3}, {1,2,3,4}, {1,3,5,7}, {1,1,5,7}}; + double[] y = {10, 13, 19, 17}; + double[] valueToPredict = {1,1,1,1}; + Matrix m = new Matrix(x); + Matrix n = new Matrix(y, y.length); + + LinearRegression regression = new LinearRegression(m, n, 0); + double[] coef = regression.getCoefficients(); + + //predict + double a = 0; + for(int i = 0; i < coef.length; i++) { + a += coef[i] * valueToPredict[i]; + } + System.out.println(a); + } +} -- cgit v1.2.3-54-g00ecf