Skip to content

Commit

Permalink
feat: add cal_external_cluster_validation_metrics();
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Sep 24, 2023
1 parent 6cf2caa commit 1984734
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions pypots/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,37 @@ def cal_cluster_purity(
return cluster_purity


def cal_external_cluster_validation_metrics(class_predictions, targets):
"""Computer all external cluster validation metrics available in PyPOTS and return as a dictionary.
Parameters
----------
class_predictions :
Clustering results returned by a clusterer.
targets :
Ground truth (correct) clustering results.
Returns
-------
external_cluster_validation_metrics : dict
A dictionary contains all external cluster validation metrics available in PyPOTS.
"""

ri = cal_rand_index(class_predictions, targets)
ari = cal_adjusted_rand_index(class_predictions, targets)
nmi = cal_nmi(class_predictions, targets)
cp = cal_cluster_purity(class_predictions, targets)

external_cluster_validation_metrics = {
"rand_index": ri,
"adjusted_rand_index": ari,
"nmi": nmi,
"cluster_purity": cp,
}
return external_cluster_validation_metrics


def cal_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float:
"""Compute the mean Silhouette Coefficient of all samples.
Expand Down

0 comments on commit 1984734

Please sign in to comment.