diff options
8 files changed, 277 insertions, 83 deletions
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath index f4f8357b..006b2686 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath | |||
@@ -6,5 +6,6 @@ | |||
6 | <classpathentry kind="src" path="xtend-gen"/> | 6 | <classpathentry kind="src" path="xtend-gen"/> |
7 | <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/commons-math3-3.6.1.jar"/> | 7 | <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/commons-math3-3.6.1.jar"/> |
8 | <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/commons-math3-3.6.1-javadoc.jar"/> | 8 | <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/commons-math3-3.6.1-javadoc.jar"/> |
9 | <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/weka.jar"/> | ||
9 | <classpathentry kind="output" path="bin"/> | 10 | <classpathentry kind="output" path="bin"/> |
10 | </classpath> | 11 | </classpath> |
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF index da19e07c..febd4757 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF | |||
@@ -15,4 +15,9 @@ Require-Bundle: com.google.guava, | |||
15 | hu.bme.mit.inf.dslreasoner.domains.yakindu.sgraph;bundle-version="1.0.0", | 15 | hu.bme.mit.inf.dslreasoner.domains.yakindu.sgraph;bundle-version="1.0.0", |
16 | hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage;bundle-version="1.0.0", | 16 | hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage;bundle-version="1.0.0", |
17 | org.eclipse.viatra.dse;bundle-version="0.21.2" | 17 | org.eclipse.viatra.dse;bundle-version="0.21.2" |
18 | Export-Package: ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app | 18 | Export-Package: ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app, |
19 | ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance, | ||
20 | ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.graph, | ||
21 | ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.io, | ||
22 | ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics, | ||
23 | ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor | ||
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend index b63451e8..45986ecf 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend | |||
@@ -8,26 +8,35 @@ import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.Metric | |||
8 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric | 8 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric |
9 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric | 9 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric |
10 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric | 10 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric |
11 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor.LinearModel | ||
11 | import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation | 12 | import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation |
12 | import java.util.ArrayList | 13 | import java.util.ArrayList |
13 | import java.util.HashMap | 14 | import java.util.HashMap |
15 | import java.util.List | ||
14 | import java.util.Map | 16 | import java.util.Map |
15 | import org.apache.commons.math3.stat.regression.SimpleRegression | 17 | import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression |
16 | import java.util.stream.DoubleStream.Builder | 18 | import org.eclipse.xtend.lib.annotations.Accessors |
17 | 19 | ||
18 | class PartialInterpretationMetricDistance { | 20 | class PartialInterpretationMetricDistance { |
19 | 21 | ||
20 | var KSDistance ks; | 22 | var KSDistance ks; |
21 | var JSDistance js; | 23 | var JSDistance js; |
22 | var Map<Object, StateData> stateAndHistory; | 24 | var Map<Object, StateData> stateAndHistory; |
23 | var SimpleRegression regression; | 25 | var OLSMultipleLinearRegression regression; |
26 | List<StateData> samples; | ||
27 | |||
28 | @Accessors(PUBLIC_GETTER) | ||
29 | var LinearModel linearModel; | ||
24 | 30 | ||
25 | 31 | ||
26 | new(){ | 32 | new(){ |
27 | ks = new KSDistance(Domain.Yakinduum); | 33 | ks = new KSDistance(Domain.Yakinduum); |
28 | js = new JSDistance(Domain.Yakinduum); | 34 | js = new JSDistance(Domain.Yakinduum); |
29 | regression = new SimpleRegression(); | 35 | regression = new OLSMultipleLinearRegression(); |
36 | regression.noIntercept = false; | ||
30 | stateAndHistory = new HashMap<Object, StateData>(); | 37 | stateAndHistory = new HashMap<Object, StateData>(); |
38 | samples = new ArrayList<StateData>(); | ||
39 | linearModel = new LinearModel(0.01); | ||
31 | } | 40 | } |
32 | 41 | ||
33 | def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){ | 42 | def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){ |
@@ -63,37 +72,76 @@ class PartialInterpretationMetricDistance { | |||
63 | } | 72 | } |
64 | 73 | ||
65 | def resetRegression(Object state){ | 74 | def resetRegression(Object state){ |
66 | regression = new SimpleRegression(); | 75 | samples.clear(); |
67 | 76 | ||
68 | if(stateAndHistory.containsKey(state)){ | 77 | if(stateAndHistory.containsKey(state)){ |
69 | var data = stateAndHistory.get(state); | 78 | var data = stateAndHistory.get(state); |
70 | regression.addData(data.numOfNodeFeature, data.value); | ||
71 | 79 | ||
72 | while(stateAndHistory.containsKey(data.lastState)){ | 80 | var curState = state; |
81 | |||
82 | samples.add(data); | ||
83 | |||
84 | while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){ | ||
85 | curState = data.lastState; | ||
73 | data = stateAndHistory.get(data.lastState); | 86 | data = stateAndHistory.get(data.lastState); |
74 | regression.addData(data.numOfNodeFeature, data.value); | 87 | samples.add(data); |
88 | } | ||
89 | |||
90 | if(samples.size == 0){ | ||
91 | println('state: ' + state); | ||
92 | println('last state: ' + data.lastState); | ||
75 | } | 93 | } |
76 | } | 94 | } |
95 | println("trajectory sample size:" + samples.size) | ||
77 | } | 96 | } |
78 | 97 | ||
79 | def feedData(Object state, int numOfNodes, double value, Object lastState){ | 98 | def feedData(Object state, double[] features, double value, Object lastState){ |
80 | var data = new StateData(numOfNodes, value, lastState); | 99 | var data = new StateData(features, value, lastState); |
81 | stateAndHistory.put(state, data); | 100 | stateAndHistory.put(state, data); |
82 | regression.addData(data.numOfNodeFeature, data.value); | 101 | samples.add(data); |
83 | } | 102 | } |
84 | 103 | ||
85 | def getPredictionForNextDataSample(int numOfNodes, double value, int numberOfNodesToPredict){ | 104 | def getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){ |
86 | var data = new StateData(numOfNodes, value, null); | 105 | if(samples.size <= 4){ |
87 | regression.addData(data.numOfNodeFeature, data.value); | 106 | println('OK'); |
107 | } | ||
108 | var data = new StateData(features, value, null); | ||
109 | samples.add(data); | ||
110 | |||
111 | // create training set from current data | ||
112 | var double[][] xSamples = samples.map[it.features]; | ||
113 | var double[] ySamples = samples.map[it.value]; | ||
114 | |||
88 | 115 | ||
89 | var prediction = predict(numberOfNodesToPredict); | 116 | regression.newSampleData(ySamples, xSamples); |
90 | regression.removeData(data.numOfNodeFeature, data.value); | 117 | var prediction = predict(featuresToPredict); |
118 | |||
119 | //remove the last element just added | ||
120 | samples.remove(samples.size - 1); | ||
91 | return prediction; | 121 | return prediction; |
92 | } | 122 | } |
93 | 123 | ||
94 | def predict(int numOfNodes){ | 124 | def private predict(double[] featuresToPredict){ |
95 | var data = new StateData(numOfNodes, 0, null); | 125 | var parameters = regression.estimateRegressionParameters(); |
96 | return regression.predict(data.numOfNodeFeature); | 126 | // the regression will add an initial column for 1's, the first parameter is constant term |
127 | var result = parameters.get(0); | ||
128 | for(var i = 0; i < featuresToPredict.length; i++){ | ||
129 | result += parameters.get(i+1) * featuresToPredict.get(i); | ||
130 | } | ||
131 | return result; | ||
132 | } | ||
133 | |||
134 | def double[] calculateFeature(int step, int violations){ | ||
135 | var features = newDoubleArrayOfSize(5); | ||
136 | //constant term | ||
137 | features.set(0, 1); | ||
138 | |||
139 | features.set(1, 1.0 / step); | ||
140 | features.set(2, violations); | ||
141 | features.set(3, Math.pow(violations, 2)); | ||
142 | features.set(4, Math.pow(violations, 0.5)); | ||
143 | |||
144 | return features; | ||
97 | } | 145 | } |
98 | } | 146 | } |
99 | 147 | ||
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java new file mode 100644 index 00000000..f06b377f --- /dev/null +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java | |||
@@ -0,0 +1,31 @@ | |||
1 | package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app; | ||
2 | |||
3 | import java.util.ArrayList; | ||
4 | import java.util.List; | ||
5 | |||
6 | import weka.core.matrix.LinearRegression; | ||
7 | import weka.core.matrix.Matrix; | ||
8 | |||
9 | public class Test { | ||
10 | public static void main(String[] args) { | ||
11 | linearRegressionTest(); | ||
12 | } | ||
13 | |||
14 | public static void linearRegressionTest() { | ||
15 | double[][] x = {{1,1,2,3}, {1,2,3,4}, {1,3,5,7}, {1,1,5,7}}; | ||
16 | double[] y = {10, 13, 19, 17}; | ||
17 | double[] valueToPredict = {1,1,1,1}; | ||
18 | Matrix m = new Matrix(x); | ||
19 | Matrix n = new Matrix(y, y.length); | ||
20 | |||
21 | LinearRegression regression = new LinearRegression(m, n, 0); | ||
22 | double[] coef = regression.getCoefficients(); | ||
23 | |||
24 | //predict | ||
25 | double a = 0; | ||
26 | for(int i = 0; i < coef.length; i++) { | ||
27 | a += coef[i] * valueToPredict[i]; | ||
28 | } | ||
29 | System.out.println(a); | ||
30 | } | ||
31 | } | ||
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend index ee856201..33d10fa3 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend | |||
@@ -1,28 +1,21 @@ | |||
1 | package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance | 1 | package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance |
2 | 2 | ||
3 | import org.apache.commons.math3.stat.regression.SimpleRegression | ||
4 | import org.eclipse.xtend.lib.annotations.Accessors | 3 | import org.eclipse.xtend.lib.annotations.Accessors |
5 | 4 | ||
6 | class CostDistance { | 5 | class CostDistance { |
7 | 6 | ||
8 | var SimpleRegression regression; | ||
9 | |||
10 | new(){ | ||
11 | regression = new SimpleRegression(true); | ||
12 | } | ||
13 | |||
14 | } | 7 | } |
15 | 8 | ||
16 | class StateData{ | 9 | class StateData{ |
17 | @Accessors(PUBLIC_GETTER) | 10 | @Accessors(PUBLIC_GETTER) |
18 | var double numOfNodeFeature; | 11 | var double[] features; |
19 | @Accessors(PUBLIC_GETTER) | 12 | @Accessors(PUBLIC_GETTER) |
20 | var double value; | 13 | var double value; |
21 | @Accessors(PUBLIC_GETTER) | 14 | @Accessors(PUBLIC_GETTER) |
22 | var Object lastState; | 15 | var Object lastState; |
23 | 16 | ||
24 | new(int numOfNode, double value, Object lastState){ | 17 | new(double[] features, double value, Object lastState){ |
25 | this.numOfNodeFeature = 1.0 / numOfNode; | 18 | this.features = features; |
26 | this.value = value | 19 | this.value = value |
27 | this.lastState = lastState; | 20 | this.lastState = lastState; |
28 | } | 21 | } |
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend index 01e3940b..00b38d90 100644 --- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend | |||
@@ -2,19 +2,34 @@ package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.io; | |||
2 | 2 | ||
3 | import java.io.File | 3 | import java.io.File |
4 | import java.io.FileNotFoundException | 4 | import java.io.FileNotFoundException |
5 | import java.io.FileOutputStream | ||
5 | import java.io.PrintWriter | 6 | import java.io.PrintWriter |
6 | import java.util.ArrayList | 7 | import java.util.ArrayList |
7 | import java.util.List | 8 | import java.util.List |
8 | 9 | ||
9 | class CsvFileWriter { | 10 | class CsvFileWriter { |
11 | |||
10 | def static void write(ArrayList<ArrayList<String>> datas, String uri) { | 12 | def static void write(ArrayList<ArrayList<String>> datas, String uri) { |
11 | if(datas.size() <= 0) { | 13 | if(datas.size() <= 0) { |
12 | return; | 14 | return; |
13 | } | 15 | } |
14 | 16 | val PrintWriter writer = new PrintWriter(new File(uri)); | |
17 | output(writer, datas, uri); | ||
18 | } | ||
19 | |||
20 | def static void append(ArrayList<ArrayList<String>> datas, String uri) { | ||
21 | if(datas.size() <= 0) { | ||
22 | return; | ||
23 | } | ||
24 | val PrintWriter writer = new PrintWriter(new FileOutputStream(new File(uri), true)); | ||
25 | output(writer, datas, uri); | ||
26 | } | ||
27 | |||
28 | |||
29 | def private static void output(PrintWriter writer, ArrayList<ArrayList<String>> datas, String uri) { | ||
15 | //println("Output csv for " + uri); | 30 | //println("Output csv for " + uri); |
16 | try { | 31 | try { |
17 | val PrintWriter writer = new PrintWriter(new File(uri)); | 32 | |
18 | val output = new StringBuilder; | 33 | val output = new StringBuilder; |
19 | for(List<String> datarow : datas){ | 34 | for(List<String> datarow : datas){ |
20 | for(var i = 0; i < datarow.size() - 1; i++){ | 35 | for(var i = 0; i < datarow.size() - 1; i++){ |
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend new file mode 100644 index 00000000..f0ded347 --- /dev/null +++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend | |||
@@ -0,0 +1,91 @@ | |||
1 | package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor | ||
2 | |||
3 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance.StateData | ||
4 | import java.util.ArrayList | ||
5 | import java.util.HashMap | ||
6 | import java.util.List | ||
7 | import java.util.Map | ||
8 | import weka.core.matrix.LinearRegression | ||
9 | import weka.core.matrix.Matrix | ||
10 | |||
11 | class LinearModel { | ||
12 | var double ridge; | ||
13 | var Map<Object, StateData> stateAndHistory; | ||
14 | List<StateData> samples; | ||
15 | |||
16 | new(double ridge){ | ||
17 | this.ridge = ridge; | ||
18 | stateAndHistory = new HashMap<Object, StateData>(); | ||
19 | samples = new ArrayList<StateData>(); | ||
20 | } | ||
21 | |||
22 | /** | ||
23 | * reset the current train data for regression to a new trajectory | ||
24 | * @param state: the last state of the trajectory | ||
25 | */ | ||
26 | def resetRegression(Object state){ | ||
27 | samples.clear(); | ||
28 | |||
29 | if(stateAndHistory.containsKey(state)){ | ||
30 | var data = stateAndHistory.get(state); | ||
31 | var curState = state; | ||
32 | |||
33 | samples.add(data); | ||
34 | |||
35 | //loop through data until the oldest state in the record | ||
36 | while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){ | ||
37 | curState = data.lastState; | ||
38 | data = stateAndHistory.get(data.lastState); | ||
39 | samples.add(data); | ||
40 | } | ||
41 | } | ||
42 | } | ||
43 | |||
44 | /** | ||
45 | * Add a new data point to the current training set | ||
46 | * @param state: the state on which the new data point is calculated | ||
47 | * @param features: the set of feature value(x) | ||
48 | * @param value: the value of the state (y) | ||
49 | * @param lastState: the state which transformed to current state, used to record the trajectory | ||
50 | */ | ||
51 | def feedData(Object state, double[] features, double value, Object lastState){ | ||
52 | var data = new StateData(features, value, lastState); | ||
53 | stateAndHistory.put(state, data); | ||
54 | samples.add(data); | ||
55 | } | ||
56 | |||
57 | /** | ||
58 | * get prediction for next state, without storing the data point into the training set | ||
59 | * @param features: the feature values of current state | ||
60 | * @param value: the value of the current state | ||
61 | * @param: featuresToPredict: the features of the state wanted to be predected | ||
62 | * @return the value of the state to be predicted | ||
63 | */ | ||
64 | def double getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){ | ||
65 | var data = new StateData(features, value, null); | ||
66 | samples.add(data); | ||
67 | |||
68 | // create training set from current data | ||
69 | val double[][] xSamples = samples.map[it.features]; | ||
70 | val double[] ySamples = samples.map[it.value]; | ||
71 | |||
72 | val x = new Matrix(xSamples); | ||
73 | val y = new Matrix(ySamples, ySamples.size()); | ||
74 | |||
75 | val regression = new LinearRegression(x, y, ridge); | ||
76 | var prediction = predict(regression.coefficients, featuresToPredict); | ||
77 | |||
78 | //remove the last element just added | ||
79 | samples.remove(samples.size - 1); | ||
80 | return prediction; | ||
81 | } | ||
82 | |||
83 | def private predict(double[] parameters, double[] featuresToPredict){ | ||
84 | // the regression will add an initial column for 1's, the first parameter is constant term | ||
85 | var result = parameters.get(0); | ||
86 | for(var i = 0; i < featuresToPredict.length; i++){ | ||
87 | result += parameters.get(i) * featuresToPredict.get(i); | ||
88 | } | ||
89 | return result; | ||
90 | } | ||
91 | } \ No newline at end of file | ||
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java index 148cb243..d9f6f2aa 100644 --- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java +++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java | |||
@@ -26,6 +26,7 @@ import org.eclipse.viatra.query.runtime.api.ViatraQueryMatcher; | |||
26 | 26 | ||
27 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.MetricDistanceGroup; | 27 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.MetricDistanceGroup; |
28 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetric; | 28 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetric; |
29 | import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetricDistance; | ||
29 | import hu.bme.mit.inf.dslreasoner.logic.model.builder.DocumentationLevel; | 30 | import hu.bme.mit.inf.dslreasoner.logic.model.builder.DocumentationLevel; |
30 | import hu.bme.mit.inf.dslreasoner.logic.model.builder.LogicReasoner; | 31 | import hu.bme.mit.inf.dslreasoner.logic.model.builder.LogicReasoner; |
31 | import hu.bme.mit.inf.dslreasoner.logic.model.logicproblem.LogicProblem; | 32 | import hu.bme.mit.inf.dslreasoner.logic.model.logicproblem.LogicProblem; |
@@ -62,12 +63,12 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
62 | private Collection<ViatraQueryMatcher<? extends IPatternMatch>> mayMatchers; | 63 | private Collection<ViatraQueryMatcher<? extends IPatternMatch>> mayMatchers; |
63 | private Map<Object, List<Object>> stateAndActivations; | 64 | private Map<Object, List<Object>> stateAndActivations; |
64 | private Map<TrajectoryWithFitness, Double> trajectoryFit; | 65 | private Map<TrajectoryWithFitness, Double> trajectoryFit; |
65 | |||
66 | // Statistics | 66 | // Statistics |
67 | private int numberOfStatecoderFail = 0; | 67 | private int numberOfStatecoderFail = 0; |
68 | private int numberOfPrintedModel = 0; | 68 | private int numberOfPrintedModel = 0; |
69 | private int numberOfSolverCalls = 0; | 69 | private int numberOfSolverCalls = 0; |
70 | 70 | private PartialInterpretationMetricDistance metricDistance; | |
71 | |||
71 | public HillClimbingOnRealisticMetricStrategyForModelGeneration( | 72 | public HillClimbingOnRealisticMetricStrategyForModelGeneration( |
72 | ReasonerWorkspace workspace, | 73 | ReasonerWorkspace workspace, |
73 | ViatraReasonerConfiguration configuration, | 74 | ViatraReasonerConfiguration configuration, |
@@ -112,7 +113,7 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
112 | this.solutionStoreWithCopy = new SolutionStoreWithCopy(); | 113 | this.solutionStoreWithCopy = new SolutionStoreWithCopy(); |
113 | this.solutionStoreWithDiversityDescriptor = new SolutionStoreWithDiversityDescriptor(configuration.diversityRequirement); | 114 | this.solutionStoreWithDiversityDescriptor = new SolutionStoreWithDiversityDescriptor(configuration.diversityRequirement); |
114 | 115 | ||
115 | final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper(); | 116 | //final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper(); |
116 | trajectoryFit = new HashMap<TrajectoryWithFitness, Double>(); | 117 | trajectoryFit = new HashMap<TrajectoryWithFitness, Double>(); |
117 | 118 | ||
118 | this.comparator = new Comparator<TrajectoryWithFitness>() { | 119 | this.comparator = new Comparator<TrajectoryWithFitness>() { |
@@ -124,6 +125,7 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
124 | 125 | ||
125 | trajectoiresToExplore = new PriorityQueue<TrajectoryWithFitness>(11, comparator); | 126 | trajectoiresToExplore = new PriorityQueue<TrajectoryWithFitness>(11, comparator); |
126 | stateAndActivations = new HashMap<Object, List<Object>>(); | 127 | stateAndActivations = new HashMap<Object, List<Object>>(); |
128 | metricDistance = new PartialInterpretationMetricDistance(); | ||
127 | } | 129 | } |
128 | 130 | ||
129 | @Override | 131 | @Override |
@@ -145,11 +147,11 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
145 | TrajectoryWithFitness currentTrajectoryWithFittness = new TrajectoryWithFitness(firstTrajectory, firstFittness); | 147 | TrajectoryWithFitness currentTrajectoryWithFittness = new TrajectoryWithFitness(firstTrajectory, firstFittness); |
146 | trajectoryFit.put(currentTrajectoryWithFittness, Double.MAX_VALUE); | 148 | trajectoryFit.put(currentTrajectoryWithFittness, Double.MAX_VALUE); |
147 | trajectoiresToExplore.add(currentTrajectoryWithFittness); | 149 | trajectoiresToExplore.add(currentTrajectoryWithFittness); |
150 | Object lastState = null; | ||
148 | 151 | ||
149 | //if(configuration) | 152 | //if(configuration) |
150 | visualiseCurrentState(); | 153 | visualiseCurrentState(); |
151 | 154 | ||
152 | PartialInterpretationMetric.initPaths(); | ||
153 | //create matcher | 155 | //create matcher |
154 | int count = 0; | 156 | int count = 0; |
155 | mainLoop: while (!isInterrupted && !configuration.progressMonitor.isCancelled()) { | 157 | mainLoop: while (!isInterrupted && !configuration.progressMonitor.isCancelled()) { |
@@ -165,6 +167,9 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
165 | logger.debug("New trajectory is chosen: " + currentTrajectoryWithFittness); | 167 | logger.debug("New trajectory is chosen: " + currentTrajectoryWithFittness); |
166 | } | 168 | } |
167 | context.getDesignSpaceManager().executeTrajectoryWithMinimalBacktrackWithoutStateCoding(currentTrajectoryWithFittness.trajectory); | 169 | context.getDesignSpaceManager().executeTrajectoryWithMinimalBacktrackWithoutStateCoding(currentTrajectoryWithFittness.trajectory); |
170 | |||
171 | // reset the regression for this trajectory | ||
172 | metricDistance.getLinearModel().resetRegression(context.getCurrentStateId()); | ||
168 | } | 173 | } |
169 | } | 174 | } |
170 | 175 | ||
@@ -178,10 +183,10 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
178 | //init epsilon and draw | 183 | //init epsilon and draw |
179 | double epsilon = 0; | 184 | double epsilon = 0; |
180 | double draw = 1; | 185 | double draw = 1; |
181 | MetricDistanceGroup heuristics = PartialInterpretationMetric.calculateMetricDistanceKS(model); | 186 | MetricDistanceGroup heuristics = metricDistance.calculateMetricDistanceKS(model); |
182 | 187 | ||
183 | if(!stateAndActivations.containsKey(model)) { | 188 | if(!stateAndActivations.containsKey(context.getCurrentStateId())) { |
184 | stateAndActivations.put(model, new ArrayList<Object>()); | 189 | stateAndActivations.put(context.getCurrentStateId(), new ArrayList<Object>()); |
185 | } | 190 | } |
186 | 191 | ||
187 | //Output intermediate model | 192 | //Output intermediate model |
@@ -190,41 +195,26 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
190 | draw = Math.random(); | 195 | draw = Math.random(); |
191 | count++; | 196 | count++; |
192 | // cut off the trajectory for bad graph | 197 | // cut off the trajectory for bad graph |
193 | double distance = heuristics.getMPCDistance(); | ||
194 | System.out.println("KS"); | 198 | System.out.println("KS"); |
195 | System.out.println("NA distance: " + heuristics.getNADistance()); | 199 | System.out.println("NA distance: " + heuristics.getNADistance()); |
196 | System.out.println("MPC distance: " + heuristics.getMPCDistance()); | 200 | System.out.println("MPC distance: " + heuristics.getMPCDistance()); |
197 | System.out.println("Out degree distance:" + heuristics.getOutDegreeDistance()); | 201 | System.out.println("Out degree distance:" + heuristics.getOutDegreeDistance()); |
198 | 202 | ||
199 | MetricDistanceGroup lsHeuristic = PartialInterpretationMetric.calculateMetricDistance(model); | 203 | // MetricDistanceGroup lsHeuristic = metricDistance.calculateMetricDistance(model); |
200 | System.out.println("LS"); | 204 | // System.out.println("LS"); |
201 | System.out.println("NA distance: " + lsHeuristic.getNADistance()); | 205 | // System.out.println("NA distance: " + lsHeuristic.getNADistance()); |
202 | System.out.println("MPC distance: " + lsHeuristic.getMPCDistance()); | 206 | // System.out.println("MPC distance: " + lsHeuristic.getMPCDistance()); |
203 | System.out.println("Out degree distance:" + lsHeuristic.getOutDegreeDistance()); | 207 | // System.out.println("Out degree distance:" + lsHeuristic.getOutDegreeDistance()); |
204 | |||
205 | if(distance <= 0.23880597014925373 + 0.00001 && distance >= 0.23880597014925373 - 0.00001) { | ||
206 | context.backtrack(); | ||
207 | final Fitness nextFitness = context.calculateFitness(); | ||
208 | currentTrajectoryWithFittness = new TrajectoryWithFitness(context.getTrajectory().toArray(), nextFitness); | ||
209 | continue; | ||
210 | } | ||
211 | 208 | ||
212 | //check for next value when doing greedy move | 209 | //check for next value when doing greedy move |
210 | |||
213 | valueMap = sortWithWeight(activationIds, currentTrajectoryWithFittness.trajectory.length+1); | 211 | valueMap = sortWithWeight(activationIds, currentTrajectoryWithFittness.trajectory.length+1); |
214 | // if(activationIds.isEmpty() || (model.getNewElements().size() >= 20 && valueMap.get(activationIds.get(0)) > currentValue && epsilon < draw)) { | ||
215 | // context.backtrack(); | ||
216 | // final Fitness nextFitness = context.calculateFitness(); | ||
217 | // currentTrajectoryWithFittness = new TrajectoryWithFitness(context.getTrajectory().toArray(), nextFitness); | ||
218 | // continue; | ||
219 | // } | ||
220 | } | 212 | } |
213 | lastState = context.getCurrentStateId(); | ||
221 | while (!isInterrupted && !configuration.progressMonitor.isCancelled() && activationIds.size() > 0) { | 214 | while (!isInterrupted && !configuration.progressMonitor.isCancelled() && activationIds.size() > 0) { |
222 | final Object nextActivation = drawWithEpsilonProbabilty(activationIds, valueMap, epsilon, draw); | 215 | final Object nextActivation = drawWithEpsilonProbabilty(activationIds, valueMap, epsilon, draw); |
223 | // if (!iterator.hasNext()) { | 216 | |
224 | // logger.debug("Last untraversed activation of the state."); | 217 | stateAndActivations.get(context.getCurrentStateId()).add(nextActivation); |
225 | // trajectoiresToExplore.remove(currentTrajectoryWithFittness); | ||
226 | // } | ||
227 | stateAndActivations.get(context.getModel()).add(nextActivation); | ||
228 | logger.debug("Executing new activation: " + nextActivation); | 218 | logger.debug("Executing new activation: " + nextActivation); |
229 | context.executeAcitvationId(nextActivation); | 219 | context.executeAcitvationId(nextActivation); |
230 | visualiseCurrentState(); | 220 | visualiseCurrentState(); |
@@ -242,15 +232,17 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
242 | logger.debug("Global contraint is not satisifed."); | 232 | logger.debug("Global contraint is not satisifed."); |
243 | context.backtrack(); | 233 | context.backtrack(); |
244 | } else*/// { | 234 | } else*/// { |
245 | if(getNumberOfViolations(mustMatchers) > 8) { | 235 | /*if(getNumberOfViolations(mustMatchers) > 0) { |
236 | context.backtrack(); | ||
237 | }else*/ if(model.getNewElements().size() > 90 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.3) { | ||
246 | context.backtrack(); | 238 | context.backtrack(); |
247 | }else if(model.getNewElements().size() > 90 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.18) { | 239 | }else if(model.getNewElements().size() > 70 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.45) { |
248 | context.backtrack(); | 240 | context.backtrack(); |
249 | }else if(model.getNewElements().size() > 70 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.36) { | 241 | } else if(model.getNewElements().size() > 50 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.60) { |
250 | context.backtrack(); | 242 | context.backtrack(); |
251 | } else if(model.getNewElements().size() > 50 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.72) { | 243 | } else if(model.getNewElements().size() > 30 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.70) { |
252 | context.backtrack(); | 244 | context.backtrack(); |
253 | } else { | 245 | }else { |
254 | final Fitness nextFitness = context.calculateFitness(); | 246 | final Fitness nextFitness = context.calculateFitness(); |
255 | 247 | ||
256 | // the only hard objectives are the size | 248 | // the only hard objectives are the size |
@@ -271,9 +263,14 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
271 | continue; | 263 | continue; |
272 | } | 264 | } |
273 | 265 | ||
266 | |||
274 | TrajectoryWithFitness nextTrajectoryWithFittness = new TrajectoryWithFitness( | 267 | TrajectoryWithFitness nextTrajectoryWithFittness = new TrajectoryWithFitness( |
275 | context.getTrajectory().toArray(), nextFitness); | 268 | context.getTrajectory().toArray(), nextFitness); |
276 | trajectoryFit.put(nextTrajectoryWithFittness, calculateCurrentStateValue(nextTrajectoryWithFittness.trajectory.length)); | 269 | int step = nextTrajectoryWithFittness.trajectory.length; |
270 | int violation = getNumberOfViolations(mustMatchers) + getNumberOfViolations(mayMatchers); | ||
271 | metricDistance.getLinearModel().feedData(context.getCurrentStateId(), metricDistance.calculateFeature(step, violation), calculateCurrentStateValue(step, violation), lastState); | ||
272 | |||
273 | trajectoryFit.put(nextTrajectoryWithFittness, calculateCurrentStateValue(step, violation)); | ||
277 | trajectoiresToExplore.add(nextTrajectoryWithFittness); | 274 | trajectoiresToExplore.add(nextTrajectoryWithFittness); |
278 | 275 | ||
279 | int compare = objectiveComparatorHelper.compare(currentTrajectoryWithFittness.fitness, | 276 | int compare = objectiveComparatorHelper.compare(currentTrajectoryWithFittness.fitness, |
@@ -298,9 +295,9 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
298 | currentTrajectoryWithFittness = null; | 295 | currentTrajectoryWithFittness = null; |
299 | context.backtrack(); | 296 | context.backtrack(); |
300 | } | 297 | } |
301 | PartialInterpretation model = (PartialInterpretation) context.getModel(); | 298 | // PartialInterpretation model = (PartialInterpretation) context.getModel(); |
302 | PartialInterpretationMetric.calculateMetric(model, "debug/metric/output", context.getCurrentStateId().toString(), count); | 299 | // PartialInterpretationMetric.calculateMetric(model, "debug/metric/output", context.getCurrentStateId().toString(), count); |
303 | count++; | 300 | // count++; |
304 | logger.info("Interrupted."); | 301 | logger.info("Interrupted."); |
305 | } | 302 | } |
306 | 303 | ||
@@ -314,8 +311,8 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
314 | // do hill climbing | 311 | // do hill climbing |
315 | for(Object id : activationIds) { | 312 | for(Object id : activationIds) { |
316 | context.executeAcitvationId(id); | 313 | context.executeAcitvationId(id); |
317 | 314 | int violation = getNumberOfViolations(mayMatchers) + getNumberOfViolations(mustMatchers); | |
318 | valueMap.put(id, calculateCurrentStateValue(factor)); | 315 | valueMap.put(id, calculateFutureStateValue(factor, violation)); |
319 | context.backtrack(); | 316 | context.backtrack(); |
320 | } | 317 | } |
321 | 318 | ||
@@ -323,14 +320,27 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
323 | return valueMap; | 320 | return valueMap; |
324 | } | 321 | } |
325 | 322 | ||
326 | private double calculateCurrentStateValue(int factor) { | 323 | private double calculateFutureStateValue(int step, int violation) { |
324 | double currentValue = calculateCurrentStateValue(step, violation); | ||
325 | if(step > 40) { | ||
326 | double[] toPredict = metricDistance.calculateFeature(200, violation); | ||
327 | try { | ||
328 | return metricDistance.getLinearModel().getPredictionForNextDataSample(metricDistance.calculateFeature(step, violation), currentValue, toPredict); | ||
329 | }catch(IllegalArgumentException e) { | ||
330 | return currentValue; | ||
331 | } | ||
332 | }else { | ||
333 | return currentValue; | ||
334 | } | ||
335 | } | ||
336 | |||
337 | private double calculateCurrentStateValue(int factor, int violation) { | ||
327 | PartialInterpretation model = (PartialInterpretation) context.getModel(); | 338 | PartialInterpretation model = (PartialInterpretation) context.getModel(); |
328 | MetricDistanceGroup g = PartialInterpretationMetric.calculateMetricDistanceKS(model); | 339 | MetricDistanceGroup g = metricDistance.calculateMetricDistanceKS(model); |
329 | 340 | ||
330 | int violations = getNumberOfViolations(mayMatchers); | 341 | double consistenceWeights = 1- 1.0/(1+violation); |
331 | double consistenceWeights = 1.0/(1+violations); | ||
332 | 342 | ||
333 | return(2.5 / Math.log(factor) * (g.getNADistance() + g.getMPCDistance() + g.getOutDegreeDistance()) + 1-consistenceWeights); | 343 | return( /*/ Math.log(factor)*/(g.getNADistance() + g.getMPCDistance() + g.getOutDegreeDistance()) + consistenceWeights); |
334 | } | 344 | } |
335 | 345 | ||
336 | private int getNumberOfViolations(Collection<ViatraQueryMatcher<? extends IPatternMatch>> matchers) { | 346 | private int getNumberOfViolations(Collection<ViatraQueryMatcher<? extends IPatternMatch>> matchers) { |
@@ -379,8 +389,8 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements | |||
379 | List<Object> activationIds; | 389 | List<Object> activationIds; |
380 | try { | 390 | try { |
381 | activationIds = new ArrayList<Object>(context.getCurrentActivationIds()); | 391 | activationIds = new ArrayList<Object>(context.getCurrentActivationIds()); |
382 | if(stateAndActivations.containsKey(context.getModel())) { | 392 | if(stateAndActivations.containsKey(context.getCurrentStateId())) { |
383 | activationIds.removeAll(stateAndActivations.get(context.getModel())); | 393 | activationIds.removeAll(stateAndActivations.get(context.getCurrentStateId())); |
384 | } | 394 | } |
385 | Collections.shuffle(activationIds); | 395 | Collections.shuffle(activationIds); |
386 | } catch (NullPointerException e) { | 396 | } catch (NullPointerException e) { |