aboutsummaryrefslogtreecommitdiffstats
path: root/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend
diff options
context:
space:
mode:
Diffstat (limited to 'Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend')
-rw-r--r--Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend91
1 files changed, 91 insertions, 0 deletions
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend
new file mode 100644
index 00000000..f0ded347
--- /dev/null
+++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend
@@ -0,0 +1,91 @@
1package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor
2
3import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance.StateData
4import java.util.ArrayList
5import java.util.HashMap
6import java.util.List
7import java.util.Map
8import weka.core.matrix.LinearRegression
9import weka.core.matrix.Matrix
10
11class LinearModel {
12 var double ridge;
13 var Map<Object, StateData> stateAndHistory;
14 List<StateData> samples;
15
16 new(double ridge){
17 this.ridge = ridge;
18 stateAndHistory = new HashMap<Object, StateData>();
19 samples = new ArrayList<StateData>();
20 }
21
22 /**
23 * reset the current train data for regression to a new trajectory
24 * @param state: the last state of the trajectory
25 */
26 def resetRegression(Object state){
27 samples.clear();
28
29 if(stateAndHistory.containsKey(state)){
30 var data = stateAndHistory.get(state);
31 var curState = state;
32
33 samples.add(data);
34
35 //loop through data until the oldest state in the record
36 while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){
37 curState = data.lastState;
38 data = stateAndHistory.get(data.lastState);
39 samples.add(data);
40 }
41 }
42 }
43
44 /**
45 * Add a new data point to the current training set
46 * @param state: the state on which the new data point is calculated
47 * @param features: the set of feature value(x)
48 * @param value: the value of the state (y)
49 * @param lastState: the state which transformed to current state, used to record the trajectory
50 */
51 def feedData(Object state, double[] features, double value, Object lastState){
52 var data = new StateData(features, value, lastState);
53 stateAndHistory.put(state, data);
54 samples.add(data);
55 }
56
57 /**
58 * get prediction for next state, without storing the data point into the training set
59 * @param features: the feature values of current state
60 * @param value: the value of the current state
61 * @param: featuresToPredict: the features of the state wanted to be predected
62 * @return the value of the state to be predicted
63 */
64 def double getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){
65 var data = new StateData(features, value, null);
66 samples.add(data);
67
68 // create training set from current data
69 val double[][] xSamples = samples.map[it.features];
70 val double[] ySamples = samples.map[it.value];
71
72 val x = new Matrix(xSamples);
73 val y = new Matrix(ySamples, ySamples.size());
74
75 val regression = new LinearRegression(x, y, ridge);
76 var prediction = predict(regression.coefficients, featuresToPredict);
77
78 //remove the last element just added
79 samples.remove(samples.size - 1);
80 return prediction;
81 }
82
83 def private predict(double[] parameters, double[] featuresToPredict){
84 // the regression will add an initial column for 1's, the first parameter is constant term
85 var result = parameters.get(0);
86 for(var i = 0; i < featuresToPredict.length; i++){
87 result += parameters.get(i) * featuresToPredict.get(i);
88 }
89 return result;
90 }
91} \ No newline at end of file