-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathutils.py
130 lines (105 loc) · 3.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import random
import numpy as np
from munkres import Munkres
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
def adjust_learning_rate(optimizer, epoch):
lr = 0.001 * (0.1 ** (epoch // 50))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def target_distribution(q):
weight = q ** 2 / q.sum(0)
return (weight.t() / weight.sum(1)).t()
def cluster_acc(y_true, y_pred):
y_true = y_true - np.min(y_true)
l1 = list(set(y_true))
numclass1 = len(l1)
l2 = list(set(y_pred))
numclass2 = len(l2)
ind = 0
if numclass1 != numclass2:
for i in l1:
if i in l2:
pass
else:
y_pred[ind] = i
ind += 1
l2 = list(set(y_pred))
numclass2 = len(l2)
if numclass1 != numclass2:
print('error')
return
cost = np.zeros((numclass1, numclass2), dtype=int)
for i, c1 in enumerate(l1):
mps = [i1 for i1, e1 in enumerate(y_true) if e1 == c1]
for j, c2 in enumerate(l2):
mps_d = [i1 for i1 in mps if y_pred[i1] == c2]
cost[i][j] = len(mps_d)
m = Munkres()
cost = cost.__neg__().tolist()
indexes = m.compute(cost)
new_predict = np.zeros(len(y_pred))
for i, c in enumerate(l1):
c2 = l2[indexes[i][1]]
ai = [ind for ind, elm in enumerate(y_pred) if elm == c2]
new_predict[ai] = c
acc = metrics.accuracy_score(y_true, new_predict)
f1_macro = metrics.f1_score(y_true, new_predict, average='macro')
precision_macro = metrics.precision_score(y_true, new_predict, average='macro')
recall_macro = metrics.recall_score(y_true, new_predict, average='macro')
f1_micro = metrics.f1_score(y_true, new_predict, average='micro')
precision_micro = metrics.precision_score(y_true, new_predict, average='micro')
recall_micro = metrics.recall_score(y_true, new_predict, average='micro')
return acc, f1_macro
def eva(y_true, y_pred, epoch=0):
acc, f1 = cluster_acc(y_true, y_pred)
nmi = nmi_score(y_true, y_pred, average_method='arithmetic')
ari = ari_score(y_true, y_pred)
print('Epoch_{}'.format(epoch), ':acc {:.4f}'.format(acc), ', nmi {:.4f}'.format(nmi), ', ari {:.4f}'.format(ari),
', f1 {:.4f}'.format(f1))
return acc, nmi, ari, f1
def parameter(model):
params = list(model.parameters())
k = 0
for i in params:
l = 1
for j in i.size():
l *= j
k = k + l
print("sum:" + str(k))
return str(k)
def plot_pca_scatter(name, n_clusters, X_pca, y):
if name == "usps":
colors = ['black', 'blue', 'purple', 'yellow', 'pink', 'red', 'lime', 'cyan', 'orange', 'gray'] # usps:10
elif name == "acm":
colors = ['yellow', 'pink', 'red'] # acm:3
elif name == "dblp":
colors = ['yellow', 'pink', 'red', 'orange'] # dblp:4
elif name == "cite":
colors = ['yellow', 'pink', 'red', 'lime', 'cyan', 'orange'] # cite:6
elif name == "hhar":
colors = ['green', 'blue', 'red', 'pink', 'yellow', 'purple'] # hhar:6
elif name == "reut":
colors = ['green', 'blue', 'red', 'pink'] # reut:4
else:
print("Loading Error!")
for i in range(len(colors)):
px = X_pca[:, 0][y == i]
py = X_pca[:, 1][y == i]
plt.scatter(px, py, c=colors[i])
plt.legend(np.arange(n_clusters))
plt.xlabel('First Principal Component')
plt.ylabel('Second Principal Component')
# plt.show()
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True