diff --git a/metric.py b/metric.py index eada084..840f7aa 100644 --- a/metric.py +++ b/metric.py @@ -1,4 +1,4 @@ -from sklearn.metrics import v_measure_score, adjusted_rand_score, accuracy_score +from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score, accuracy_score from sklearn.cluster import KMeans from scipy.optimize import linear_sum_assignment from torch.utils.data import DataLoader @@ -36,7 +36,7 @@ def purity(y_true, y_pred): def evaluate(label, pred): - nmi = v_measure_score(label, pred) + nmi = normalized_mutual_info_score(label, pred) ari = adjusted_rand_score(label, pred) acc = cluster_acc(label, pred) pur = purity(label, pred)