aboutsummaryrefslogtreecommitdiffstats
path: root/Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py
diff options
context:
space:
mode:
Diffstat (limited to 'Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py')
-rw-r--r--Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py21
1 files changed, 12 insertions, 9 deletions
diff --git a/Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py b/Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py
index 2f39ca93..a66802d5 100644
--- a/Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py
+++ b/Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py
@@ -8,17 +8,19 @@ import matplotlib.pyplot as plt
8from scipy import stats 8from scipy import stats
9import numpy as np 9import numpy as np
10from GraphType import GraphCollection 10from GraphType import GraphCollection
11import DistributionMetrics as metrics
11 12
12def main(): 13def main():
13 # read models 14 # read models
14 human = GraphCollection('../input/humanOutput/', 500, 'Human') 15 # human = GraphCollection('../input/humanOutput/', 500, 'Human')
15 viatra30 = GraphCollection('../input/viatraOutput30/', 500, 'Viatra (30 nodes)') 16 # viatra30 = GraphCollection('../input/viatraOutput30/', 500,'Viatra (30 nodes)')
16 # viatra60 = GraphCollection('../input/viatraOutput60/', 500, 'Viatra (60 nodes)') 17 # viatra60 = GraphCollection('../input/viatraOutput60/', 500, 'Viatra (60 nodes)')
17 # viatra100 = GraphCollection('../input/viatraOutput100/', 500, 'Viatra (100 nodes)') 18 viatra100 = GraphCollection('../input/viatraOutput100/', 500, 'Viatra (100 nodes)')
18 # random = GraphCollection('../input/randomOutput/', 500, 'Random') 19 # random = GraphCollection('../input/randomOutput/', 500, 'Random')
19 # alloy = GraphCollection('../input/alloyOutput/', 500, 'Alloy (30 nodes)') 20 # alloy = GraphCollection('../input/alloyOutput/', 500, 'Alloy (30 nodes)')
20 21 realistic_viatra = GraphCollection('../input/viatra_output_consistent_100/', 50, 'Realistic Viatra With Some Constraints (100 nodes)')
21 models_to_compare = [human, viatra30] 22 human100 = GraphCollection('../input/human_output_100/', 304, 'Human')
23 models_to_compare = [human100, realistic_viatra, viatra100]
22 24
23 # define output folder 25 # define output folder
24 outputFolder = '../output/' 26 outputFolder = '../output/'
@@ -38,7 +40,7 @@ def calculateKSMatrix(dists):
38 for i in range(len(dist)): 40 for i in range(len(dist)):
39 matrix[i,i] = 0 41 matrix[i,i] = 0
40 for j in range(i+1, len(dist)): 42 for j in range(i+1, len(dist)):
41 value, p = stats.ks_2samp(dist[i], dist[j]) 43 value = metrics.euclidean_distance(dist[i], dist[j])
42 matrix[i, j] = value 44 matrix[i, j] = value
43 matrix[j, i] = value 45 matrix[j, i] = value
44 return matrix 46 return matrix
@@ -50,13 +52,14 @@ def calculateMDS(dissimilarities):
50 return trans 52 return trans
51 53
52def plot(graphTypes, coords, title='',index = 0, savePath = ''): 54def plot(graphTypes, coords, title='',index = 0, savePath = ''):
53 half_length = int(coords.shape[0] / len(graphTypes))
54 color = ['blue', 'red', 'green', 'yellow'] 55 color = ['blue', 'red', 'green', 'yellow']
55 plt.figure(index, figsize=(7, 4)) 56 plt.figure(index, figsize=(7, 4))
56 plt.title(title) 57 plt.title(title)
58 index = 0
57 for i in range(len(graphTypes)): 59 for i in range(len(graphTypes)):
58 x = (coords[(i*half_length):((i+1)*half_length), 0].tolist()) 60 x = (coords[index:index+graphTypes[i].size, 0].tolist())
59 y = (coords[(i*half_length):((i+1)*half_length), 1].tolist()) 61 y = (coords[index:index+graphTypes[i].size, 1].tolist())
62 index += graphTypes[i].size
60 plt.plot(x, y, color=color[i], marker='o', label = graphTypes[i].name, linestyle='', alpha=0.7) 63 plt.plot(x, y, color=color[i], marker='o', label = graphTypes[i].name, linestyle='', alpha=0.7)
61 plt.legend(loc='upper right') 64 plt.legend(loc='upper right')
62 plt.savefig(fname = savePath, dpi=150) 65 plt.savefig(fname = savePath, dpi=150)