aboutsummaryrefslogtreecommitdiffstats
path: root/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend
diff options
context:
space:
mode:
Diffstat (limited to 'Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend')
-rw-r--r--Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend216
1 files changed, 216 insertions, 0 deletions
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend
new file mode 100644
index 00000000..697b2639
--- /dev/null
+++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend
@@ -0,0 +1,216 @@
1package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app
2
3import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance.EuclideanDistance
4import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance.JSDistance
5import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance.KSDistance
6import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance.StateData
7import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.graph.PartialInterpretationGraph
8import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.io.RepMetricsReader
9import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.Metric
10import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MetricSampleGroup
11import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric
12import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric
13import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeTypeMetric
14import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric
15import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor.LinearModel
16import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation
17import java.util.ArrayList
18import java.util.HashMap
19import java.util.List
20import java.util.Map
21import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression
22import org.eclipse.xtend.lib.annotations.Accessors
23import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.EdgeTypeMetric
24
25class PartialInterpretationMetricDistance {
26
27 var KSDistance ks;
28 var JSDistance js;
29 var EuclideanDistance ed;
30 var Map<Object, StateData> stateAndHistory;
31 var OLSMultipleLinearRegression regression;
32 List<StateData> samples;
33 var MetricSampleGroup g;
34 @Accessors(PUBLIC_GETTER)
35 var LinearModel linearModel;
36
37
38 new(Domain d){
39 var metrics = RepMetricsReader.read(d);
40 this.g = metrics;
41 ks = new KSDistance(g);
42 js = new JSDistance(g);
43 ed = new EuclideanDistance(g);
44 regression = new OLSMultipleLinearRegression();
45 regression.noIntercept = false;
46 stateAndHistory = new HashMap<Object, StateData>();
47 samples = new ArrayList<StateData>();
48 linearModel = new LinearModel(0.01);
49 }
50
51 def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){
52 val metrics = new ArrayList<Metric>();
53 metrics.add(new OutDegreeMetric());
54 metrics.add(new NodeActivityMetric());
55 metrics.add(new MultiplexParticipationCoefficientMetric());
56 metrics.add(new NodeTypeMetric());
57 val metricCalculator = new PartialInterpretationGraph(partial, metrics, null);
58 var metricSamples = metricCalculator.evaluateAllMetricsToSamples();
59
60 var mpc = ks.mpcDistance(metricSamples.mpcSamples);
61 var na = ks.naDistance(metricSamples.naSamples);
62 var outDegree = ks.outDegreeDistance(metricSamples.outDegreeSamples);
63 var nodeType = ks.nodeTypeDistance(metricSamples.nodeTypeSamples);
64 //var typedOutDegree = ks.typedOutDegreeDistance(metricSamples.typedOutDegreeSamples);
65 var distance = new MetricDistanceGroup(mpc, na, outDegree, nodeType);
66 distance.nodeTypeInfo = metricSamples.nodeTypeSamples;
67 return distance;
68 }
69
70 def MetricDistanceGroup calculateMetricEuclidean(PartialInterpretation partial){
71 val metrics = new ArrayList<Metric>();
72 metrics.add(new OutDegreeMetric());
73 metrics.add(new NodeActivityMetric());
74 metrics.add(new MultiplexParticipationCoefficientMetric());
75
76 val metricCalculator = new PartialInterpretationGraph(partial, metrics, null);
77 var metricSamples = metricCalculator.evaluateAllMetricsToSamples();
78
79 var mpc = ed.mpcDistance(metricSamples.mpcSamples);
80 var na = ed.naDistance(metricSamples.naSamples);
81 var outDegree = ed.outDegreeDistance(metricSamples.outDegreeSamples);
82
83 return new MetricDistanceGroup(mpc, na, outDegree);
84 }
85
86 def MetricDistanceGroup calculateMetricDistance(PartialInterpretation partial){
87 val metrics = new ArrayList<Metric>();
88 metrics.add(new OutDegreeMetric());
89 metrics.add(new NodeActivityMetric());
90 metrics.add(new MultiplexParticipationCoefficientMetric());
91
92 val metricCalculator = new PartialInterpretationGraph(partial, metrics, null);
93 var metricSamples = metricCalculator.evaluateAllMetricsToSamples();
94
95 var mpc = js.mpcDistance(metricSamples.mpcSamples);
96 var na = js.naDistance(metricSamples.naSamples);
97 var outDegree = js.outDegreeDistance(metricSamples.outDegreeSamples);
98
99 return new MetricDistanceGroup(mpc, na, outDegree);
100 }
101
102 def resetRegression(Object state){
103 samples.clear();
104
105 if(stateAndHistory.containsKey(state)){
106 var data = stateAndHistory.get(state);
107
108 var curState = state;
109
110 samples.add(data);
111
112 while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){
113 curState = data.lastState;
114 data = stateAndHistory.get(data.lastState);
115 samples.add(data);
116 }
117
118 if(samples.size == 0){
119 println('state: ' + state);
120 println('last state: ' + data.lastState);
121 }
122 }
123 println("trajectory sample size:" + samples.size)
124 }
125
126 def feedData(Object state, double[] features, double value, Object lastState){
127 var data = new StateData(features, value, lastState);
128 stateAndHistory.put(state, data);
129 samples.add(data);
130 }
131
132 def getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){
133 if(samples.size <= 4){
134 println('OK');
135 }
136 var data = new StateData(features, value, null);
137 samples.add(data);
138
139 // create training set from current data
140 var double[][] xSamples = samples.map[it.features];
141 var double[] ySamples = samples.map[it.value];
142
143
144 regression.newSampleData(ySamples, xSamples);
145 var prediction = predict(featuresToPredict);
146
147 //remove the last element just added
148 samples.remove(samples.size - 1);
149 return prediction;
150 }
151
152 def private predict(double[] featuresToPredict){
153 var parameters = regression.estimateRegressionParameters();
154 // the regression will add an initial column for 1's, the first parameter is constant term
155 var result = parameters.get(0);
156 for(var i = 0; i < featuresToPredict.length; i++){
157 result += parameters.get(i+1) * featuresToPredict.get(i);
158 }
159 return result;
160 }
161
162 def double[] calculateFeature(int step, int violations){
163 var features = newDoubleArrayOfSize(2);
164 //constant term
165 features.set(0, 1); //a
166 features.set(0, Math.sqrt(step) + 30) // b
167 features.set(1, 1.0 / (step + 30) );// c
168
169
170// features.set(2, violations);
171// features.set(3, Math.pow(violations, 2));
172
173 return features;
174 }
175}
176
177class MetricDistanceGroup{
178 var double mpcDistance;
179 var double naDistance;
180 var double outDegreeDistance;
181 var double nodeTypeDistance;
182 protected var HashMap<String, Double> nodeTypeInfo;
183
184 new(double mpcDistance, double naDistance, double outDegreeDistance, double nodeTypeDistance){
185 this.mpcDistance = mpcDistance;
186 this.naDistance = naDistance;
187 this.outDegreeDistance = outDegreeDistance;
188 this.nodeTypeDistance = nodeTypeDistance;
189 }
190
191 new(double mpcDistance, double naDistance, double outDegreeDistance){
192 this.mpcDistance = mpcDistance;
193 this.naDistance = naDistance;
194 this.outDegreeDistance = outDegreeDistance;
195 }
196
197 def double getNodeTypeDistance(){
198 return this.nodeTypeDistance;
199 }
200
201 def double getMPCDistance(){
202 return this.mpcDistance
203 }
204
205 def double getNADistance(){
206 return this.naDistance
207 }
208
209 def double getOutDegreeDistance(){
210 return this.outDegreeDistance
211 }
212
213 def double getNodeTypePercentage(String typeName){
214 return nodeTypeInfo.getOrDefault(typeName, 0.0);
215 }
216} \ No newline at end of file