Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tie breaking in ndcg metric #2031

Merged
merged 11 commits into from
Sep 8, 2023
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042))
- Fixed tie breaking in ndcg metric ([#2031](https://github.com/Lightning-AI/torchmetrics/pull/2031))


- Fixed bug in `BootStrapper` when very few samples were evaluated that could lead to crash ([#2052](https://github.com/Lightning-AI/torchmetrics/pull/2052))
Expand All @@ -50,11 +50,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug when creating multiple plots that lead to not all plots being shown ([#2060](https://github.com/Lightning-AI/torchmetrics/pull/2060))


- Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042))


## [1.1.1] - 2023-08-29

### Added

- Added `average` argument to `MeanAveragePrecision` ([#2018](https://github.com/Lightning-AI/torchmetrics/pull/2018)
- Added `average` argument to `MeanAveragePrecision` ([#2018](https://github.com/Lightning-AI/torchmetrics/pull/2018))

### Fixed

Expand Down
66 changes: 53 additions & 13 deletions src/torchmetrics/functional/retrieval/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,53 @@
from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def _dcg(target: Tensor) -> Tensor:
"""Compute Discounted Cumulative Gain for input tensor."""
denom = torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0)
return (target / denom).sum(dim=-1)
def _tie_average_dcg(target: Tensor, preds: Tensor, discount_cumsum: Tensor) -> Tensor:
"""Translated version of sklearns `_tie_average_dcg` function.

Args:
target: ground truth about each document relevance.
preds: estimated probabilities of each document to be relevant.
discount_cumsum: cumulative sum of the discount.

Returns:
The cumulative gain of the tied elements.

"""
_, inv, counts = torch.unique(-preds, return_inverse=True, return_counts=True)
ranked = torch.zeros_like(counts, dtype=torch.float32)
ranked.scatter_add_(0, inv, target.to(dtype=ranked.dtype))
ranked = ranked / counts
groups = counts.cumsum(dim=0) - 1
discount_sums = torch.zeros_like(counts, dtype=torch.float32)
discount_sums[0] = discount_cumsum[groups[0]]
discount_sums[1:] = discount_cumsum[groups].diff()
return (ranked * discount_sums).sum()


def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: bool) -> Tensor:
"""Translated version of sklearns `_dcg_sample_scores` function.

Args:
target: ground truth about each document relevance.
preds: estimated probabilities of each document to be relevant.
top_k: consider only the top k elements
ignore_ties: If True, ties are ignored. If False, ties are averaged.

Returns:
The cumulative gain

"""
discount = 1.0 / (torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0))
discount[top_k:] = 0.0

if ignore_ties:
ranking = preds.argsort(descending=True)
ranked = target[ranking]
cumulative_gain = (discount * ranked).sum()
else:
discount_cumsum = discount.cumsum(dim=-1)
cumulative_gain = _tie_average_dcg(target, preds, discount_cumsum)
return cumulative_gain


def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
Expand Down Expand Up @@ -59,15 +102,12 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int]
if not (isinstance(top_k, int) and top_k > 0):
raise ValueError("`top_k` has to be a positive integer or None")

sorted_target = target[torch.argsort(preds, dim=-1, descending=True)][:top_k]
ideal_target = torch.sort(target, descending=True)[0][:top_k]

ideal_dcg = _dcg(ideal_target)
target_dcg = _dcg(sorted_target)
gain = _dcg_sample_scores(target, preds, top_k, ignore_ties=False)
normalized_gain = _dcg_sample_scores(target, target, top_k, ignore_ties=True)

# filter undefined scores
all_irrelevant = ideal_dcg == 0
target_dcg[all_irrelevant] = 0
target_dcg[~all_irrelevant] /= ideal_dcg[~all_irrelevant]
all_irrelevant = normalized_gain == 0
gain[all_irrelevant] = 0
gain[~all_irrelevant] /= normalized_gain[~all_irrelevant]

return target_dcg.mean()
return gain.mean()
13 changes: 13 additions & 0 deletions tests/unittests/retrieval/test_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
import pytest
import torch
from sklearn.metrics import ndcg_score
from torch import Tensor
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
Expand Down Expand Up @@ -185,3 +186,15 @@ def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, messag
exception_type=ValueError,
kwargs_update=metric_args,
)


def test_corner_case_with_tied_scores():
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/2022."""
target = torch.tensor([[10, 0, 0, 1, 5]])
preds = torch.tensor([[0.1, 0, 0, 0, 0.1]])

for k in [1, 3, 5]:
assert torch.allclose(
retrieval_normalized_dcg(preds, target, top_k=k),
torch.tensor([ndcg_score(target, preds, k=k)], dtype=torch.float32),
)
Loading