Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New metric: Adjusted mutual info score #2058

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3bc820e
move average method arg checker to utils
matsumotosan Sep 5, 2023
ff0657c
docs. metric done except expected mutual info.
matsumotosan Sep 5, 2023
e6e0583
Merge branch 'master' into 2003-adjusted-mutual-info-score
SkafteNicki Sep 7, 2023
8de271b
Merge branch 'master' into 2003-adjusted-mutual-info-score
Borda Sep 7, 2023
eb49ff8
changelog
SkafteNicki Sep 8, 2023
f65db04
Merge branch 'master' into 2003-adjusted-mutual-info-score
SkafteNicki Sep 8, 2023
e21de52
Update src/torchmetrics/clustering/adjusted_mutual_info_score.py
matsumotosan Sep 11, 2023
7b9f40b
Update src/torchmetrics/clustering/adjusted_mutual_info_score.py
matsumotosan Sep 11, 2023
dec2a6a
Update src/torchmetrics/clustering/adjusted_mutual_info_score.py
matsumotosan Sep 11, 2023
5b74513
Merge branch 'master' into 2003-adjusted-mutual-info-score
SkafteNicki Sep 11, 2023
cfbde6a
changelog
SkafteNicki Sep 11, 2023
d77d1bf
most tests passing. failing one cluster case.
matsumotosan Sep 12, 2023
79e7091
Merge branch 'master' into 2003-adjusted-mutual-info-score
matsumotosan Sep 12, 2023
fc54b71
Merge branch 'master' into 2003-adjusted-mutual-info-score
matsumotosan Sep 13, 2023
22d2f6f
Merge branch 'master' into 2003-adjusted-mutual-info-score
matsumotosan Sep 13, 2023
4181884
Merge branch 'master' into 2003-adjusted-mutual-info-score
SkafteNicki Sep 15, 2023
4cee286
Fix an inappropriate test expression to remove a logical short circui…
munahaf Sep 15, 2023
b18218f
New metric: Davies bouldin score (#2071)
matsumotosan Sep 17, 2023
33b6f4d
Change broken links (#2083)
SkafteNicki Sep 15, 2023
c47bc69
Improvements to docs on custom implementations (#2061)
SkafteNicki Sep 15, 2023
6c670ac
New metrics: Homogeneity, Completness, V-Measure (#2053)
SkafteNicki Sep 15, 2023
3abd2c6
Merge branch 'master' into 2003-adjusted-mutual-info-score
matsumotosan Sep 17, 2023
5ac2865
add tolerances to allclose test and generalized mean check
matsumotosan Sep 17, 2023
6f0d8b5
docstring spelling mistake
matsumotosan Sep 17, 2023
e6b8ff5
fix issue
SkafteNicki Sep 18, 2023
c609048
fix device placement + remove scipy dependency
SkafteNicki Sep 18, 2023
c333e63
fix zero entropy case
SkafteNicki Sep 18, 2023
0a292bb
account for numerics
SkafteNicki Sep 18, 2023
8b5da2f
fix mypy issues
SkafteNicki Sep 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `FowlkesMallowsIndex` ([#2066](https://github.com/Lightning-AI/torchmetrics/pull/2066))

- `AdjustedMutualInfoScore` ([#2058](https://github.com/Lightning-AI/torchmetrics/pull/2058))

- `DaviesBouldinScore` ([#2071](https://github.com/Lightning-AI/torchmetrics/pull/2071))


Expand Down
21 changes: 21 additions & 0 deletions docs/source/clustering/adjusted_mutual_info_score.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Adjusted Mutual Information Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg
:tags: Clustering

.. include:: ../links.rst

#################################
Adjusted Mutual Information Score
#################################

Module Interface
________________

.. autoclass:: torchmetrics.clustering.AdjustedMutualInfoScore
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.clustering.adjusted_mutual_info_score
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
.. _GIOU: https://arxiv.org/abs/1902.09630
.. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information
.. _Normalized Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html
.. _Adjusted Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_mutual_info_score.html#sklearn.metrics.adjusted_mutual_info_score
.. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools
.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075
.. _faster-coco-eval: https://github.com/MiXaiLL76/faster_coco_eval
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.clustering.adjusted_mutual_info_score import AdjustedMutualInfoScore
from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore
from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore
from torchmetrics.clustering.davies_bouldin_score import DaviesBouldinScore
Expand All @@ -26,6 +27,7 @@
from torchmetrics.clustering.rand_score import RandScore

__all__ = [
"AdjustedMutualInfoScore",
"AdjustedRandScore",
"CalinskiHarabaszScore",
"CompletenessScore",
Expand Down
127 changes: 127 additions & 0 deletions src/torchmetrics/clustering/adjusted_mutual_info_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Literal, Optional, Sequence, Union

from torch import Tensor

from torchmetrics.clustering.mutual_info_score import MutualInfoScore
from torchmetrics.functional.clustering.adjusted_mutual_info_score import (
_validate_average_method_arg,
adjusted_mutual_info_score,
)
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["AdjustedMutualInfoScore.plot"]


class AdjustedMutualInfoScore(MutualInfoScore):
r"""Compute `Adjusted Mutual Information Score`_.

.. math::
AMI(U,V) = \frac{MI(U,V) - E(MI(U,V))}{avg(H(U), H(V)) - E(MI(U,V))}

Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, :math:`M_p(U,V)` is the
generalized mean of order :math:`p` of :math:`U` and :math:`V`, and :math:`MI(U,V)` is the mutual information score
between clusters :math:`U` and :math:`V`. The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields
the same mutual information score.

This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not
be available in practice since clustering in generally is used for unsupervised learning.

As input to ``forward`` and ``update`` the metric accepts the following input:

- ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels
- ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels

As output of ``forward`` and ``compute`` the metric returns the following output:

- ``ami_score`` (:class:`~torch.Tensor`): A tensor with the Adjusted Mutual Information Score

Args:
average_method: Method used to calculate generalized mean for normalization. Choose between
``'min'``, ``'geometric'``, ``'arithmetic'``, ``'max'``.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> from torchmetrics.clustering import AdjustedMutualInfoScore
>>> preds = torch.tensor([2, 1, 0, 1, 0])
>>> target = torch.tensor([0, 2, 1, 1, 0])
>>> ami_score = AdjustedMutualInfoScore(average_method="arithmetic")
>>> ami_score(preds, target)
tensor(-0.2500)

"""

is_differentiable: bool = True
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor
matsumotosan marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self, average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic", **kwargs: Any
) -> None:
super().__init__(**kwargs)
_validate_average_method_arg(average_method)
self.average_method = average_method

def compute(self) -> Tensor:
"""Compute normalized mutual information over state."""
return adjusted_mutual_info_score(dim_zero_cat(self.preds), dim_zero_cat(self.target), self.average_method)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import AdjustedMutualInfoScore
>>> metric = AdjustedMutualInfoScore()
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import AdjustedMutualInfoScore
>>> metric = AdjustedMutualInfoScore()
>>> for _ in range(10):
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())

"""
return self._plot(val, ax)
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.clustering.adjusted_mutual_info_score import adjusted_mutual_info_score
from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score
from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score
from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score
Expand All @@ -26,6 +27,7 @@
from torchmetrics.functional.clustering.rand_score import rand_score

__all__ = [
"adjusted_mutual_info_score",
"adjusted_rand_score",
"calinski_harabasz_score",
"completeness_score",
Expand Down
121 changes: 121 additions & 0 deletions src/torchmetrics/functional/clustering/adjusted_mutual_info_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal

import torch
from torch import Tensor, tensor

from torchmetrics.functional.clustering.mutual_info_score import _mutual_info_score_compute, _mutual_info_score_update
from torchmetrics.functional.clustering.utils import (
_validate_average_method_arg,
calculate_entropy,
calculate_generalized_mean,
)


def adjusted_mutual_info_score(
preds: Tensor, target: Tensor, average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic"
) -> Tensor:
"""Compute adjusted mutual information between two clusterings.

Args:
preds: predicted cluster labels
target: ground truth cluster labels
average_method: normalizer computation method

Returns:
Scalar tensor with adjusted mutual info score between 0.0 and 1.0

Example:
>>> from torchmetrics.functional.clustering import adjusted_mutual_info_score
>>> preds = torch.tensor([2, 1, 0, 1, 0])
>>> target = torch.tensor([0, 2, 1, 1, 0])
>>> adjusted_mutual_info_score(preds, target, "arithmetic")
tensor(-0.2500)

"""
_validate_average_method_arg(average_method)
contingency = _mutual_info_score_update(preds, target)
mutual_info = _mutual_info_score_compute(contingency)
expected_mutual_info = expected_mutual_info_score(contingency, target.numel())
normalizer = calculate_generalized_mean(
torch.stack([calculate_entropy(preds), calculate_entropy(target)]), average_method
)
denominator = normalizer - expected_mutual_info
if denominator < 0:
denominator = torch.min(torch.tensor([denominator, -torch.finfo(denominator.dtype).eps]))
else:
denominator = torch.max(torch.tensor([denominator, torch.finfo(denominator.dtype).eps]))

return (mutual_info - expected_mutual_info) / denominator


def expected_mutual_info_score(contingency: Tensor, n_samples: int) -> Tensor:
"""Calculated expected mutual information score between two clusterings.

Implementation taken from sklearn/metrics/cluster/_expected_mutual_info_fast.pyx.

Args:
contingency: contingency matrix
n_samples: number of samples

Returns:
expected_mutual_info_score: expected mutual information score

"""
n_rows, n_cols = contingency.shape
a = torch.ravel(contingency.sum(dim=1))
b = torch.ravel(contingency.sum(dim=0))

# Check if preds or target labels only have one cluster
if a.numel() == 1 or b.numel() == 1:
return tensor(0.0, device=a.device)

nijs = torch.arange(0, max([a.max().item(), b.max().item()]) + 1, device=a.device)
nijs[0] = 1

term1 = nijs / n_samples
log_a = torch.log(a)
log_b = torch.log(b)

log_nnij = torch.log(torch.tensor(n_samples, device=a.device)) + torch.log(nijs)

gln_a = torch.lgamma(a + 1)
gln_b = torch.lgamma(b + 1)
gln_na = torch.lgamma(n_samples - a + 1)
gln_nb = torch.lgamma(n_samples - b + 1)
gln_nnij = torch.lgamma(nijs + 1) + torch.lgamma(torch.tensor(n_samples + 1, dtype=a.dtype, device=a.device))

emi = tensor(0.0, device=a.device)
for i in range(n_rows):
for j in range(n_cols):
start = int(max(1, a[i].item() - n_samples + b[j].item()))
end = int(min(a[i].item(), b[j].item()) + 1)

for nij in range(start, end):
term2 = log_nnij[nij] - log_a[i] - log_b[j]
gln = (
gln_a[i]
+ gln_b[j]
+ gln_na[i]
+ gln_nb[j]
- gln_nnij[nij]
- torch.lgamma(a[i] - nij + 1)
- torch.lgamma(b[j] - nij + 1)
- torch.lgamma(n_samples - a[i] - b[j] + nij + 1)
)
term3 = torch.exp(gln)
emi += term1[nij] * term2 * term3

return emi
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,12 @@
from torch import Tensor

from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
from torchmetrics.functional.clustering.utils import calculate_entropy, calculate_generalized_mean, check_cluster_labels


def _validate_average_method_arg(
average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic"
) -> None:
if average_method not in ("min", "geometric", "arithmetic", "max"):
raise ValueError(
"Expected argument `average_method` to be one of `min`, `geometric`, `arithmetic`, `max`,"
f"but got {average_method}"
)
from torchmetrics.functional.clustering.utils import (
_validate_average_method_arg,
calculate_entropy,
calculate_generalized_mean,
check_cluster_labels,
)


def normalized_mutual_info_score(
Expand Down
26 changes: 25 additions & 1 deletion src/torchmetrics/functional/clustering/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,30 @@
from torchmetrics.utilities.checks import _check_same_shape


def is_nonnegative(x: Tensor, atol: float = 1e-5) -> Tensor:
"""Return True if all elements of tensor are nonnegative within certain tolerance.

Args:
x: tensor
atol: absolute tolerance

Returns:
Boolean tensor indicating if all values are nonnegative

"""
return torch.logical_or(x > 0.0, torch.abs(x) < atol).all()


def _validate_average_method_arg(
average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic"
) -> None:
if average_method not in ("min", "geometric", "arithmetic", "max"):
raise ValueError(
"Expected argument `average_method` to be one of `min`, `geometric`, `arithmetic`, `max`,"
f"but got {average_method}"
)


def calculate_entropy(x: Tensor) -> Tensor:
"""Calculate entropy for a tensor of labels.

Expand Down Expand Up @@ -74,7 +98,7 @@ def calculate_generalized_mean(x: Tensor, p: Union[int, Literal["min", "geometri
tensor(1.6438)

"""
if torch.is_complex(x) or torch.any(x <= 0.0):
if torch.is_complex(x) or not is_nonnegative(x):
raise ValueError("`x` must contain positive real numbers")

if isinstance(p, str):
Expand Down
Loading
Loading