aboutsummaryrefslogtreecommitdiffstats
path: root/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/JSDistance.xtend
blob: 4a0a0dc34df8368478223cff96ecd0a130321684 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance

import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MetricSampleGroup
import com.google.common.collect.Sets
import java.text.DecimalFormat
import java.util.HashMap
import java.util.List

class JSDistance extends CostDistance {
	var HashMap<String, Double> mpcPMF;
	var HashMap<String, Double> naPMF;
	var HashMap<String, Double> outDegreePMF;
	var HashMap<String, Double> nodeTypesPMF;
	var DecimalFormat formatter;
	
	new(MetricSampleGroup g){
		var mpcSamples = g.mpcSamples;
		var naSamples = g.naSamples.stream.mapToDouble([it]).toArray();
		var outDegreeSamples = g.outDegreeSamples.stream.mapToDouble([it]).toArray();
		
		//needs to format the number to string avoid precision issue
		formatter = new DecimalFormat("#0.00000");  
		
		mpcPMF = pmfFromSamples(mpcSamples, formatter);
   		naPMF = pmfFromSamples(naSamples, formatter);
   		outDegreePMF = pmfFromSamples(outDegreeSamples, formatter);	
   		nodeTypesPMF = g.nodeTypeSamples; 
	}
	
	def private combinePMF(HashMap<String, Double> pmf1, HashMap<String, Double> pmf2){
		var pmfMap = new HashMap<String, Double>();
		
		var union = Sets.union(pmf1.keySet(), pmf2.keySet());
		
		for(key : union){
			// corresponding to M in JS distance
			var value = 1.0/2 * (pmf1.getOrDefault(key, 0.0) + pmf2.getOrDefault(key, 0.0));
			pmfMap.put(key, value);
		}
		return pmfMap;
	}
	
	def private jsDivergence(HashMap<String, Double> p, HashMap<String, Double> q){
		val m = combinePMF(q, p);
		var distance = 1.0/2 * klDivergence(p, m) + 1.0/2 * klDivergence(q, m);
		return distance;
	}
	
	def klDivergence(HashMap<String, Double> p, HashMap<String, Double> q){
		var distance = 0.0;
		for(key : q.keySet()){
			//need to convert log e to log 2
			if(p.containsKey(key)){
				distance -= p.get(key) * Math.log(q.get(key) / p.get(key)) / Math.log(2);
			}
		}
		return distance;
	}
	
	override double mpcDistance(List<Double> samples){
		// map list to array
		var map = pmfFromSamples(samples.stream().mapToDouble([it]).toArray(), formatter);
		//if the size of array is smaller than 2, ks distance cannot be performed, simply return 1
		if(map.size < 2) return 1;
		return jsDivergence(map, mpcPMF);
	}
	
	override double naDistance(List<Double> samples){
		// map list to array
		var map = pmfFromSamples(samples.stream().mapToDouble([it]).toArray(), formatter);
		
		//if the size of array is smaller than 2, ks distance cannot be performed, simply return 1
		if(map.size < 2) return 1;
		return jsDivergence(map, naPMF);
	}
	
	override double outDegreeDistance(List<Double> samples){
		// map list to array
		var map = pmfFromSamples(samples.stream().mapToDouble([it]).toArray(), formatter);
		//if the size of array is smaller than 2, ks distance cannot be performed, simply return 1
		if(map.size < 2) return 1;
		return jsDivergence(map, outDegreePMF);
	}
	
	def nodeTypeDistance(HashMap<String, Double> samples){
		return klDivergence(samples, nodeTypesPMF);
	}
}