diff options
Diffstat (limited to 'Metrics')
7 files changed, 217 insertions, 33 deletions
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath index f4f8357b..006b2686 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath | |||
@@ -6,5 +6,6 @@ | |||
6 | <classpathentry kind="src" path="xtend-gen"/> | 6 | <classpathentry kind="src" path="xtend-gen"/> |
7 | <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/commons-math3-3.6.1.jar"/> | 7 | <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/commons-math3-3.6.1.jar"/> |
8 | <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/commons-math3-3.6.1-javadoc.jar"/> | 8 | <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/commons-math3-3.6.1-javadoc.jar"/> |
9 | <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/weka.jar"/> | ||
9 | <classpathentry kind="output" path="bin"/> | 10 | <classpathentry kind="output" path="bin"/> |
10 | </classpath> | 11 | </classpath> |
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF index da19e07c..febd4757 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF | |||
@@ -15,4 +15,9 @@ Require-Bundle: com.google.guava, | |||
15 | hu.bme.mit.inf.dslreasoner.domains.yakindu.sgraph;bundle-version="1.0.0", | 15 | hu.bme.mit.inf.dslreasoner.domains.yakindu.sgraph;bundle-version="1.0.0", |
16 | hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage;bundle-version="1.0.0", | 16 | hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage;bundle-version="1.0.0", |
17 | org.eclipse.viatra.dse;bundle-version="0.21.2" | 17 | org.eclipse.viatra.dse;bundle-version="0.21.2" |
18 | Export-Package: ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app | 18 | Export-Package: ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app, |
19 | ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance, | ||
20 | ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.graph, | ||
21 | ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.io, | ||
22 | ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics, | ||
23 | ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor | ||
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend index b63451e8..45986ecf 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend | |||
@@ -8,26 +8,35 @@ import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.Metric | |||
8 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric | 8 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric |
9 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric | 9 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric |
10 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric | 10 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric |
11 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor.LinearModel | ||
11 | import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation | 12 | import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation |
12 | import java.util.ArrayList | 13 | import java.util.ArrayList |
13 | import java.util.HashMap | 14 | import java.util.HashMap |
15 | import java.util.List | ||
14 | import java.util.Map | 16 | import java.util.Map |
15 | import org.apache.commons.math3.stat.regression.SimpleRegression | 17 | import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression |
16 | import java.util.stream.DoubleStream.Builder | 18 | import org.eclipse.xtend.lib.annotations.Accessors |
17 | 19 | ||
18 | class PartialInterpretationMetricDistance { | 20 | class PartialInterpretationMetricDistance { |
19 | 21 | ||
20 | var KSDistance ks; | 22 | var KSDistance ks; |
21 | var JSDistance js; | 23 | var JSDistance js; |
22 | var Map<Object, StateData> stateAndHistory; | 24 | var Map<Object, StateData> stateAndHistory; |
23 | var SimpleRegression regression; | 25 | var OLSMultipleLinearRegression regression; |
26 | List<StateData> samples; | ||
27 | |||
28 | @Accessors(PUBLIC_GETTER) | ||
29 | var LinearModel linearModel; | ||
24 | 30 | ||
25 | 31 | ||
26 | new(){ | 32 | new(){ |
27 | ks = new KSDistance(Domain.Yakinduum); | 33 | ks = new KSDistance(Domain.Yakinduum); |
28 | js = new JSDistance(Domain.Yakinduum); | 34 | js = new JSDistance(Domain.Yakinduum); |
29 | regression = new SimpleRegression(); | 35 | regression = new OLSMultipleLinearRegression(); |
36 | regression.noIntercept = false; | ||
30 | stateAndHistory = new HashMap<Object, StateData>(); | 37 | stateAndHistory = new HashMap<Object, StateData>(); |
38 | samples = new ArrayList<StateData>(); | ||
39 | linearModel = new LinearModel(0.01); | ||
31 | } | 40 | } |
32 | 41 | ||
33 | def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){ | 42 | def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){ |
@@ -63,37 +72,76 @@ class PartialInterpretationMetricDistance { | |||
63 | } | 72 | } |
64 | 73 | ||
65 | def resetRegression(Object state){ | 74 | def resetRegression(Object state){ |
66 | regression = new SimpleRegression(); | 75 | samples.clear(); |
67 | 76 | ||
68 | if(stateAndHistory.containsKey(state)){ | 77 | if(stateAndHistory.containsKey(state)){ |
69 | var data = stateAndHistory.get(state); | 78 | var data = stateAndHistory.get(state); |
70 | regression.addData(data.numOfNodeFeature, data.value); | ||
71 | 79 | ||
72 | while(stateAndHistory.containsKey(data.lastState)){ | 80 | var curState = state; |
81 | |||
82 | samples.add(data); | ||
83 | |||
84 | while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){ | ||
85 | curState = data.lastState; | ||
73 | data = stateAndHistory.get(data.lastState); | 86 | data = stateAndHistory.get(data.lastState); |
74 | regression.addData(data.numOfNodeFeature, data.value); | 87 | samples.add(data); |
88 | } | ||
89 | |||
90 | if(samples.size == 0){ | ||
91 | println('state: ' + state); | ||
92 | println('last state: ' + data.lastState); | ||
75 | } | 93 | } |
76 | } | 94 | } |
95 | println("trajectory sample size:" + samples.size) | ||
77 | } | 96 | } |
78 | 97 | ||
79 | def feedData(Object state, int numOfNodes, double value, Object lastState){ | 98 | def feedData(Object state, double[] features, double value, Object lastState){ |
80 | var data = new StateData(numOfNodes, value, lastState); | 99 | var data = new StateData(features, value, lastState); |
81 | stateAndHistory.put(state, data); | 100 | stateAndHistory.put(state, data); |
82 | regression.addData(data.numOfNodeFeature, data.value); | 101 | samples.add(data); |
83 | } | 102 | } |
84 | 103 | ||
85 | def getPredictionForNextDataSample(int numOfNodes, double value, int numberOfNodesToPredict){ | 104 | def getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){ |
86 | var data = new StateData(numOfNodes, value, null); | 105 | if(samples.size <= 4){ |
87 | regression.addData(data.numOfNodeFeature, data.value); | 106 | println('OK'); |
107 | } | ||
108 | var data = new StateData(features, value, null); | ||
109 | samples.add(data); | ||
110 | |||
111 | // create training set from current data | ||
112 | var double[][] xSamples = samples.map[it.features]; | ||
113 | var double[] ySamples = samples.map[it.value]; | ||
114 | |||
88 | 115 | ||
89 | var prediction = predict(numberOfNodesToPredict); | 116 | regression.newSampleData(ySamples, xSamples); |
90 | regression.removeData(data.numOfNodeFeature, data.value); | 117 | var prediction = predict(featuresToPredict); |
118 | |||
119 | //remove the last element just added | ||
120 | samples.remove(samples.size - 1); | ||
91 | return prediction; | 121 | return prediction; |
92 | } | 122 | } |
93 | 123 | ||
94 | def predict(int numOfNodes){ | 124 | def private predict(double[] featuresToPredict){ |
95 | var data = new StateData(numOfNodes, 0, null); | 125 | var parameters = regression.estimateRegressionParameters(); |
96 | return regression.predict(data.numOfNodeFeature); | 126 | // the regression will add an initial column for 1's, the first parameter is constant term |
127 | var result = parameters.get(0); | ||
128 | for(var i = 0; i < featuresToPredict.length; i++){ | ||
129 | result += parameters.get(i+1) * featuresToPredict.get(i); | ||
130 | } | ||
131 | return result; | ||
132 | } | ||
133 | |||
134 | def double[] calculateFeature(int step, int violations){ | ||
135 | var features = newDoubleArrayOfSize(5); | ||
136 | //constant term | ||
137 | features.set(0, 1); | ||
138 | |||
139 | features.set(1, 1.0 / step); | ||
140 | features.set(2, violations); | ||
141 | features.set(3, Math.pow(violations, 2)); | ||
142 | features.set(4, Math.pow(violations, 0.5)); | ||
143 | |||
144 | return features; | ||
97 | } | 145 | } |
98 | } | 146 | } |
99 | 147 | ||
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java new file mode 100644 index 00000000..f06b377f --- /dev/null +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java | |||
@@ -0,0 +1,31 @@ | |||
1 | package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app; | ||
2 | |||
3 | import java.util.ArrayList; | ||
4 | import java.util.List; | ||
5 | |||
6 | import weka.core.matrix.LinearRegression; | ||
7 | import weka.core.matrix.Matrix; | ||
8 | |||
9 | public class Test { | ||
10 | public static void main(String[] args) { | ||
11 | linearRegressionTest(); | ||
12 | } | ||
13 | |||
14 | public static void linearRegressionTest() { | ||
15 | double[][] x = {{1,1,2,3}, {1,2,3,4}, {1,3,5,7}, {1,1,5,7}}; | ||
16 | double[] y = {10, 13, 19, 17}; | ||
17 | double[] valueToPredict = {1,1,1,1}; | ||
18 | Matrix m = new Matrix(x); | ||
19 | Matrix n = new Matrix(y, y.length); | ||
20 | |||
21 | LinearRegression regression = new LinearRegression(m, n, 0); | ||
22 | double[] coef = regression.getCoefficients(); | ||
23 | |||
24 | //predict | ||
25 | double a = 0; | ||
26 | for(int i = 0; i < coef.length; i++) { | ||
27 | a += coef[i] * valueToPredict[i]; | ||
28 | } | ||
29 | System.out.println(a); | ||
30 | } | ||
31 | } | ||
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend index ee856201..33d10fa3 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend | |||
@@ -1,28 +1,21 @@ | |||
1 | package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance | 1 | package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance |
2 | 2 | ||
3 | import org.apache.commons.math3.stat.regression.SimpleRegression | ||
4 | import org.eclipse.xtend.lib.annotations.Accessors | 3 | import org.eclipse.xtend.lib.annotations.Accessors |
5 | 4 | ||
6 | class CostDistance { | 5 | class CostDistance { |
7 | 6 | ||
8 | var SimpleRegression regression; | ||
9 | |||
10 | new(){ | ||
11 | regression = new SimpleRegression(true); | ||
12 | } | ||
13 | |||
14 | } | 7 | } |
15 | 8 | ||
16 | class StateData{ | 9 | class StateData{ |
17 | @Accessors(PUBLIC_GETTER) | 10 | @Accessors(PUBLIC_GETTER) |
18 | var double numOfNodeFeature; | 11 | var double[] features; |
19 | @Accessors(PUBLIC_GETTER) | 12 | @Accessors(PUBLIC_GETTER) |
20 | var double value; | 13 | var double value; |
21 | @Accessors(PUBLIC_GETTER) | 14 | @Accessors(PUBLIC_GETTER) |
22 | var Object lastState; | 15 | var Object lastState; |
23 | 16 | ||
24 | new(int numOfNode, double value, Object lastState){ | 17 | new(double[] features, double value, Object lastState){ |
25 | this.numOfNodeFeature = 1.0 / numOfNode; | 18 | this.features = features; |
26 | this.value = value | 19 | this.value = value |
27 | this.lastState = lastState; | 20 | this.lastState = lastState; |
28 | } | 21 | } |
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend index 01e3940b..00b38d90 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend | |||
@@ -2,19 +2,34 @@ package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.io; | |||
2 | 2 | ||
3 | import java.io.File | 3 | import java.io.File |
4 | import java.io.FileNotFoundException | 4 | import java.io.FileNotFoundException |
5 | import java.io.FileOutputStream | ||
5 | import java.io.PrintWriter | 6 | import java.io.PrintWriter |
6 | import java.util.ArrayList | 7 | import java.util.ArrayList |
7 | import java.util.List | 8 | import java.util.List |
8 | 9 | ||
9 | class CsvFileWriter { | 10 | class CsvFileWriter { |
11 | |||
10 | def static void write(ArrayList<ArrayList<String>> datas, String uri) { | 12 | def static void write(ArrayList<ArrayList<String>> datas, String uri) { |
11 | if(datas.size() <= 0) { | 13 | if(datas.size() <= 0) { |
12 | return; | 14 | return; |
13 | } | 15 | } |
14 | 16 | val PrintWriter writer = new PrintWriter(new File(uri)); | |
17 | output(writer, datas, uri); | ||
18 | } | ||
19 | |||
20 | def static void append(ArrayList<ArrayList<String>> datas, String uri) { | ||
21 | if(datas.size() <= 0) { | ||
22 | return; | ||
23 | } | ||
24 | val PrintWriter writer = new PrintWriter(new FileOutputStream(new File(uri), true)); | ||
25 | output(writer, datas, uri); | ||
26 | } | ||
27 | |||
28 | |||
29 | def private static void output(PrintWriter writer, ArrayList<ArrayList<String>> datas, String uri) { | ||
15 | //println("Output csv for " + uri); | 30 | //println("Output csv for " + uri); |
16 | try { | 31 | try { |
17 | val PrintWriter writer = new PrintWriter(new File(uri)); | 32 | |
18 | val output = new StringBuilder; | 33 | val output = new StringBuilder; |
19 | for(List<String> datarow : datas){ | 34 | for(List<String> datarow : datas){ |
20 | for(var i = 0; i < datarow.size() - 1; i++){ | 35 | for(var i = 0; i < datarow.size() - 1; i++){ |
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend new file mode 100644 index 00000000..f0ded347 --- /dev/null +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend | |||
@@ -0,0 +1,91 @@ | |||
1 | package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor | ||
2 | |||
3 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance.StateData | ||
4 | import java.util.ArrayList | ||
5 | import java.util.HashMap | ||
6 | import java.util.List | ||
7 | import java.util.Map | ||
8 | import weka.core.matrix.LinearRegression | ||
9 | import weka.core.matrix.Matrix | ||
10 | |||
11 | class LinearModel { | ||
12 | var double ridge; | ||
13 | var Map<Object, StateData> stateAndHistory; | ||
14 | List<StateData> samples; | ||
15 | |||
16 | new(double ridge){ | ||
17 | this.ridge = ridge; | ||
18 | stateAndHistory = new HashMap<Object, StateData>(); | ||
19 | samples = new ArrayList<StateData>(); | ||
20 | } | ||
21 | |||
22 | /** | ||
23 | * reset the current train data for regression to a new trajectory | ||
24 | * @param state: the last state of the trajectory | ||
25 | */ | ||
26 | def resetRegression(Object state){ | ||
27 | samples.clear(); | ||
28 | |||
29 | if(stateAndHistory.containsKey(state)){ | ||
30 | var data = stateAndHistory.get(state); | ||
31 | var curState = state; | ||
32 | |||
33 | samples.add(data); | ||
34 | |||
35 | //loop through data until the oldest state in the record | ||
36 | while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){ | ||
37 | curState = data.lastState; | ||
38 | data = stateAndHistory.get(data.lastState); | ||
39 | samples.add(data); | ||
40 | } | ||
41 | } | ||
42 | } | ||
43 | |||
44 | /** | ||
45 | * Add a new data point to the current training set | ||
46 | * @param state: the state on which the new data point is calculated | ||
47 | * @param features: the set of feature value(x) | ||
48 | * @param value: the value of the state (y) | ||
49 | * @param lastState: the state which transformed to current state, used to record the trajectory | ||
50 | */ | ||
51 | def feedData(Object state, double[] features, double value, Object lastState){ | ||
52 | var data = new StateData(features, value, lastState); | ||
53 | stateAndHistory.put(state, data); | ||
54 | samples.add(data); | ||
55 | } | ||
56 | |||
57 | /** | ||
58 | * get prediction for next state, without storing the data point into the training set | ||
59 | * @param features: the feature values of current state | ||
60 | * @param value: the value of the current state | ||
61 | * @param: featuresToPredict: the features of the state wanted to be predected | ||
62 | * @return the value of the state to be predicted | ||
63 | */ | ||
64 | def double getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){ | ||
65 | var data = new StateData(features, value, null); | ||
66 | samples.add(data); | ||
67 | |||
68 | // create training set from current data | ||
69 | val double[][] xSamples = samples.map[it.features]; | ||
70 | val double[] ySamples = samples.map[it.value]; | ||
71 | |||
72 | val x = new Matrix(xSamples); | ||
73 | val y = new Matrix(ySamples, ySamples.size()); | ||
74 | |||
75 | val regression = new LinearRegression(x, y, ridge); | ||
76 | var prediction = predict(regression.coefficients, featuresToPredict); | ||
77 | |||
78 | //remove the last element just added | ||
79 | samples.remove(samples.size - 1); | ||
80 | return prediction; | ||
81 | } | ||
82 | |||
83 | def private predict(double[] parameters, double[] featuresToPredict){ | ||
84 | // the regression will add an initial column for 1's, the first parameter is constant term | ||
85 | var result = parameters.get(0); | ||
86 | for(var i = 0; i < featuresToPredict.length; i++){ | ||
87 | result += parameters.get(i) * featuresToPredict.get(i); | ||
88 | } | ||
89 | return result; | ||
90 | } | ||
91 | } \ No newline at end of file | ||