diff options
Diffstat (limited to 'Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend')
-rw-r--r-- | Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend new file mode 100644 index 00000000..f0ded347 --- /dev/null +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend | |||
@@ -0,0 +1,91 @@ | |||
1 | package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor | ||
2 | |||
3 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance.StateData | ||
4 | import java.util.ArrayList | ||
5 | import java.util.HashMap | ||
6 | import java.util.List | ||
7 | import java.util.Map | ||
8 | import weka.core.matrix.LinearRegression | ||
9 | import weka.core.matrix.Matrix | ||
10 | |||
11 | class LinearModel { | ||
12 | var double ridge; | ||
13 | var Map<Object, StateData> stateAndHistory; | ||
14 | List<StateData> samples; | ||
15 | |||
16 | new(double ridge){ | ||
17 | this.ridge = ridge; | ||
18 | stateAndHistory = new HashMap<Object, StateData>(); | ||
19 | samples = new ArrayList<StateData>(); | ||
20 | } | ||
21 | |||
22 | /** | ||
23 | * reset the current train data for regression to a new trajectory | ||
24 | * @param state: the last state of the trajectory | ||
25 | */ | ||
26 | def resetRegression(Object state){ | ||
27 | samples.clear(); | ||
28 | |||
29 | if(stateAndHistory.containsKey(state)){ | ||
30 | var data = stateAndHistory.get(state); | ||
31 | var curState = state; | ||
32 | |||
33 | samples.add(data); | ||
34 | |||
35 | //loop through data until the oldest state in the record | ||
36 | while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){ | ||
37 | curState = data.lastState; | ||
38 | data = stateAndHistory.get(data.lastState); | ||
39 | samples.add(data); | ||
40 | } | ||
41 | } | ||
42 | } | ||
43 | |||
44 | /** | ||
45 | * Add a new data point to the current training set | ||
46 | * @param state: the state on which the new data point is calculated | ||
47 | * @param features: the set of feature value(x) | ||
48 | * @param value: the value of the state (y) | ||
49 | * @param lastState: the state which transformed to current state, used to record the trajectory | ||
50 | */ | ||
51 | def feedData(Object state, double[] features, double value, Object lastState){ | ||
52 | var data = new StateData(features, value, lastState); | ||
53 | stateAndHistory.put(state, data); | ||
54 | samples.add(data); | ||
55 | } | ||
56 | |||
57 | /** | ||
58 | * get prediction for next state, without storing the data point into the training set | ||
59 | * @param features: the feature values of current state | ||
60 | * @param value: the value of the current state | ||
61 | * @param: featuresToPredict: the features of the state wanted to be predected | ||
62 | * @return the value of the state to be predicted | ||
63 | */ | ||
64 | def double getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){ | ||
65 | var data = new StateData(features, value, null); | ||
66 | samples.add(data); | ||
67 | |||
68 | // create training set from current data | ||
69 | val double[][] xSamples = samples.map[it.features]; | ||
70 | val double[] ySamples = samples.map[it.value]; | ||
71 | |||
72 | val x = new Matrix(xSamples); | ||
73 | val y = new Matrix(ySamples, ySamples.size()); | ||
74 | |||
75 | val regression = new LinearRegression(x, y, ridge); | ||
76 | var prediction = predict(regression.coefficients, featuresToPredict); | ||
77 | |||
78 | //remove the last element just added | ||
79 | samples.remove(samples.size - 1); | ||
80 | return prediction; | ||
81 | } | ||
82 | |||
83 | def private predict(double[] parameters, double[] featuresToPredict){ | ||
84 | // the regression will add an initial column for 1's, the first parameter is constant term | ||
85 | var result = parameters.get(0); | ||
86 | for(var i = 0; i < featuresToPredict.length; i++){ | ||
87 | result += parameters.get(i) * featuresToPredict.get(i); | ||
88 | } | ||
89 | return result; | ||
90 | } | ||
91 | } \ No newline at end of file | ||