diff options
Diffstat (limited to 'Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend')
1 files changed, 67 insertions, 19 deletions
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend index b63451e8..45986ecf 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend | |||
@@ -8,26 +8,35 @@ import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.Metric | |||
8 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric | 8 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric |
9 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric | 9 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric |
10 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric | 10 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric |
11 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor.LinearModel | ||
11 | import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation | 12 | import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation |
12 | import java.util.ArrayList | 13 | import java.util.ArrayList |
13 | import java.util.HashMap | 14 | import java.util.HashMap |
15 | import java.util.List | ||
14 | import java.util.Map | 16 | import java.util.Map |
15 | import org.apache.commons.math3.stat.regression.SimpleRegression | 17 | import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression |
16 | import java.util.stream.DoubleStream.Builder | 18 | import org.eclipse.xtend.lib.annotations.Accessors |
17 | 19 | ||
18 | class PartialInterpretationMetricDistance { | 20 | class PartialInterpretationMetricDistance { |
19 | 21 | ||
20 | var KSDistance ks; | 22 | var KSDistance ks; |
21 | var JSDistance js; | 23 | var JSDistance js; |
22 | var Map<Object, StateData> stateAndHistory; | 24 | var Map<Object, StateData> stateAndHistory; |
23 | var SimpleRegression regression; | 25 | var OLSMultipleLinearRegression regression; |
26 | List<StateData> samples; | ||
27 | |||
28 | @Accessors(PUBLIC_GETTER) | ||
29 | var LinearModel linearModel; | ||
24 | 30 | ||
25 | 31 | ||
26 | new(){ | 32 | new(){ |
27 | ks = new KSDistance(Domain.Yakinduum); | 33 | ks = new KSDistance(Domain.Yakinduum); |
28 | js = new JSDistance(Domain.Yakinduum); | 34 | js = new JSDistance(Domain.Yakinduum); |
29 | regression = new SimpleRegression(); | 35 | regression = new OLSMultipleLinearRegression(); |
36 | regression.noIntercept = false; | ||
30 | stateAndHistory = new HashMap<Object, StateData>(); | 37 | stateAndHistory = new HashMap<Object, StateData>(); |
38 | samples = new ArrayList<StateData>(); | ||
39 | linearModel = new LinearModel(0.01); | ||
31 | } | 40 | } |
32 | 41 | ||
33 | def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){ | 42 | def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){ |
@@ -63,37 +72,76 @@ class PartialInterpretationMetricDistance { | |||
63 | } | 72 | } |
64 | 73 | ||
65 | def resetRegression(Object state){ | 74 | def resetRegression(Object state){ |
66 | regression = new SimpleRegression(); | 75 | samples.clear(); |
67 | 76 | ||
68 | if(stateAndHistory.containsKey(state)){ | 77 | if(stateAndHistory.containsKey(state)){ |
69 | var data = stateAndHistory.get(state); | 78 | var data = stateAndHistory.get(state); |
70 | regression.addData(data.numOfNodeFeature, data.value); | ||
71 | 79 | ||
72 | while(stateAndHistory.containsKey(data.lastState)){ | 80 | var curState = state; |
81 | |||
82 | samples.add(data); | ||
83 | |||
84 | while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){ | ||
85 | curState = data.lastState; | ||
73 | data = stateAndHistory.get(data.lastState); | 86 | data = stateAndHistory.get(data.lastState); |
74 | regression.addData(data.numOfNodeFeature, data.value); | 87 | samples.add(data); |
88 | } | ||
89 | |||
90 | if(samples.size == 0){ | ||
91 | println('state: ' + state); | ||
92 | println('last state: ' + data.lastState); | ||
75 | } | 93 | } |
76 | } | 94 | } |
95 | println("trajectory sample size:" + samples.size) | ||
77 | } | 96 | } |
78 | 97 | ||
79 | def feedData(Object state, int numOfNodes, double value, Object lastState){ | 98 | def feedData(Object state, double[] features, double value, Object lastState){ |
80 | var data = new StateData(numOfNodes, value, lastState); | 99 | var data = new StateData(features, value, lastState); |
81 | stateAndHistory.put(state, data); | 100 | stateAndHistory.put(state, data); |
82 | regression.addData(data.numOfNodeFeature, data.value); | 101 | samples.add(data); |
83 | } | 102 | } |
84 | 103 | ||
85 | def getPredictionForNextDataSample(int numOfNodes, double value, int numberOfNodesToPredict){ | 104 | def getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){ |
86 | var data = new StateData(numOfNodes, value, null); | 105 | if(samples.size <= 4){ |
87 | regression.addData(data.numOfNodeFeature, data.value); | 106 | println('OK'); |
107 | } | ||
108 | var data = new StateData(features, value, null); | ||
109 | samples.add(data); | ||
110 | |||
111 | // create training set from current data | ||
112 | var double[][] xSamples = samples.map[it.features]; | ||
113 | var double[] ySamples = samples.map[it.value]; | ||
114 | |||
88 | 115 | ||
89 | var prediction = predict(numberOfNodesToPredict); | 116 | regression.newSampleData(ySamples, xSamples); |
90 | regression.removeData(data.numOfNodeFeature, data.value); | 117 | var prediction = predict(featuresToPredict); |
118 | |||
119 | //remove the last element just added | ||
120 | samples.remove(samples.size - 1); | ||
91 | return prediction; | 121 | return prediction; |
92 | } | 122 | } |
93 | 123 | ||
94 | def predict(int numOfNodes){ | 124 | def private predict(double[] featuresToPredict){ |
95 | var data = new StateData(numOfNodes, 0, null); | 125 | var parameters = regression.estimateRegressionParameters(); |
96 | return regression.predict(data.numOfNodeFeature); | 126 | // the regression will add an initial column for 1's, the first parameter is constant term |
127 | var result = parameters.get(0); | ||
128 | for(var i = 0; i < featuresToPredict.length; i++){ | ||
129 | result += parameters.get(i+1) * featuresToPredict.get(i); | ||
130 | } | ||
131 | return result; | ||
132 | } | ||
133 | |||
134 | def double[] calculateFeature(int step, int violations){ | ||
135 | var features = newDoubleArrayOfSize(5); | ||
136 | //constant term | ||
137 | features.set(0, 1); | ||
138 | |||
139 | features.set(1, 1.0 / step); | ||
140 | features.set(2, violations); | ||
141 | features.set(3, Math.pow(violations, 2)); | ||
142 | features.set(4, Math.pow(violations, 0.5)); | ||
143 | |||
144 | return features; | ||
97 | } | 145 | } |
98 | } | 146 | } |
99 | 147 | ||