package ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.distance import ca.mcgill.ecse.dslreasoner.realistic.metrics.calculator.metrics.MetricSampleGroup import com.google.common.collect.Sets import java.text.DecimalFormat import java.util.HashMap import java.util.List class JSDistance extends CostDistance { var HashMap mpcPMF; var HashMap naPMF; var HashMap outDegreePMF; var HashMap nodeTypesPMF; var DecimalFormat formatter; new(MetricSampleGroup g){ var mpcSamples = g.mpcSamples; var naSamples = g.naSamples.stream.mapToDouble([it]).toArray(); var outDegreeSamples = g.outDegreeSamples.stream.mapToDouble([it]).toArray(); //needs to format the number to string avoid precision issue formatter = new DecimalFormat("#0.00000"); mpcPMF = pmfFromSamples(mpcSamples, formatter); naPMF = pmfFromSamples(naSamples, formatter); outDegreePMF = pmfFromSamples(outDegreeSamples, formatter); nodeTypesPMF = g.nodeTypeSamples; } def private combinePMF(HashMap pmf1, HashMap pmf2){ var pmfMap = new HashMap(); var union = Sets.union(pmf1.keySet(), pmf2.keySet()); for(key : union){ // corresponding to M in JS distance var value = 1.0/2 * (pmf1.getOrDefault(key, 0.0) + pmf2.getOrDefault(key, 0.0)); pmfMap.put(key, value); } return pmfMap; } def private jsDivergence(HashMap p, HashMap q){ val m = combinePMF(q, p); var distance = 1.0/2 * klDivergence(p, m) + 1.0/2 * klDivergence(q, m); return distance; } def klDivergence(HashMap p, HashMap q){ var distance = 0.0; for(key : q.keySet()){ //need to convert log e to log 2 if(p.containsKey(key)){ distance -= p.get(key) * Math.log(q.get(key) / p.get(key)) / Math.log(2); } } return distance; } override double mpcDistance(List samples){ // map list to array var map = pmfFromSamples(samples.stream().mapToDouble([it]).toArray(), formatter); //if the size of array is smaller than 2, ks distance cannot be performed, simply return 1 if(map.size < 2) return 1; return jsDivergence(map, mpcPMF); } override double naDistance(List samples){ // map list to array var map = pmfFromSamples(samples.stream().mapToDouble([it]).toArray(), formatter); //if the size of array is smaller than 2, ks distance cannot be performed, simply return 1 if(map.size < 2) return 1; return jsDivergence(map, naPMF); } override double outDegreeDistance(List samples){ // map list to array var map = pmfFromSamples(samples.stream().mapToDouble([it]).toArray(), formatter); //if the size of array is smaller than 2, ks distance cannot be performed, simply return 1 if(map.size < 2) return 1; return jsDivergence(map, outDegreePMF); } def nodeTypeDistance(HashMap samples){ return klDivergence(samples, nodeTypesPMF); } }