Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented HitRate for IR #576

Merged
merged 6 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Short Term Objective Intelligibility (`STOI`) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))


- Added `RetrievalHitRate` metric to retrieval package ([#576](https://github.com/PyTorchLightning/metrics/pull/576))


### Changed

- `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))
Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,13 @@ retrieval_normalized_dcg [func]
.. autofunction:: torchmetrics.functional.retrieval_normalized_dcg
:noindex:


retrieval_hit_rate [func]
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.retrieval_hit_rate
:noindex:

****
Text
****
Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,13 @@ RetrievalNormalizedDCG
.. autoclass:: torchmetrics.RetrievalNormalizedDCG
:noindex:


RetrievalHitRate
~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.RetrievalHitRate
:noindex:

****
Text
****
Expand Down
147 changes: 147 additions & 0 deletions tests/retrieval/test_hit_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
from torch import Tensor

from tests.helpers import seed_all
from tests.retrieval.helpers import (
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_k,
_errors_test_class_metric_parameters_no_pos_target,
_errors_test_functional_metric_parameters_default,
_errors_test_functional_metric_parameters_k,
)
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
from torchmetrics.retrieval.retrieval_hit_rate import RetrievalHitRate

seed_all(42)


def _hit_rate_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
"""Didn't find a reliable implementation of Hit Rate in Information Retrieval, so, reimplementing here."""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs

if k is None:
k = len(preds)

if target.sum() > 0:
order_indexes = np.argsort(preds, axis=0)[::-1]
relevant = np.sum(target[order_indexes][:k])
return float(relevant > 0.0)
return np.NaN


class TestHitRate(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
k: int,
):
metric_args = {"empty_target_action": empty_target_action, "k": k}

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalHitRate,
sk_metric=_hit_rate_at_k,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

@pytest.mark.parametrize(**_default_metric_functional_input_arguments)
@pytest.mark.parametrize("k", [None, 1, 4, 10])
def test_functional_metric(self, preds: Tensor, target: Tensor, k: int):
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=retrieval_hit_rate,
sk_metric=_hit_rate_at_k,
metric_args={},
k=k,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_cpu(
indexes=indexes,
preds=preds,
target=target,
metric_module=RetrievalHitRate,
metric_functional=retrieval_hit_rate,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_gpu(
indexes=indexes,
preds=preds,
target=target,
metric_module=RetrievalHitRate,
metric_functional=retrieval_hit_rate,
)

@pytest.mark.parametrize(
**_concat_tests(
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_pos_target,
_errors_test_class_metric_parameters_k,
)
)
def test_arguments_class_metric(
self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict
):
self.run_metric_class_arguments_test(
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalHitRate,
message=message,
metric_args=metric_args,
exception_type=ValueError,
kwargs_update={},
)

@pytest.mark.parametrize(
**_concat_tests(
_errors_test_functional_metric_parameters_default,
_errors_test_functional_metric_parameters_k,
)
)
def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict):
self.run_functional_metric_arguments_test(
preds=preds,
target=target,
metric_functional=retrieval_hit_rate,
message=message,
exception_type=ValueError,
kwargs_update=metric_args,
)
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
)
from torchmetrics.retrieval import ( # noqa: E402
RetrievalFallOut,
RetrievalHitRate,
RetrievalMAP,
RetrievalMRR,
RetrievalNormalizedDCG,
Expand Down Expand Up @@ -116,6 +117,7 @@
"R2Score",
"Recall",
"RetrievalFallOut",
"RetrievalHitRate",
"RetrievalMAP",
"RetrievalMRR",
"RetrievalNormalizedDCG",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from torchmetrics.functional.regression.tweedie_deviance import tweedie_deviance_score
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.recall import retrieval_recall
Expand Down Expand Up @@ -114,6 +115,7 @@
"recall",
"retrieval_average_precision",
"retrieval_fall_out",
"retrieval_hit_rate",
"retrieval_normalized_dcg",
"retrieval_precision",
"retrieval_recall",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out # noqa: F401
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate # noqa: F401
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
Expand Down
10 changes: 8 additions & 2 deletions torchmetrics/functional/retrieval/fall_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_fall_out(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
def retrieval_fall_out(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
"""Computes the Fall-out (for information retrieval), as explained in `IR Fall-out`_ Fall-out is the fraction
of non-relevant documents retrieved among all the non-relevant documents.

Expand All @@ -28,11 +30,15 @@ def retrieval_fall_out(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
k: consider only the top k elements (default: None)
k: consider only the top k elements (default: None, which considers them all)

Returns:
a single-value tensor with the fall-out (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``.

Raises:
ValueError:
If ``k`` parameter is not `None` or an integer larger than 0

Example:
>>> from torchmetrics.functional import retrieval_fall_out
>>> preds = tensor([0.2, 0.3, 0.5])
Expand Down
57 changes: 57 additions & 0 deletions torchmetrics/functional/retrieval/hit_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_hit_rate(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
"""Computes the hit rate (for information retrieval). The hit rate is 1.0 if there is at least one relevant
document among all the top `k` retrieved documents.

``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
otherwise an error is raised. If you want to measure HitRate@K, ``k`` must be a positive integer.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
k: consider only the top k elements (default: None, which considers them all)

Returns:
a single-value tensor with the hit rate (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``.

Raises:
ValueError:
If ``k`` parameter is not `None` or an integer larger than 0

Example:
lucadiliello marked this conversation as resolved.
Show resolved Hide resolved
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_hit_rate(preds, target, k=2)
tensor(1.)
"""
preds, target = _check_retrieval_functional_inputs(preds, target)

if k is None:
k = preds.shape[-1]

if not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")

relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum()
return (relevant > 0).float()
6 changes: 5 additions & 1 deletion torchmetrics/functional/retrieval/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,15 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = N
Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document relevance.
k: consider only the top k elements (default: None)
k: consider only the top k elements (default: None, which considers them all)

Return:
a single-value tensor with the nDCG of the predictions ``preds`` w.r.t. the labels ``target``.

Raises:
ValueError:
If ``k`` parameter is not `None` or an integer larger than 0

Example:
>>> from torchmetrics.functional import retrieval_normalized_dcg
>>> preds = torch.tensor([.1, .2, .3, 4, 70])
Expand Down
6 changes: 5 additions & 1 deletion torchmetrics/functional/retrieval/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ def retrieval_precision(preds: Tensor, target: Tensor, k: Optional[int] = None)
Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
k: consider only the top k elements (default: None)
k: consider only the top k elements (default: None, which considers them all)

Returns:
a single-value tensor with the precision (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``.

Raises:
ValueError:
If ``k`` parameter is not `None` or an integer larger than 0

Example:
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
Expand Down
6 changes: 5 additions & 1 deletion torchmetrics/functional/retrieval/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ def retrieval_recall(preds: Tensor, target: Tensor, k: Optional[int] = None) ->
Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
k: consider only the top k elements (default: None)
k: consider only the top k elements (default: None, which considers them all)

Returns:
a single-value tensor with the recall (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``.

Raises:
ValueError:
If ``k`` parameter is not `None` or an integer larger than 0

Example:
>>> from torchmetrics.functional import retrieval_recall
>>> preds = tensor([0.2, 0.3, 0.5])
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchmetrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401
from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR # noqa: F401
from torchmetrics.retrieval.retrieval_fallout import RetrievalFallOut # noqa: F401
from torchmetrics.retrieval.retrieval_hit_rate import RetrievalHitRate # noqa: F401
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401
from torchmetrics.retrieval.retrieval_ndcg import RetrievalNormalizedDCG # noqa: F401
from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision # noqa: F401
Expand Down
Loading