aboutsummaryrefslogtreecommitdiffstats
path: root/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend
diff options
context:
space:
mode:
Diffstat (limited to 'Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend')
-rw-r--r--Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend86
1 files changed, 67 insertions, 19 deletions
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend
index b63451e8..45986ecf 100644
--- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend
+++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend
@@ -8,26 +8,35 @@ import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.Metric
8import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric 8import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric
9import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric 9import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric
10import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric 10import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric
11import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor.LinearModel
11import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation 12import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation
12import java.util.ArrayList 13import java.util.ArrayList
13import java.util.HashMap 14import java.util.HashMap
15import java.util.List
14import java.util.Map 16import java.util.Map
15import org.apache.commons.math3.stat.regression.SimpleRegression 17import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression
16import java.util.stream.DoubleStream.Builder 18import org.eclipse.xtend.lib.annotations.Accessors
17 19
18class PartialInterpretationMetricDistance { 20class PartialInterpretationMetricDistance {
19 21
20 var KSDistance ks; 22 var KSDistance ks;
21 var JSDistance js; 23 var JSDistance js;
22 var Map<Object, StateData> stateAndHistory; 24 var Map<Object, StateData> stateAndHistory;
23 var SimpleRegression regression; 25 var OLSMultipleLinearRegression regression;
26 List<StateData> samples;
27
28 @Accessors(PUBLIC_GETTER)
29 var LinearModel linearModel;
24 30
25 31
26 new(){ 32 new(){
27 ks = new KSDistance(Domain.Yakinduum); 33 ks = new KSDistance(Domain.Yakinduum);
28 js = new JSDistance(Domain.Yakinduum); 34 js = new JSDistance(Domain.Yakinduum);
29 regression = new SimpleRegression(); 35 regression = new OLSMultipleLinearRegression();
36 regression.noIntercept = false;
30 stateAndHistory = new HashMap<Object, StateData>(); 37 stateAndHistory = new HashMap<Object, StateData>();
38 samples = new ArrayList<StateData>();
39 linearModel = new LinearModel(0.01);
31 } 40 }
32 41
33 def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){ 42 def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){
@@ -63,37 +72,76 @@ class PartialInterpretationMetricDistance {
63 } 72 }
64 73
65 def resetRegression(Object state){ 74 def resetRegression(Object state){
66 regression = new SimpleRegression(); 75 samples.clear();
67 76
68 if(stateAndHistory.containsKey(state)){ 77 if(stateAndHistory.containsKey(state)){
69 var data = stateAndHistory.get(state); 78 var data = stateAndHistory.get(state);
70 regression.addData(data.numOfNodeFeature, data.value);
71 79
72 while(stateAndHistory.containsKey(data.lastState)){ 80 var curState = state;
81
82 samples.add(data);
83
84 while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){
85 curState = data.lastState;
73 data = stateAndHistory.get(data.lastState); 86 data = stateAndHistory.get(data.lastState);
74 regression.addData(data.numOfNodeFeature, data.value); 87 samples.add(data);
88 }
89
90 if(samples.size == 0){
91 println('state: ' + state);
92 println('last state: ' + data.lastState);
75 } 93 }
76 } 94 }
95 println("trajectory sample size:" + samples.size)
77 } 96 }
78 97
79 def feedData(Object state, int numOfNodes, double value, Object lastState){ 98 def feedData(Object state, double[] features, double value, Object lastState){
80 var data = new StateData(numOfNodes, value, lastState); 99 var data = new StateData(features, value, lastState);
81 stateAndHistory.put(state, data); 100 stateAndHistory.put(state, data);
82 regression.addData(data.numOfNodeFeature, data.value); 101 samples.add(data);
83 } 102 }
84 103
85 def getPredictionForNextDataSample(int numOfNodes, double value, int numberOfNodesToPredict){ 104 def getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){
86 var data = new StateData(numOfNodes, value, null); 105 if(samples.size <= 4){
87 regression.addData(data.numOfNodeFeature, data.value); 106 println('OK');
107 }
108 var data = new StateData(features, value, null);
109 samples.add(data);
110
111 // create training set from current data
112 var double[][] xSamples = samples.map[it.features];
113 var double[] ySamples = samples.map[it.value];
114
88 115
89 var prediction = predict(numberOfNodesToPredict); 116 regression.newSampleData(ySamples, xSamples);
90 regression.removeData(data.numOfNodeFeature, data.value); 117 var prediction = predict(featuresToPredict);
118
119 //remove the last element just added
120 samples.remove(samples.size - 1);
91 return prediction; 121 return prediction;
92 } 122 }
93 123
94 def predict(int numOfNodes){ 124 def private predict(double[] featuresToPredict){
95 var data = new StateData(numOfNodes, 0, null); 125 var parameters = regression.estimateRegressionParameters();
96 return regression.predict(data.numOfNodeFeature); 126 // the regression will add an initial column for 1's, the first parameter is constant term
127 var result = parameters.get(0);
128 for(var i = 0; i < featuresToPredict.length; i++){
129 result += parameters.get(i+1) * featuresToPredict.get(i);
130 }
131 return result;
132 }
133
134 def double[] calculateFeature(int step, int violations){
135 var features = newDoubleArrayOfSize(5);
136 //constant term
137 features.set(0, 1);
138
139 features.set(1, 1.0 / step);
140 features.set(2, violations);
141 features.set(3, Math.pow(violations, 2));
142 features.set(4, Math.pow(violations, 0.5));
143
144 return features;
97 } 145 }
98} 146}
99 147