From 69729ebb20fc34f5d836d0ba9dc416114f2c9c4a Mon Sep 17 00:00:00 2001 From: 20001LastOrder Date: Mon, 24 Jun 2019 10:12:34 -0400 Subject: Implement linear regressor using Weka3 --- .../app/PartialInterpretationMetricDistance.xtend | 86 +++++++++++++++++----- 1 file changed, 67 insertions(+), 19 deletions(-) (limited to 'Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend') diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend index b63451e8..45986ecf 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend @@ -8,26 +8,35 @@ import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.Metric import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric +import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor.LinearModel import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation import java.util.ArrayList import java.util.HashMap +import java.util.List import java.util.Map -import org.apache.commons.math3.stat.regression.SimpleRegression -import java.util.stream.DoubleStream.Builder +import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression +import org.eclipse.xtend.lib.annotations.Accessors class PartialInterpretationMetricDistance { var KSDistance ks; var JSDistance js; var Map stateAndHistory; - var SimpleRegression regression; + var OLSMultipleLinearRegression regression; + List samples; + + @Accessors(PUBLIC_GETTER) + var LinearModel linearModel; new(){ ks = new KSDistance(Domain.Yakinduum); js = new JSDistance(Domain.Yakinduum); - regression = new SimpleRegression(); + regression = new OLSMultipleLinearRegression(); + regression.noIntercept = false; stateAndHistory = new HashMap(); + samples = new ArrayList(); + linearModel = new LinearModel(0.01); } def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){ @@ -63,37 +72,76 @@ class PartialInterpretationMetricDistance { } def resetRegression(Object state){ - regression = new SimpleRegression(); + samples.clear(); if(stateAndHistory.containsKey(state)){ var data = stateAndHistory.get(state); - regression.addData(data.numOfNodeFeature, data.value); - while(stateAndHistory.containsKey(data.lastState)){ + var curState = state; + + samples.add(data); + + while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){ + curState = data.lastState; data = stateAndHistory.get(data.lastState); - regression.addData(data.numOfNodeFeature, data.value); + samples.add(data); + } + + if(samples.size == 0){ + println('state: ' + state); + println('last state: ' + data.lastState); } } + println("trajectory sample size:" + samples.size) } - def feedData(Object state, int numOfNodes, double value, Object lastState){ - var data = new StateData(numOfNodes, value, lastState); + def feedData(Object state, double[] features, double value, Object lastState){ + var data = new StateData(features, value, lastState); stateAndHistory.put(state, data); - regression.addData(data.numOfNodeFeature, data.value); + samples.add(data); } - def getPredictionForNextDataSample(int numOfNodes, double value, int numberOfNodesToPredict){ - var data = new StateData(numOfNodes, value, null); - regression.addData(data.numOfNodeFeature, data.value); + def getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){ + if(samples.size <= 4){ + println('OK'); + } + var data = new StateData(features, value, null); + samples.add(data); + + // create training set from current data + var double[][] xSamples = samples.map[it.features]; + var double[] ySamples = samples.map[it.value]; + - var prediction = predict(numberOfNodesToPredict); - regression.removeData(data.numOfNodeFeature, data.value); + regression.newSampleData(ySamples, xSamples); + var prediction = predict(featuresToPredict); + + //remove the last element just added + samples.remove(samples.size - 1); return prediction; } - def predict(int numOfNodes){ - var data = new StateData(numOfNodes, 0, null); - return regression.predict(data.numOfNodeFeature); + def private predict(double[] featuresToPredict){ + var parameters = regression.estimateRegressionParameters(); + // the regression will add an initial column for 1's, the first parameter is constant term + var result = parameters.get(0); + for(var i = 0; i < featuresToPredict.length; i++){ + result += parameters.get(i+1) * featuresToPredict.get(i); + } + return result; + } + + def double[] calculateFeature(int step, int violations){ + var features = newDoubleArrayOfSize(5); + //constant term + features.set(0, 1); + + features.set(1, 1.0 / step); + features.set(2, violations); + features.set(3, Math.pow(violations, 2)); + features.set(4, Math.pow(violations, 0.5)); + + return features; } } -- cgit v1.2.3-54-g00ecf