aboutsummaryrefslogtreecommitdiffstats
path: root/Metrics/Metrics-Calculation/ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator/src/ca/mcgill/ecse/dslreasoner/realistic/metrics/calculator/distance/JSDistance.xtend
blob: df65b81f5be90ce8cb4527aabea3d77a1e48c4ea (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance

import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.app.Domain
import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.io.RepMetricsReader
import com.google.common.collect.Sets
import java.text.DecimalFormat
import java.util.HashMap
import java.util.List

class JSDistance extends CostDistance {
	var HashMap<String, Double> mpcPMF;
	var HashMap<String, Double> naPMF;
	var HashMap<String, Double> outDegreePMF;
	var DecimalFormat formatter;
	
	new(Domain d){
		var metrics = RepMetricsReader.read(d);
		var mpcSamples = metrics.mpcSamples;
		var naSamples = metrics.naSamples.stream.mapToDouble([it]).toArray();
		var outDegreeSamples = metrics.outDegreeSamples.stream.mapToDouble([it]).toArray();
		
		//needs to format the number to string avoid precision issue
		formatter = new DecimalFormat("#0.00000");  
		
		mpcPMF = pmfFromSamples(mpcSamples, formatter);
   		naPMF = pmfFromSamples(naSamples, formatter);
   		outDegreePMF = pmfFromSamples(outDegreeSamples, formatter);	
	}
	
	def private combinePMF(HashMap<String, Double> pmf1, HashMap<String, Double> pmf2){
		var pmfMap = new HashMap<String, Double>();
		
		var union = Sets.union(pmf1.keySet(), pmf2.keySet());
		
		for(key : union){
			// corresponding to M in JS distance
			var value = 1.0/2 * (pmf1.getOrDefault(key, 0.0) + pmf2.getOrDefault(key, 0.0));
			pmfMap.put(key, value);
		}
		return pmfMap;
	}
	
	def private jsDivergence(HashMap<String, Double> p, HashMap<String, Double> q){
		val m = combinePMF(q, p);
		var distance = 1.0/2 * klDivergence(p, m) + 1.0/2 * klDivergence(q, m);
		return distance;
	}
	
	def klDivergence(HashMap<String, Double> p, HashMap<String, Double> q){
		var distance = 0.0;
		for(key : q.keySet()){
			//need to convert log e to log 2
			if(p.containsKey(key)){
				distance -= p.get(key) * Math.log(q.get(key) / p.get(key)) / Math.log(2);
			}
		}
		return distance;
	}
	
	override double mpcDistance(List<Double> samples){
		// map list to array
		var map = pmfFromSamples(samples.stream().mapToDouble([it]).toArray(), formatter);
		//if the size of array is smaller than 2, ks distance cannot be performed, simply return 1
		if(map.size < 2) return 1;
		return jsDivergence(map, mpcPMF);
	}
	
	override double naDistance(List<Double> samples){
		// map list to array
		var map = pmfFromSamples(samples.stream().mapToDouble([it]).toArray(), formatter);
		
		//if the size of array is smaller than 2, ks distance cannot be performed, simply return 1
		if(map.size < 2) return 1;
		return jsDivergence(map, naPMF);
	}
	
	override double outDegreeDistance(List<Double> samples){
		// map list to array
		var map = pmfFromSamples(samples.stream().mapToDouble([it]).toArray(), formatter);
		//if the size of array is smaller than 2, ks distance cannot be performed, simply return 1
		if(map.size < 2) return 1;
		return jsDivergence(map, outDegreePMF);
	}
}