diff options
Diffstat (limited to 'Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py')
-rw-r--r-- | Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py | 21 |
1 files changed, 12 insertions, 9 deletions
diff --git a/Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py b/Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py index 2f39ca93..a66802d5 100644 --- a/Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py +++ b/Metrics/Metrics-Calculation/metrics_plot/model comparison/src/plot_ks_stats.py | |||
@@ -8,17 +8,19 @@ import matplotlib.pyplot as plt | |||
8 | from scipy import stats | 8 | from scipy import stats |
9 | import numpy as np | 9 | import numpy as np |
10 | from GraphType import GraphCollection | 10 | from GraphType import GraphCollection |
11 | import DistributionMetrics as metrics | ||
11 | 12 | ||
12 | def main(): | 13 | def main(): |
13 | # read models | 14 | # read models |
14 | human = GraphCollection('../input/humanOutput/', 500, 'Human') | 15 | # human = GraphCollection('../input/humanOutput/', 500, 'Human') |
15 | viatra30 = GraphCollection('../input/viatraOutput30/', 500, 'Viatra (30 nodes)') | 16 | # viatra30 = GraphCollection('../input/viatraOutput30/', 500,'Viatra (30 nodes)') |
16 | # viatra60 = GraphCollection('../input/viatraOutput60/', 500, 'Viatra (60 nodes)') | 17 | # viatra60 = GraphCollection('../input/viatraOutput60/', 500, 'Viatra (60 nodes)') |
17 | # viatra100 = GraphCollection('../input/viatraOutput100/', 500, 'Viatra (100 nodes)') | 18 | viatra100 = GraphCollection('../input/viatraOutput100/', 500, 'Viatra (100 nodes)') |
18 | # random = GraphCollection('../input/randomOutput/', 500, 'Random') | 19 | # random = GraphCollection('../input/randomOutput/', 500, 'Random') |
19 | # alloy = GraphCollection('../input/alloyOutput/', 500, 'Alloy (30 nodes)') | 20 | # alloy = GraphCollection('../input/alloyOutput/', 500, 'Alloy (30 nodes)') |
20 | 21 | realistic_viatra = GraphCollection('../input/viatra_output_consistent_100/', 50, 'Realistic Viatra With Some Constraints (100 nodes)') | |
21 | models_to_compare = [human, viatra30] | 22 | human100 = GraphCollection('../input/human_output_100/', 304, 'Human') |
23 | models_to_compare = [human100, realistic_viatra, viatra100] | ||
22 | 24 | ||
23 | # define output folder | 25 | # define output folder |
24 | outputFolder = '../output/' | 26 | outputFolder = '../output/' |
@@ -38,7 +40,7 @@ def calculateKSMatrix(dists): | |||
38 | for i in range(len(dist)): | 40 | for i in range(len(dist)): |
39 | matrix[i,i] = 0 | 41 | matrix[i,i] = 0 |
40 | for j in range(i+1, len(dist)): | 42 | for j in range(i+1, len(dist)): |
41 | value, p = stats.ks_2samp(dist[i], dist[j]) | 43 | value = metrics.euclidean_distance(dist[i], dist[j]) |
42 | matrix[i, j] = value | 44 | matrix[i, j] = value |
43 | matrix[j, i] = value | 45 | matrix[j, i] = value |
44 | return matrix | 46 | return matrix |
@@ -50,13 +52,14 @@ def calculateMDS(dissimilarities): | |||
50 | return trans | 52 | return trans |
51 | 53 | ||
52 | def plot(graphTypes, coords, title='',index = 0, savePath = ''): | 54 | def plot(graphTypes, coords, title='',index = 0, savePath = ''): |
53 | half_length = int(coords.shape[0] / len(graphTypes)) | ||
54 | color = ['blue', 'red', 'green', 'yellow'] | 55 | color = ['blue', 'red', 'green', 'yellow'] |
55 | plt.figure(index, figsize=(7, 4)) | 56 | plt.figure(index, figsize=(7, 4)) |
56 | plt.title(title) | 57 | plt.title(title) |
58 | index = 0 | ||
57 | for i in range(len(graphTypes)): | 59 | for i in range(len(graphTypes)): |
58 | x = (coords[(i*half_length):((i+1)*half_length), 0].tolist()) | 60 | x = (coords[index:index+graphTypes[i].size, 0].tolist()) |
59 | y = (coords[(i*half_length):((i+1)*half_length), 1].tolist()) | 61 | y = (coords[index:index+graphTypes[i].size, 1].tolist()) |
62 | index += graphTypes[i].size | ||
60 | plt.plot(x, y, color=color[i], marker='o', label = graphTypes[i].name, linestyle='', alpha=0.7) | 63 | plt.plot(x, y, color=color[i], marker='o', label = graphTypes[i].name, linestyle='', alpha=0.7) |
61 | plt.legend(loc='upper right') | 64 | plt.legend(loc='upper right') |
62 | plt.savefig(fname = savePath, dpi=150) | 65 | plt.savefig(fname = savePath, dpi=150) |