diff --git a/CHANGELOG.md b/CHANGELOG.md index b349e5a3d13..2805105ee6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,17 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008)) - - - Added `RandScore` metric to cluster package ([#2025](https://github.com/Lightning-AI/torchmetrics/pull/2025)) - - -- Added `CalinskiHarabaszScore` metric to cluster package ([#2036](https://github.com/Lightning-AI/torchmetrics/pull/2036)) - - - Added `NormalizedMutualInfoScore` metric to cluster package ([#2029](https://github.com/Lightning-AI/torchmetrics/pull/2029)) - - +- Added `CalinskiHarabaszScore` metric to cluster package ([#2036](https://github.com/Lightning-AI/torchmetrics/pull/2036)) +- Added `DunnIndex` metric to cluster package ([#2049](https://github.com/Lightning-AI/torchmetrics/pull/2049)) ### Changed diff --git a/docs/source/clustering/dunn_index.rst b/docs/source/clustering/dunn_index.rst new file mode 100644 index 00000000000..69246661a60 --- /dev/null +++ b/docs/source/clustering/dunn_index.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Dunn Index + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg + :tags: Clustering + +.. include:: ../links.rst + +########## +Dunn Index +########## + +Module Interface +________________ + +.. autoclass:: torchmetrics.clustering.DunnIndex + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.clustering.dunn_index diff --git a/docs/source/links.rst b/docs/source/links.rst index 78a2b34d764..e6c85b2994a 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -154,3 +154,4 @@ .. _Normalized Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html .. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools .. _Rand Score: https://link.springer.com/article/10.1007/BF01908075 +.. _Dunn Index: https://en.wikipedia.org/wiki/Dunn_index diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py index 939f52b3c69..6f4e67e1197 100644 --- a/src/torchmetrics/clustering/__init__.py +++ b/src/torchmetrics/clustering/__init__.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore +from torchmetrics.clustering.dunn_index import DunnIndex from torchmetrics.clustering.mutual_info_score import MutualInfoScore from torchmetrics.clustering.normalized_mutual_info_score import NormalizedMutualInfoScore from torchmetrics.clustering.rand_score import RandScore __all__ = [ "CalinskiHarabaszScore", + "DunnIndex", "MutualInfoScore", "NormalizedMutualInfoScore", "RandScore", diff --git a/src/torchmetrics/clustering/dunn_index.py b/src/torchmetrics/clustering/dunn_index.py new file mode 100644 index 00000000000..11b8b5cf3f5 --- /dev/null +++ b/src/torchmetrics/clustering/dunn_index.py @@ -0,0 +1,116 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.functional.clustering.dunn_index import dunn_index +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["DunnIndex.plot"] + + +class DunnIndex(Metric): + r"""Compute `Dunn Index`_. + + .. math:: + DI_m = \frac{\min_{1\leq i>> import torch + >>> from torchmetrics.clustering import DunnIndex + >>> data = torch.tensor([[0, 0], [0.5, 0], [1, 0], [0.5, 1]]) + >>> labels = torch.tensor([0, 0, 0, 1]) + >>> dunn_index = DunnIndex(p=2) + >>> dunn_index(data, labels) + tensor(2.) + + """ + + is_differentiable: bool = True + higher_is_better: bool = True + full_state_update: bool = True + plot_lower_bound: float = 0.0 + data: List[Tensor] + labels: List[Tensor] + + def __init__(self, p: float = 2, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.p = p + + self.add_state("data", default=[], dist_reduce_fx="cat") + self.add_state("labels", default=[], dist_reduce_fx="cat") + + def update(self, data: Tensor, labels: Tensor) -> None: + """Update state with predictions and targets.""" + self.data.append(data) + self.labels.append(labels) + + def compute(self) -> Tensor: + """Compute mutual information over state.""" + return dunn_index(dim_zero_cat(self.data), dim_zero_cat(self.labels), self.p) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import DunnIndex + >>> data = torch.tensor([[0, 0], [0.5, 0], [1, 0], [0.5, 1]]) + >>> labels = torch.tensor([0, 0, 0, 1]) + >>> metric = DunnIndex(p=2) + >>> metric.update(data, labels) + >>> fig_, ax_ = metric.plot(metric.compute()) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py index 00d8c1e788f..08656e9e5e4 100644 --- a/src/torchmetrics/functional/clustering/__init__.py +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score +from torchmetrics.functional.clustering.dunn_index import dunn_index from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from torchmetrics.functional.clustering.normalized_mutual_info_score import normalized_mutual_info_score from torchmetrics.functional.clustering.rand_score import rand_score __all__ = [ "calinski_harabasz_score", + "dunn_index", "mutual_info_score", "normalized_mutual_info_score", "rand_score", diff --git a/src/torchmetrics/functional/clustering/dunn_index.py b/src/torchmetrics/functional/clustering/dunn_index.py new file mode 100644 index 00000000000..b3b8d5df50c --- /dev/null +++ b/src/torchmetrics/functional/clustering/dunn_index.py @@ -0,0 +1,83 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from itertools import combinations +from typing import Tuple + +import torch +from torch import Tensor + + +def _dunn_index_update(data: Tensor, labels: Tensor, p: float) -> Tuple[Tensor, Tensor]: + """Update and return variables required to compute the Dunn index. + + Args: + data: feature vectors of shape (n_samples, n_features) + labels: cluster labels + p: p-norm (distance metric) + + Returns: + intercluster_distance: intercluster distances + max_intracluster_distance: max intracluster distances + + """ + unique_labels, inverse_indices = labels.unique(return_inverse=True) + clusters = [data[inverse_indices == label_idx] for label_idx in range(len(unique_labels))] + centroids = [c.mean(dim=0) for c in clusters] + + intercluster_distance = torch.linalg.norm( + torch.stack([a - b for a, b in combinations(centroids, 2)], dim=0), ord=p, dim=1 + ) + + max_intracluster_distance = torch.stack( + [torch.linalg.norm(ci - mu, ord=p, dim=1).max() for ci, mu in zip(clusters, centroids)] + ) + + return intercluster_distance, max_intracluster_distance + + +def _dunn_index_compute(intercluster_distance: Tensor, max_intracluster_distance: Tensor) -> Tensor: + """Compute the Dunn index based on updated state. + + Args: + intercluster_distance: intercluster distances + max_intracluster_distance: max intracluster distances + + Returns: + scalar tensor with the dunn index + + """ + return intercluster_distance.min() / max_intracluster_distance.max() + + +def dunn_index(data: Tensor, labels: Tensor, p: float = 2) -> Tensor: + """Compute the Dunn index. + + Args: + data: feature vectors + labels: cluster labels + p: p-norm used for distance metric + + Returns: + scalar tensor with the dunn index + + Example: + >>> from torchmetrics.functional.clustering import dunn_index + >>> data = torch.tensor([[0, 0], [0.5, 0], [1, 0], [0.5, 1]]) + >>> labels = torch.tensor([0, 0, 0, 1]) + >>> dunn_index(data, labels) + tensor(2.) + + """ + pairwise_distance, max_distance = _dunn_index_update(data, labels, p) + return _dunn_index_compute(pairwise_distance, max_distance) diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index 23ece71cbaf..f5581ac3429 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -151,6 +151,13 @@ def calculate_contingency_matrix( return contingency +def _is_real_discrete_label(x: Tensor) -> bool: + """Check if tensor of labels is real and discrete.""" + if x.ndim != 1: + raise ValueError(f"Expected arguments to be 1-d tensors but got {x.ndim}-d tensors.") + return not (torch.is_floating_point(x) or torch.is_complex(x)) + + def check_cluster_labels(preds: Tensor, target: Tensor) -> None: """Check shape of input tensors and if they are real, discrete tensors. @@ -160,18 +167,8 @@ def check_cluster_labels(preds: Tensor, target: Tensor) -> None: """ _check_same_shape(preds, target) - if preds.ndim != 1: - raise ValueError(f"Expected arguments to be 1d tensors but got {preds.ndim} and {target.ndim}") - if ( - torch.is_floating_point(preds) - or torch.is_complex(preds) - or torch.is_floating_point(target) - or torch.is_complex(target) - ): - raise ValueError( - f"Expected real, discrete values but received {preds.dtype} for" - f"predictions and {target.dtype} for target labels instead." - ) + if not (_is_real_discrete_label(preds) and _is_real_discrete_label(target)): + raise ValueError(f"Expected real, discrete values for x but received {preds.dtype} and {target.dtype}.") def calcualte_pair_cluster_confusion_matrix( diff --git a/tests/unittests/clustering/inputs.py b/tests/unittests/clustering/inputs.py index 61fff9eed31..15b24298f7d 100644 --- a/tests/unittests/clustering/inputs.py +++ b/tests/unittests/clustering/inputs.py @@ -14,38 +14,44 @@ from collections import namedtuple import torch +from sklearn.datasets import make_blobs -from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES +from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES from unittests.helpers import seed_all seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -NUM_CLASSES = 10 - # extrinsic input for clustering metrics that requires predicted clustering labels and target clustering labels -_single_target_extrinsic1 = Input( +ExtrinsicInput = namedtuple("ExtrinsicInput", ["preds", "target"]) + +# intrinsic input for clustering metrics that requires only predicted clustering labels and the cluster embeddings +IntrinsicInput = namedtuple("IntrinsicInput", ["data", "labels"]) + + +def _batch_blobs(num_batches, num_samples, num_features, num_classes): + data, labels = [], [] + for _ in range(num_batches): + _data, _labels = make_blobs(num_samples, num_features, centers=num_classes) + data.append(torch.tensor(_data)) + labels.append(torch.tensor(_labels)) + + return IntrinsicInput(data=torch.stack(data), labels=torch.stack(labels)) + + +_single_target_extrinsic1 = ExtrinsicInput( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ) -_single_target_extrinsic2 = Input( +_single_target_extrinsic2 = ExtrinsicInput( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ) -_float_inputs_extrinsic = Input( +_float_inputs_extrinsic = ExtrinsicInput( preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), target=torch.rand((NUM_BATCHES, BATCH_SIZE)) ) -# intrinsic input for clustering metrics that requires only predicted clustering labels and the cluster embeddings -_single_target_intrinsic1 = Input( - preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), -) - -_single_target_intrinsic2 = Input( - preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), -) +_single_target_intrinsic1 = _batch_blobs(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES) +_single_target_intrinsic2 = _batch_blobs(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES) diff --git a/tests/unittests/clustering/test_calinski_harabasz_score.py b/tests/unittests/clustering/test_calinski_harabasz_score.py index cc063d1ebb5..98c86218b04 100644 --- a/tests/unittests/clustering/test_calinski_harabasz_score.py +++ b/tests/unittests/clustering/test_calinski_harabasz_score.py @@ -24,10 +24,10 @@ @pytest.mark.parametrize( - "preds, target", + "data, labels", [ - (_single_target_intrinsic1.preds, _single_target_intrinsic1.target), - (_single_target_intrinsic2.preds, _single_target_intrinsic2.target), + (_single_target_intrinsic1.data, _single_target_intrinsic1.labels), + (_single_target_intrinsic2.data, _single_target_intrinsic2.labels), ], ) class TestCalinskiHarabaszScore(MetricTester): @@ -36,21 +36,21 @@ class TestCalinskiHarabaszScore(MetricTester): atol = 1e-5 @pytest.mark.parametrize("ddp", [True, False]) - def test_calinski_harabasz_score(self, preds, target, ddp): + def test_calinski_harabasz_score(self, data, labels, ddp): """Test class implementation of metric.""" self.run_class_metric_test( ddp=ddp, - preds=preds, - target=target, + preds=data, + target=labels, metric_class=CalinskiHarabaszScore, reference_metric=sklearn_calinski_harabasz_score, ) - def test_calinski_harabasz_score_functional(self, preds, target): + def test_calinski_harabasz_score_functional(self, data, labels): """Test functional implementation of metric.""" self.run_functional_metric_test( - preds=preds, - target=target, + preds=data, + target=labels, metric_functional=calinski_harabasz_score, reference_metric=sklearn_calinski_harabasz_score, ) diff --git a/tests/unittests/clustering/test_dunn_index.py b/tests/unittests/clustering/test_dunn_index.py new file mode 100644 index 00000000000..b035dc3d48a --- /dev/null +++ b/tests/unittests/clustering/test_dunn_index.py @@ -0,0 +1,84 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from itertools import combinations + +import numpy as np +import pytest +from torchmetrics.clustering.dunn_index import DunnIndex +from torchmetrics.functional.clustering.dunn_index import dunn_index + +from unittests.clustering.inputs import ( + _single_target_intrinsic1, + _single_target_intrinsic2, +) +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + + +def _np_dunn_index(data, labels, p): + unique_labels, inverse_indices = np.unique(labels, return_inverse=True) + clusters = [data[inverse_indices == label_idx] for label_idx in range(len(unique_labels))] + centroids = [c.mean(axis=0) for c in clusters] + + intercluster_distance = np.linalg.norm( + np.stack([a - b for a, b in combinations(centroids, 2)], axis=0), ord=p, axis=1 + ) + + max_intracluster_distance = np.stack( + [np.linalg.norm(ci - mu, ord=p, axis=1).max() for ci, mu in zip(clusters, centroids)] + ) + + return intercluster_distance.min() / max_intracluster_distance.max() + + +@pytest.mark.parametrize( + "data, labels", + [ + (_single_target_intrinsic1.data, _single_target_intrinsic1.labels), + (_single_target_intrinsic2.data, _single_target_intrinsic2.labels), + ], +) +@pytest.mark.parametrize( + "p", + [0, 1, 2], +) +class TestDunnIndex(MetricTester): + """Test class for `DunnIndex` metric.""" + + atol = 1e-5 + + @pytest.mark.parametrize("ddp", [True, False]) + def test_dunn_index(self, data, labels, p, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=data, + target=labels, + metric_class=DunnIndex, + reference_metric=partial(_np_dunn_index, p=p), + metric_args={"p": p}, + ) + + def test_dunn_index_functional(self, data, labels, p): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=data, + target=labels, + metric_functional=dunn_index, + reference_metric=partial(_np_dunn_index, p=p), + p=p, + ) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 44259e69083..ff80729cc28 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -91,7 +91,13 @@ MultilabelROC, MultilabelSpecificity, ) -from torchmetrics.clustering import CalinskiHarabaszScore, MutualInfoScore, NormalizedMutualInfoScore, RandScore +from torchmetrics.clustering import ( + CalinskiHarabaszScore, + DunnIndex, + MutualInfoScore, + NormalizedMutualInfoScore, + RandScore, +) from torchmetrics.detection import PanopticQuality from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio @@ -620,6 +626,7 @@ pytest.param(RandScore, _nominal_input, _nominal_input, id="rand score"), pytest.param(CalinskiHarabaszScore, lambda: torch.randn(100, 3), _nominal_input, id="calinski harabasz score"), pytest.param(NormalizedMutualInfoScore, _nominal_input, _nominal_input, id="normalized mutual info score"), + pytest.param(DunnIndex, lambda: torch.randn(100, 3), _nominal_input, id="dunn index"), ], ) @pytest.mark.parametrize("num_vals", [1, 3])