-
Notifications
You must be signed in to change notification settings - Fork 58
/
my_surgery.py
161 lines (132 loc) · 5.5 KB
/
my_surgery.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from importlib import util
import community as community_louvain
import networkx as nx
import numpy as np
def ARI(G, clustering, clustering_label="club"):
"""
Computer the Adjust Rand Index (clustering accuracy) of "clustering" with "clustering_label" as ground truth.
Parameters
----------
G : NetworkX graph
A given NetworkX graph with node attribute "clustering_label" as ground truth.
clustering : dict or list or list of set
Predicted community clustering.
clustering_label : str
Node attribute name for ground truth.
Returns
-------
ari : float
Adjust Rand Index for predicted community.
"""
if util.find_spec("sklearn") is not None:
from sklearn import preprocessing, metrics
else:
print("scikit-learn not installed, skipped...")
return -1
complex_list = nx.get_node_attributes(G, clustering_label)
le = preprocessing.LabelEncoder()
y_true = le.fit_transform(list(complex_list.values()))
if isinstance(clustering, dict):
# python-louvain partition format
y_pred = np.array([clustering[v] for v in complex_list.keys()])
elif isinstance(clustering[0], set):
# networkx partition format
predict_dict = {c: idx for idx, comp in enumerate(clustering) for c in comp}
y_pred = np.array([predict_dict[v] for v in complex_list.keys()])
elif isinstance(clustering, list):
# sklearn partition format
y_pred = clustering
else:
return -1
return metrics.adjusted_rand_score(y_true, y_pred)
def my_surgery(G_origin: nx.Graph(), weight="weight", cut=0):
"""A simple surgery function that remove the edges with weight above a threshold
Parameters
----------
G_origin : NetworkX graph
A graph with ``weight`` as Ricci flow metric to cut.
weight:
The edge weight used as Ricci flow metric. (Default value = "weight")
cut:
Manually assigned cutoff point.
Returns
-------
G : NetworkX graph
A graph after surgery.
"""
G = G_origin.copy()
w = nx.get_edge_attributes(G, weight)
assert cut >= 0, "Cut value should be greater than 0."
if not cut:
cut = (max(w.values()) - 1.0) * 0.6 + 1.0 # Guess a cut point as default
to_cut = []
for n1, n2 in G.edges():
if G[n1][n2][weight] > cut:
to_cut.append((n1, n2))
print("*************** Surgery time ****************")
print("* Cut %d edges." % len(to_cut))
G.remove_edges_from(to_cut)
print("* Number of nodes now: %d" % G.number_of_nodes())
print("* Number of edges now: %d" % G.number_of_edges())
cc = list(nx.connected_components(G))
print("* Modularity now: %f " % nx.algorithms.community.quality.modularity(G, cc))
print("* ARI now: %f " % ARI(G, cc))
print("*********************************************")
return G
def check_accuracy(G_origin, weight="weight", clustering_label="value", plot_cut=True):
"""To check the clustering quality while cut the edges with weight using different threshold
Parameters
----------
G_origin : NetworkX graph
A graph with ``weight`` as Ricci flow metric to cut.
weight: float
The edge weight used as Ricci flow metric. (Default value = "weight")
clustering_label : str
Node attribute name for ground truth.
plot_cut: bool
To plot the good guessed cut or not.
"""
if util.find_spec("matplotlib") is not None:
import matplotlib.pyplot as plt
else:
print("matplotlib not installed, skipped to show the cut graph...")
return -1
G = G_origin.copy()
modularity, ari = [], []
maxw = max(nx.get_edge_attributes(G, weight).values())
cutoff_range = np.arange(maxw, 1, -0.025)
for cutoff in cutoff_range:
edge_trim_list = []
for n1, n2 in G.edges():
if G[n1][n2][weight] > cutoff:
edge_trim_list.append((n1, n2))
G.remove_edges_from(edge_trim_list)
# Get connected component after cut as clustering
clustering = {c: idx for idx, comp in enumerate(nx.connected_components(G)) for c in comp}
# Compute modularity and ari
modularity.append(community_louvain.modularity(clustering, G, weight))
ari.append(ARI(G, clustering, clustering_label=clustering_label))
plt.xlim(maxw, 0)
plt.xlabel("Edge weight cutoff")
plt.plot(cutoff_range, modularity, alpha=0.8)
plt.plot(cutoff_range, ari, alpha=0.8)
if plot_cut:
good_cut = -1
mod_last = modularity[-1]
drop_threshold = 0.01 # at least drop this much to considered as a drop for good_cut
# check drop from 1 -> maxw
for i in range(len(modularity) - 1, 0, -1):
mod_now = modularity[i]
if mod_last > mod_now > 1e-4 and abs(mod_last - mod_now) / mod_last > drop_threshold:
if good_cut != -1:
print("Other cut:%f, diff:%f, mod_now:%f, mod_last:%f, ari:%f" % (
cutoff_range[i + 1], mod_last - mod_now, mod_now, mod_last, ari[i + 1]))
else:
good_cut = cutoff_range[i + 1]
print("*Good Cut:%f, diff:%f, mod_now:%f, mod_last:%f, ari:%f" % (
good_cut, mod_last - mod_now, mod_now, mod_last, ari[i + 1]))
mod_last = mod_now
plt.axvline(x=good_cut, color="red")
plt.legend(['Modularity', 'Adjust Rand Index', 'Good cut'])
else:
plt.legend(['Modularity', 'Adjust Rand Index'])