aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath1
-rw-r--r--Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF7
-rw-r--r--Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend86
-rw-r--r--Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java31
-rw-r--r--Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend15
-rw-r--r--Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend19
-rw-r--r--Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend91
-rw-r--r--Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java110
8 files changed, 277 insertions, 83 deletions
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath
index f4f8357b..006b2686 100644
--- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath
+++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/.classpath
@@ -6,5 +6,6 @@
6 <classpathentry kind="src" path="xtend-gen"/> 6 <classpathentry kind="src" path="xtend-gen"/>
7 <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/commons-math3-3.6.1.jar"/> 7 <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/commons-math3-3.6.1.jar"/>
8 <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/commons-math3-3.6.1-javadoc.jar"/> 8 <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/commons-math3-3.6.1-javadoc.jar"/>
9 <classpathentry kind="lib" path="C:/Users/chenp/eclipse-workspace/VIATRA-Generator/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/lib/weka.jar"/>
9 <classpathentry kind="output" path="bin"/> 10 <classpathentry kind="output" path="bin"/>
10</classpath> 11</classpath>
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF
index da19e07c..febd4757 100644
--- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF
+++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/META-INF/MANIFEST.MF
@@ -15,4 +15,9 @@ Require-Bundle: com.google.guava,
15 hu.bme.mit.inf.dslreasoner.domains.yakindu.sgraph;bundle-version="1.0.0", 15 hu.bme.mit.inf.dslreasoner.domains.yakindu.sgraph;bundle-version="1.0.0",
16 hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage;bundle-version="1.0.0", 16 hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage;bundle-version="1.0.0",
17 org.eclipse.viatra.dse;bundle-version="0.21.2" 17 org.eclipse.viatra.dse;bundle-version="0.21.2"
18Export-Package: ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app 18Export-Package: ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app,
19 ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance,
20 ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.graph,
21 ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.io,
22 ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics,
23 ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend
index b63451e8..45986ecf 100644
--- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend
+++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/PartialInterpretationMetricDistance.xtend
@@ -8,26 +8,35 @@ import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.Metric
8import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric 8import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MultiplexParticipationCoefficientMetric
9import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric 9import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.NodeActivityMetric
10import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric 10import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.OutDegreeMetric
11import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor.LinearModel
11import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation 12import hu.bme.mit.inf.dslreasoner.viatrasolver.partialinterpretationlanguage.partialinterpretation.PartialInterpretation
12import java.util.ArrayList 13import java.util.ArrayList
13import java.util.HashMap 14import java.util.HashMap
15import java.util.List
14import java.util.Map 16import java.util.Map
15import org.apache.commons.math3.stat.regression.SimpleRegression 17import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression
16import java.util.stream.DoubleStream.Builder 18import org.eclipse.xtend.lib.annotations.Accessors
17 19
18class PartialInterpretationMetricDistance { 20class PartialInterpretationMetricDistance {
19 21
20 var KSDistance ks; 22 var KSDistance ks;
21 var JSDistance js; 23 var JSDistance js;
22 var Map<Object, StateData> stateAndHistory; 24 var Map<Object, StateData> stateAndHistory;
23 var SimpleRegression regression; 25 var OLSMultipleLinearRegression regression;
26 List<StateData> samples;
27
28 @Accessors(PUBLIC_GETTER)
29 var LinearModel linearModel;
24 30
25 31
26 new(){ 32 new(){
27 ks = new KSDistance(Domain.Yakinduum); 33 ks = new KSDistance(Domain.Yakinduum);
28 js = new JSDistance(Domain.Yakinduum); 34 js = new JSDistance(Domain.Yakinduum);
29 regression = new SimpleRegression(); 35 regression = new OLSMultipleLinearRegression();
36 regression.noIntercept = false;
30 stateAndHistory = new HashMap<Object, StateData>(); 37 stateAndHistory = new HashMap<Object, StateData>();
38 samples = new ArrayList<StateData>();
39 linearModel = new LinearModel(0.01);
31 } 40 }
32 41
33 def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){ 42 def MetricDistanceGroup calculateMetricDistanceKS(PartialInterpretation partial){
@@ -63,37 +72,76 @@ class PartialInterpretationMetricDistance {
63 } 72 }
64 73
65 def resetRegression(Object state){ 74 def resetRegression(Object state){
66 regression = new SimpleRegression(); 75 samples.clear();
67 76
68 if(stateAndHistory.containsKey(state)){ 77 if(stateAndHistory.containsKey(state)){
69 var data = stateAndHistory.get(state); 78 var data = stateAndHistory.get(state);
70 regression.addData(data.numOfNodeFeature, data.value);
71 79
72 while(stateAndHistory.containsKey(data.lastState)){ 80 var curState = state;
81
82 samples.add(data);
83
84 while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){
85 curState = data.lastState;
73 data = stateAndHistory.get(data.lastState); 86 data = stateAndHistory.get(data.lastState);
74 regression.addData(data.numOfNodeFeature, data.value); 87 samples.add(data);
88 }
89
90 if(samples.size == 0){
91 println('state: ' + state);
92 println('last state: ' + data.lastState);
75 } 93 }
76 } 94 }
95 println("trajectory sample size:" + samples.size)
77 } 96 }
78 97
79 def feedData(Object state, int numOfNodes, double value, Object lastState){ 98 def feedData(Object state, double[] features, double value, Object lastState){
80 var data = new StateData(numOfNodes, value, lastState); 99 var data = new StateData(features, value, lastState);
81 stateAndHistory.put(state, data); 100 stateAndHistory.put(state, data);
82 regression.addData(data.numOfNodeFeature, data.value); 101 samples.add(data);
83 } 102 }
84 103
85 def getPredictionForNextDataSample(int numOfNodes, double value, int numberOfNodesToPredict){ 104 def getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){
86 var data = new StateData(numOfNodes, value, null); 105 if(samples.size <= 4){
87 regression.addData(data.numOfNodeFeature, data.value); 106 println('OK');
107 }
108 var data = new StateData(features, value, null);
109 samples.add(data);
110
111 // create training set from current data
112 var double[][] xSamples = samples.map[it.features];
113 var double[] ySamples = samples.map[it.value];
114
88 115
89 var prediction = predict(numberOfNodesToPredict); 116 regression.newSampleData(ySamples, xSamples);
90 regression.removeData(data.numOfNodeFeature, data.value); 117 var prediction = predict(featuresToPredict);
118
119 //remove the last element just added
120 samples.remove(samples.size - 1);
91 return prediction; 121 return prediction;
92 } 122 }
93 123
94 def predict(int numOfNodes){ 124 def private predict(double[] featuresToPredict){
95 var data = new StateData(numOfNodes, 0, null); 125 var parameters = regression.estimateRegressionParameters();
96 return regression.predict(data.numOfNodeFeature); 126 // the regression will add an initial column for 1's, the first parameter is constant term
127 var result = parameters.get(0);
128 for(var i = 0; i < featuresToPredict.length; i++){
129 result += parameters.get(i+1) * featuresToPredict.get(i);
130 }
131 return result;
132 }
133
134 def double[] calculateFeature(int step, int violations){
135 var features = newDoubleArrayOfSize(5);
136 //constant term
137 features.set(0, 1);
138
139 features.set(1, 1.0 / step);
140 features.set(2, violations);
141 features.set(3, Math.pow(violations, 2));
142 features.set(4, Math.pow(violations, 0.5));
143
144 return features;
97 } 145 }
98} 146}
99 147
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java
new file mode 100644
index 00000000..f06b377f
--- /dev/null
+++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/app/Test.java
@@ -0,0 +1,31 @@
1package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app;
2
3import java.util.ArrayList;
4import java.util.List;
5
6import weka.core.matrix.LinearRegression;
7import weka.core.matrix.Matrix;
8
9public class Test {
10 public static void main(String[] args) {
11 linearRegressionTest();
12 }
13
14 public static void linearRegressionTest() {
15 double[][] x = {{1,1,2,3}, {1,2,3,4}, {1,3,5,7}, {1,1,5,7}};
16 double[] y = {10, 13, 19, 17};
17 double[] valueToPredict = {1,1,1,1};
18 Matrix m = new Matrix(x);
19 Matrix n = new Matrix(y, y.length);
20
21 LinearRegression regression = new LinearRegression(m, n, 0);
22 double[] coef = regression.getCoefficients();
23
24 //predict
25 double a = 0;
26 for(int i = 0; i < coef.length; i++) {
27 a += coef[i] * valueToPredict[i];
28 }
29 System.out.println(a);
30 }
31}
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend
index ee856201..33d10fa3 100644
--- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend
+++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/CostDistance.xtend
@@ -1,28 +1,21 @@
1package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance 1package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance
2 2
3import org.apache.commons.math3.stat.regression.SimpleRegression
4import org.eclipse.xtend.lib.annotations.Accessors 3import org.eclipse.xtend.lib.annotations.Accessors
5 4
6class CostDistance { 5class CostDistance {
7 6
8 var SimpleRegression regression;
9
10 new(){
11 regression = new SimpleRegression(true);
12 }
13
14} 7}
15 8
16class StateData{ 9class StateData{
17 @Accessors(PUBLIC_GETTER) 10 @Accessors(PUBLIC_GETTER)
18 var double numOfNodeFeature; 11 var double[] features;
19 @Accessors(PUBLIC_GETTER) 12 @Accessors(PUBLIC_GETTER)
20 var double value; 13 var double value;
21 @Accessors(PUBLIC_GETTER) 14 @Accessors(PUBLIC_GETTER)
22 var Object lastState; 15 var Object lastState;
23 16
24 new(int numOfNode, double value, Object lastState){ 17 new(double[] features, double value, Object lastState){
25 this.numOfNodeFeature = 1.0 / numOfNode; 18 this.features = features;
26 this.value = value 19 this.value = value
27 this.lastState = lastState; 20 this.lastState = lastState;
28 } 21 }
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend
index 01e3940b..00b38d90 100644
--- a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend
+++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/io/CsvFileWriter.xtend
@@ -2,19 +2,34 @@ package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.io;
2 2
3import java.io.File 3import java.io.File
4import java.io.FileNotFoundException 4import java.io.FileNotFoundException
5import java.io.FileOutputStream
5import java.io.PrintWriter 6import java.io.PrintWriter
6import java.util.ArrayList 7import java.util.ArrayList
7import java.util.List 8import java.util.List
8 9
9class CsvFileWriter { 10class CsvFileWriter {
11
10 def static void write(ArrayList<ArrayList<String>> datas, String uri) { 12 def static void write(ArrayList<ArrayList<String>> datas, String uri) {
11 if(datas.size() <= 0) { 13 if(datas.size() <= 0) {
12 return; 14 return;
13 } 15 }
14 16 val PrintWriter writer = new PrintWriter(new File(uri));
17 output(writer, datas, uri);
18 }
19
20 def static void append(ArrayList<ArrayList<String>> datas, String uri) {
21 if(datas.size() <= 0) {
22 return;
23 }
24 val PrintWriter writer = new PrintWriter(new FileOutputStream(new File(uri), true));
25 output(writer, datas, uri);
26 }
27
28
29 def private static void output(PrintWriter writer, ArrayList<ArrayList<String>> datas, String uri) {
15 //println("Output csv for " + uri); 30 //println("Output csv for " + uri);
16 try { 31 try {
17 val PrintWriter writer = new PrintWriter(new File(uri)); 32
18 val output = new StringBuilder; 33 val output = new StringBuilder;
19 for(List<String> datarow : datas){ 34 for(List<String> datarow : datas){
20 for(var i = 0; i < datarow.size() - 1; i++){ 35 for(var i = 0; i < datarow.size() - 1; i++){
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend
new file mode 100644
index 00000000..f0ded347
--- /dev/null
+++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/predictor/LinearModel.xtend
@@ -0,0 +1,91 @@
1package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.predictor
2
3import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance.StateData
4import java.util.ArrayList
5import java.util.HashMap
6import java.util.List
7import java.util.Map
8import weka.core.matrix.LinearRegression
9import weka.core.matrix.Matrix
10
11class LinearModel {
12 var double ridge;
13 var Map<Object, StateData> stateAndHistory;
14 List<StateData> samples;
15
16 new(double ridge){
17 this.ridge = ridge;
18 stateAndHistory = new HashMap<Object, StateData>();
19 samples = new ArrayList<StateData>();
20 }
21
22 /**
23 * reset the current train data for regression to a new trajectory
24 * @param state: the last state of the trajectory
25 */
26 def resetRegression(Object state){
27 samples.clear();
28
29 if(stateAndHistory.containsKey(state)){
30 var data = stateAndHistory.get(state);
31 var curState = state;
32
33 samples.add(data);
34
35 //loop through data until the oldest state in the record
36 while(stateAndHistory.containsKey(data.lastState) && data.lastState != curState){
37 curState = data.lastState;
38 data = stateAndHistory.get(data.lastState);
39 samples.add(data);
40 }
41 }
42 }
43
44 /**
45 * Add a new data point to the current training set
46 * @param state: the state on which the new data point is calculated
47 * @param features: the set of feature value(x)
48 * @param value: the value of the state (y)
49 * @param lastState: the state which transformed to current state, used to record the trajectory
50 */
51 def feedData(Object state, double[] features, double value, Object lastState){
52 var data = new StateData(features, value, lastState);
53 stateAndHistory.put(state, data);
54 samples.add(data);
55 }
56
57 /**
58 * get prediction for next state, without storing the data point into the training set
59 * @param features: the feature values of current state
60 * @param value: the value of the current state
61 * @param: featuresToPredict: the features of the state wanted to be predected
62 * @return the value of the state to be predicted
63 */
64 def double getPredictionForNextDataSample(double[] features, double value, double[] featuresToPredict){
65 var data = new StateData(features, value, null);
66 samples.add(data);
67
68 // create training set from current data
69 val double[][] xSamples = samples.map[it.features];
70 val double[] ySamples = samples.map[it.value];
71
72 val x = new Matrix(xSamples);
73 val y = new Matrix(ySamples, ySamples.size());
74
75 val regression = new LinearRegression(x, y, ridge);
76 var prediction = predict(regression.coefficients, featuresToPredict);
77
78 //remove the last element just added
79 samples.remove(samples.size - 1);
80 return prediction;
81 }
82
83 def private predict(double[] parameters, double[] featuresToPredict){
84 // the regression will add an initial column for 1's, the first parameter is constant term
85 var result = parameters.get(0);
86 for(var i = 0; i < featuresToPredict.length; i++){
87 result += parameters.get(i) * featuresToPredict.get(i);
88 }
89 return result;
90 }
91} \ No newline at end of file
diff --git a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java
index 148cb243..d9f6f2aa 100644
--- a/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java
+++ b/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.reasoner/src/hu/bme/mit/inf/dslreasoner/viatrasolver/reasoner/dse/HillClimbingOnRealisticMetricStrategyForModelGeneration.java
@@ -26,6 +26,7 @@ import org.eclipse.viatra.query.runtime.api.ViatraQueryMatcher;
26 26
27import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.MetricDistanceGroup; 27import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.MetricDistanceGroup;
28import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetric; 28import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetric;
29import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.PartialInterpretationMetricDistance;
29import hu.bme.mit.inf.dslreasoner.logic.model.builder.DocumentationLevel; 30import hu.bme.mit.inf.dslreasoner.logic.model.builder.DocumentationLevel;
30import hu.bme.mit.inf.dslreasoner.logic.model.builder.LogicReasoner; 31import hu.bme.mit.inf.dslreasoner.logic.model.builder.LogicReasoner;
31import hu.bme.mit.inf.dslreasoner.logic.model.logicproblem.LogicProblem; 32import hu.bme.mit.inf.dslreasoner.logic.model.logicproblem.LogicProblem;
@@ -62,12 +63,12 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
62 private Collection<ViatraQueryMatcher<? extends IPatternMatch>> mayMatchers; 63 private Collection<ViatraQueryMatcher<? extends IPatternMatch>> mayMatchers;
63 private Map<Object, List<Object>> stateAndActivations; 64 private Map<Object, List<Object>> stateAndActivations;
64 private Map<TrajectoryWithFitness, Double> trajectoryFit; 65 private Map<TrajectoryWithFitness, Double> trajectoryFit;
65
66 // Statistics 66 // Statistics
67 private int numberOfStatecoderFail = 0; 67 private int numberOfStatecoderFail = 0;
68 private int numberOfPrintedModel = 0; 68 private int numberOfPrintedModel = 0;
69 private int numberOfSolverCalls = 0; 69 private int numberOfSolverCalls = 0;
70 70 private PartialInterpretationMetricDistance metricDistance;
71
71 public HillClimbingOnRealisticMetricStrategyForModelGeneration( 72 public HillClimbingOnRealisticMetricStrategyForModelGeneration(
72 ReasonerWorkspace workspace, 73 ReasonerWorkspace workspace,
73 ViatraReasonerConfiguration configuration, 74 ViatraReasonerConfiguration configuration,
@@ -112,7 +113,7 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
112 this.solutionStoreWithCopy = new SolutionStoreWithCopy(); 113 this.solutionStoreWithCopy = new SolutionStoreWithCopy();
113 this.solutionStoreWithDiversityDescriptor = new SolutionStoreWithDiversityDescriptor(configuration.diversityRequirement); 114 this.solutionStoreWithDiversityDescriptor = new SolutionStoreWithDiversityDescriptor(configuration.diversityRequirement);
114 115
115 final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper(); 116 //final ObjectiveComparatorHelper objectiveComparatorHelper = context.getObjectiveComparatorHelper();
116 trajectoryFit = new HashMap<TrajectoryWithFitness, Double>(); 117 trajectoryFit = new HashMap<TrajectoryWithFitness, Double>();
117 118
118 this.comparator = new Comparator<TrajectoryWithFitness>() { 119 this.comparator = new Comparator<TrajectoryWithFitness>() {
@@ -124,6 +125,7 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
124 125
125 trajectoiresToExplore = new PriorityQueue<TrajectoryWithFitness>(11, comparator); 126 trajectoiresToExplore = new PriorityQueue<TrajectoryWithFitness>(11, comparator);
126 stateAndActivations = new HashMap<Object, List<Object>>(); 127 stateAndActivations = new HashMap<Object, List<Object>>();
128 metricDistance = new PartialInterpretationMetricDistance();
127 } 129 }
128 130
129 @Override 131 @Override
@@ -145,11 +147,11 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
145 TrajectoryWithFitness currentTrajectoryWithFittness = new TrajectoryWithFitness(firstTrajectory, firstFittness); 147 TrajectoryWithFitness currentTrajectoryWithFittness = new TrajectoryWithFitness(firstTrajectory, firstFittness);
146 trajectoryFit.put(currentTrajectoryWithFittness, Double.MAX_VALUE); 148 trajectoryFit.put(currentTrajectoryWithFittness, Double.MAX_VALUE);
147 trajectoiresToExplore.add(currentTrajectoryWithFittness); 149 trajectoiresToExplore.add(currentTrajectoryWithFittness);
150 Object lastState = null;
148 151
149 //if(configuration) 152 //if(configuration)
150 visualiseCurrentState(); 153 visualiseCurrentState();
151 154
152 PartialInterpretationMetric.initPaths();
153 //create matcher 155 //create matcher
154 int count = 0; 156 int count = 0;
155 mainLoop: while (!isInterrupted && !configuration.progressMonitor.isCancelled()) { 157 mainLoop: while (!isInterrupted && !configuration.progressMonitor.isCancelled()) {
@@ -165,6 +167,9 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
165 logger.debug("New trajectory is chosen: " + currentTrajectoryWithFittness); 167 logger.debug("New trajectory is chosen: " + currentTrajectoryWithFittness);
166 } 168 }
167 context.getDesignSpaceManager().executeTrajectoryWithMinimalBacktrackWithoutStateCoding(currentTrajectoryWithFittness.trajectory); 169 context.getDesignSpaceManager().executeTrajectoryWithMinimalBacktrackWithoutStateCoding(currentTrajectoryWithFittness.trajectory);
170
171 // reset the regression for this trajectory
172 metricDistance.getLinearModel().resetRegression(context.getCurrentStateId());
168 } 173 }
169 } 174 }
170 175
@@ -178,10 +183,10 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
178 //init epsilon and draw 183 //init epsilon and draw
179 double epsilon = 0; 184 double epsilon = 0;
180 double draw = 1; 185 double draw = 1;
181 MetricDistanceGroup heuristics = PartialInterpretationMetric.calculateMetricDistanceKS(model); 186 MetricDistanceGroup heuristics = metricDistance.calculateMetricDistanceKS(model);
182 187
183 if(!stateAndActivations.containsKey(model)) { 188 if(!stateAndActivations.containsKey(context.getCurrentStateId())) {
184 stateAndActivations.put(model, new ArrayList<Object>()); 189 stateAndActivations.put(context.getCurrentStateId(), new ArrayList<Object>());
185 } 190 }
186 191
187 //Output intermediate model 192 //Output intermediate model
@@ -190,41 +195,26 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
190 draw = Math.random(); 195 draw = Math.random();
191 count++; 196 count++;
192 // cut off the trajectory for bad graph 197 // cut off the trajectory for bad graph
193 double distance = heuristics.getMPCDistance();
194 System.out.println("KS"); 198 System.out.println("KS");
195 System.out.println("NA distance: " + heuristics.getNADistance()); 199 System.out.println("NA distance: " + heuristics.getNADistance());
196 System.out.println("MPC distance: " + heuristics.getMPCDistance()); 200 System.out.println("MPC distance: " + heuristics.getMPCDistance());
197 System.out.println("Out degree distance:" + heuristics.getOutDegreeDistance()); 201 System.out.println("Out degree distance:" + heuristics.getOutDegreeDistance());
198 202
199 MetricDistanceGroup lsHeuristic = PartialInterpretationMetric.calculateMetricDistance(model); 203// MetricDistanceGroup lsHeuristic = metricDistance.calculateMetricDistance(model);
200 System.out.println("LS"); 204// System.out.println("LS");
201 System.out.println("NA distance: " + lsHeuristic.getNADistance()); 205// System.out.println("NA distance: " + lsHeuristic.getNADistance());
202 System.out.println("MPC distance: " + lsHeuristic.getMPCDistance()); 206// System.out.println("MPC distance: " + lsHeuristic.getMPCDistance());
203 System.out.println("Out degree distance:" + lsHeuristic.getOutDegreeDistance()); 207// System.out.println("Out degree distance:" + lsHeuristic.getOutDegreeDistance());
204
205 if(distance <= 0.23880597014925373 + 0.00001 && distance >= 0.23880597014925373 - 0.00001) {
206 context.backtrack();
207 final Fitness nextFitness = context.calculateFitness();
208 currentTrajectoryWithFittness = new TrajectoryWithFitness(context.getTrajectory().toArray(), nextFitness);
209 continue;
210 }
211 208
212 //check for next value when doing greedy move 209 //check for next value when doing greedy move
210
213 valueMap = sortWithWeight(activationIds, currentTrajectoryWithFittness.trajectory.length+1); 211 valueMap = sortWithWeight(activationIds, currentTrajectoryWithFittness.trajectory.length+1);
214// if(activationIds.isEmpty() || (model.getNewElements().size() >= 20 && valueMap.get(activationIds.get(0)) > currentValue && epsilon < draw)) {
215// context.backtrack();
216// final Fitness nextFitness = context.calculateFitness();
217// currentTrajectoryWithFittness = new TrajectoryWithFitness(context.getTrajectory().toArray(), nextFitness);
218// continue;
219// }
220 } 212 }
213 lastState = context.getCurrentStateId();
221 while (!isInterrupted && !configuration.progressMonitor.isCancelled() && activationIds.size() > 0) { 214 while (!isInterrupted && !configuration.progressMonitor.isCancelled() && activationIds.size() > 0) {
222 final Object nextActivation = drawWithEpsilonProbabilty(activationIds, valueMap, epsilon, draw); 215 final Object nextActivation = drawWithEpsilonProbabilty(activationIds, valueMap, epsilon, draw);
223// if (!iterator.hasNext()) { 216
224// logger.debug("Last untraversed activation of the state."); 217 stateAndActivations.get(context.getCurrentStateId()).add(nextActivation);
225// trajectoiresToExplore.remove(currentTrajectoryWithFittness);
226// }
227 stateAndActivations.get(context.getModel()).add(nextActivation);
228 logger.debug("Executing new activation: " + nextActivation); 218 logger.debug("Executing new activation: " + nextActivation);
229 context.executeAcitvationId(nextActivation); 219 context.executeAcitvationId(nextActivation);
230 visualiseCurrentState(); 220 visualiseCurrentState();
@@ -242,15 +232,17 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
242 logger.debug("Global contraint is not satisifed."); 232 logger.debug("Global contraint is not satisifed.");
243 context.backtrack(); 233 context.backtrack();
244 } else*/// { 234 } else*/// {
245 if(getNumberOfViolations(mustMatchers) > 8) { 235 /*if(getNumberOfViolations(mustMatchers) > 0) {
236 context.backtrack();
237 }else*/ if(model.getNewElements().size() > 90 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.3) {
246 context.backtrack(); 238 context.backtrack();
247 }else if(model.getNewElements().size() > 90 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.18) { 239 }else if(model.getNewElements().size() > 70 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.45) {
248 context.backtrack(); 240 context.backtrack();
249 }else if(model.getNewElements().size() > 70 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.36) { 241 } else if(model.getNewElements().size() > 50 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.60) {
250 context.backtrack(); 242 context.backtrack();
251 } else if(model.getNewElements().size() > 50 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.72) { 243 } else if(model.getNewElements().size() > 30 && heuristics.getNADistance() + heuristics.getMPCDistance() + heuristics.getOutDegreeDistance() > 0.70) {
252 context.backtrack(); 244 context.backtrack();
253 } else { 245 }else {
254 final Fitness nextFitness = context.calculateFitness(); 246 final Fitness nextFitness = context.calculateFitness();
255 247
256 // the only hard objectives are the size 248 // the only hard objectives are the size
@@ -271,9 +263,14 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
271 continue; 263 continue;
272 } 264 }
273 265
266
274 TrajectoryWithFitness nextTrajectoryWithFittness = new TrajectoryWithFitness( 267 TrajectoryWithFitness nextTrajectoryWithFittness = new TrajectoryWithFitness(
275 context.getTrajectory().toArray(), nextFitness); 268 context.getTrajectory().toArray(), nextFitness);
276 trajectoryFit.put(nextTrajectoryWithFittness, calculateCurrentStateValue(nextTrajectoryWithFittness.trajectory.length)); 269 int step = nextTrajectoryWithFittness.trajectory.length;
270 int violation = getNumberOfViolations(mustMatchers) + getNumberOfViolations(mayMatchers);
271 metricDistance.getLinearModel().feedData(context.getCurrentStateId(), metricDistance.calculateFeature(step, violation), calculateCurrentStateValue(step, violation), lastState);
272
273 trajectoryFit.put(nextTrajectoryWithFittness, calculateCurrentStateValue(step, violation));
277 trajectoiresToExplore.add(nextTrajectoryWithFittness); 274 trajectoiresToExplore.add(nextTrajectoryWithFittness);
278 275
279 int compare = objectiveComparatorHelper.compare(currentTrajectoryWithFittness.fitness, 276 int compare = objectiveComparatorHelper.compare(currentTrajectoryWithFittness.fitness,
@@ -298,9 +295,9 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
298 currentTrajectoryWithFittness = null; 295 currentTrajectoryWithFittness = null;
299 context.backtrack(); 296 context.backtrack();
300 } 297 }
301 PartialInterpretation model = (PartialInterpretation) context.getModel(); 298// PartialInterpretation model = (PartialInterpretation) context.getModel();
302 PartialInterpretationMetric.calculateMetric(model, "debug/metric/output", context.getCurrentStateId().toString(), count); 299// PartialInterpretationMetric.calculateMetric(model, "debug/metric/output", context.getCurrentStateId().toString(), count);
303 count++; 300// count++;
304 logger.info("Interrupted."); 301 logger.info("Interrupted.");
305 } 302 }
306 303
@@ -314,8 +311,8 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
314 // do hill climbing 311 // do hill climbing
315 for(Object id : activationIds) { 312 for(Object id : activationIds) {
316 context.executeAcitvationId(id); 313 context.executeAcitvationId(id);
317 314 int violation = getNumberOfViolations(mayMatchers) + getNumberOfViolations(mustMatchers);
318 valueMap.put(id, calculateCurrentStateValue(factor)); 315 valueMap.put(id, calculateFutureStateValue(factor, violation));
319 context.backtrack(); 316 context.backtrack();
320 } 317 }
321 318
@@ -323,14 +320,27 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
323 return valueMap; 320 return valueMap;
324 } 321 }
325 322
326 private double calculateCurrentStateValue(int factor) { 323 private double calculateFutureStateValue(int step, int violation) {
324 double currentValue = calculateCurrentStateValue(step, violation);
325 if(step > 40) {
326 double[] toPredict = metricDistance.calculateFeature(200, violation);
327 try {
328 return metricDistance.getLinearModel().getPredictionForNextDataSample(metricDistance.calculateFeature(step, violation), currentValue, toPredict);
329 }catch(IllegalArgumentException e) {
330 return currentValue;
331 }
332 }else {
333 return currentValue;
334 }
335 }
336
337 private double calculateCurrentStateValue(int factor, int violation) {
327 PartialInterpretation model = (PartialInterpretation) context.getModel(); 338 PartialInterpretation model = (PartialInterpretation) context.getModel();
328 MetricDistanceGroup g = PartialInterpretationMetric.calculateMetricDistanceKS(model); 339 MetricDistanceGroup g = metricDistance.calculateMetricDistanceKS(model);
329 340
330 int violations = getNumberOfViolations(mayMatchers); 341 double consistenceWeights = 1- 1.0/(1+violation);
331 double consistenceWeights = 1.0/(1+violations);
332 342
333 return(2.5 / Math.log(factor) * (g.getNADistance() + g.getMPCDistance() + g.getOutDegreeDistance()) + 1-consistenceWeights); 343 return( /*/ Math.log(factor)*/(g.getNADistance() + g.getMPCDistance() + g.getOutDegreeDistance()) + consistenceWeights);
334 } 344 }
335 345
336 private int getNumberOfViolations(Collection<ViatraQueryMatcher<? extends IPatternMatch>> matchers) { 346 private int getNumberOfViolations(Collection<ViatraQueryMatcher<? extends IPatternMatch>> matchers) {
@@ -379,8 +389,8 @@ public class HillClimbingOnRealisticMetricStrategyForModelGeneration implements
379 List<Object> activationIds; 389 List<Object> activationIds;
380 try { 390 try {
381 activationIds = new ArrayList<Object>(context.getCurrentActivationIds()); 391 activationIds = new ArrayList<Object>(context.getCurrentActivationIds());
382 if(stateAndActivations.containsKey(context.getModel())) { 392 if(stateAndActivations.containsKey(context.getCurrentStateId())) {
383 activationIds.removeAll(stateAndActivations.get(context.getModel())); 393 activationIds.removeAll(stateAndActivations.get(context.getCurrentStateId()));
384 } 394 }
385 Collections.shuffle(activationIds); 395 Collections.shuffle(activationIds);
386 } catch (NullPointerException e) { 396 } catch (NullPointerException e) {