aboutsummaryrefslogtreecommitdiffstats
path: root/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/JSDistance.xtend
diff options
context:
space:
mode:
Diffstat (limited to 'Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/JSDistance.xtend')
-rw-r--r--Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/JSDistance.xtend88
1 files changed, 88 insertions, 0 deletions
diff --git a/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/JSDistance.xtend b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/JSDistance.xtend
new file mode 100644
index 00000000..4a0a0dc3
--- /dev/null
+++ b/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/JSDistance.xtend
@@ -0,0 +1,88 @@
1package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance
2
3import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MetricSampleGroup
4import com.google.common.collect.Sets
5import java.text.DecimalFormat
6import java.util.HashMap
7import java.util.List
8
9class JSDistance extends CostDistance {
10 var HashMap<String, Double> mpcPMF;
11 var HashMap<String, Double> naPMF;
12 var HashMap<String, Double> outDegreePMF;
13 var HashMap<String, Double> nodeTypesPMF;
14 var DecimalFormat formatter;
15
16 new(MetricSampleGroup g){
17 var mpcSamples = g.mpcSamples;
18 var naSamples = g.naSamples.stream.mapToDouble([it]).toArray();
19 var outDegreeSamples = g.outDegreeSamples.stream.mapToDouble([it]).toArray();
20
21 //needs to format the number to string avoid precision issue
22 formatter = new DecimalFormat("#0.00000");
23
24 mpcPMF = pmfFromSamples(mpcSamples, formatter);
25 naPMF = pmfFromSamples(naSamples, formatter);
26 outDegreePMF = pmfFromSamples(outDegreeSamples, formatter);
27 nodeTypesPMF = g.nodeTypeSamples;
28 }
29
30 def private combinePMF(HashMap<String, Double> pmf1, HashMap<String, Double> pmf2){
31 var pmfMap = new HashMap<String, Double>();
32
33 var union = Sets.union(pmf1.keySet(), pmf2.keySet());
34
35 for(key : union){
36 // corresponding to M in JS distance
37 var value = 1.0/2 * (pmf1.getOrDefault(key, 0.0) + pmf2.getOrDefault(key, 0.0));
38 pmfMap.put(key, value);
39 }
40 return pmfMap;
41 }
42
43 def private jsDivergence(HashMap<String, Double> p, HashMap<String, Double> q){
44 val m = combinePMF(q, p);
45 var distance = 1.0/2 * klDivergence(p, m) + 1.0/2 * klDivergence(q, m);
46 return distance;
47 }
48
49 def klDivergence(HashMap<String, Double> p, HashMap<String, Double> q){
50 var distance = 0.0;
51 for(key : q.keySet()){
52 //need to convert log e to log 2
53 if(p.containsKey(key)){
54 distance -= p.get(key) * Math.log(q.get(key) / p.get(key)) / Math.log(2);
55 }
56 }
57 return distance;
58 }
59
60 override double mpcDistance(List<Double> samples){
61 // map list to array
62 var map = pmfFromSamples(samples.stream().mapToDouble([it]).toArray(), formatter);
63 //if the size of array is smaller than 2, ks distance cannot be performed, simply return 1
64 if(map.size < 2) return 1;
65 return jsDivergence(map, mpcPMF);
66 }
67
68 override double naDistance(List<Double> samples){
69 // map list to array
70 var map = pmfFromSamples(samples.stream().mapToDouble([it]).toArray(), formatter);
71
72 //if the size of array is smaller than 2, ks distance cannot be performed, simply return 1
73 if(map.size < 2) return 1;
74 return jsDivergence(map, naPMF);
75 }
76
77 override double outDegreeDistance(List<Double> samples){
78 // map list to array
79 var map = pmfFromSamples(samples.stream().mapToDouble([it]).toArray(), formatter);
80 //if the size of array is smaller than 2, ks distance cannot be performed, simply return 1
81 if(map.size < 2) return 1;
82 return jsDivergence(map, outDegreePMF);
83 }
84
85 def nodeTypeDistance(HashMap<String, Double> samples){
86 return klDivergence(samples, nodeTypesPMF);
87 }
88} \ No newline at end of file