aboutsummaryrefslogtreecommitdiffstats
path: root/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend
blob: f0ded347e441f0a184e46e9641c96b7c9d827af5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor

import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance.StateData
import java.util.ArrayList
import java.util.HashMap
import java.util.List
import java.util.Map
import weka.core.matrix.LinearRegression
import weka.core.matrix.Matrix

class LinearModel {
	var double ridge;
	var Map<Object, StateData> stateAndHistory;
	List<StateData> samples;
	
	new(double ridge){
		this.ridge = ridge;
		stateAndHistory = new HashMap<Object, StateData>();
		samples = new ArrayList<StateData>();
	}
	
	/**
	 * reset the current train data for regression to a new trajectory
	 * @param state: the last state of the trajectory
	 */
	def resetRegression(Object state){
		samples.clear();
		
		if(stateAndHistory.containsKey(state)){
			var data = stateAndHistory.get(state);
			var curState = state;
			
			samples.add(data);
			
			//loop through data until the oldest state in the record
			while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){
				curState = data.lastState;
				data = stateAndHistory.get(data.lastState);
				samples.add(data);
			}
		}
	}
	
	/**
	 * Add a new data point to the current training set
	 * @param state: the state on which the new data point is calculated
	 * @param features: the set of feature value(x)
	 * @param value: the value of the state (y)
	 * @param lastState: the state which transformed to current state, used to record the trajectory
	 */
	def feedData(Object state, double[] features, double value, Object lastState){
		var data = new StateData(features, value, lastState);
		stateAndHistory.put(state, data);
		samples.add(data);
	}
	
	/**
	 * get prediction for next state, without storing the data point into the training set
	 * @param features: the feature values of current state
	 * @param value: the value of the current state
	 * @param: featuresToPredict: the features of the state wanted to be predected
	 * @return the value of the state to be predicted
	 */
	def double getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){		
		var data = new StateData(features, value, null);
		samples.add(data);
		
		// create training set from current data		
		val double[][] xSamples = samples.map[it.features];
		val double[] ySamples = samples.map[it.value];
		
		val x = new Matrix(xSamples);
		val y = new Matrix(ySamples, ySamples.size());  
		
		val regression = new LinearRegression(x, y, ridge);
		var prediction = predict(regression.coefficients, featuresToPredict);	
		
		//remove the last element just added
		samples.remove(samples.size - 1);
		return prediction;
	}
	
	def private predict(double[] parameters, double[] featuresToPredict){
		// the regression will add an initial column for 1's, the first parameter is constant term
		var result = parameters.get(0);
		for(var i = 0; i < featuresToPredict.length; i++){
			result += parameters.get(i) * featuresToPredict.get(i);
		}
		return result;
	}
}