diff --git a/CHANGELOG.md b/CHANGELOG.md index 169e0e4597d..e1882145264 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Short Term Objective Intelligibility (`STOI`) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) +- Added `RetrievalRPrecision` metric to retrieval package ([#577](https://github.com/PyTorchLightning/metrics/pull/577/)) + + - Added `RetrievalHitRate` metric to retrieval package ([#576](https://github.com/PyTorchLightning/metrics/pull/576)) diff --git a/docs/source/links.rst b/docs/source/links.rst index 2560562131c..497157a7cdc 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -5,6 +5,7 @@ .. _Fall-out: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Fall-out .. _Normalized Discounted Cumulative Gain: https://en.wikipedia.org/wiki/Discounted_cumulative_gain .. _IR Precision: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Precision +.. _IR R-Precision: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#R-precision .. _IR Recall: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Recall .. _Accuracy: https://en.wikipedia.org/wiki/Accuracy_and_precision .. _SMAPE: https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index fbb582e3fc0..a11a0f30bf7 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -385,6 +385,13 @@ retrieval_precision [func] :noindex: +retrieval_r_precision [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.retrieval_r_precision + :noindex: + + retrieval_recall [func] ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 68bb911bbb1..69226e43fe2 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -554,6 +554,13 @@ RetrievalPrecision :noindex: +RetrievalRPrecision +~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.RetrievalRPrecision + :noindex: + + RetrievalRecall ~~~~~~~~~~~~~~~ diff --git a/tests/retrieval/test_r_precision.py b/tests/retrieval/test_r_precision.py new file mode 100644 index 00000000000..7822991b9a5 --- /dev/null +++ b/tests/retrieval/test_r_precision.py @@ -0,0 +1,136 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest +from torch import Tensor + +from tests.helpers import seed_all +from tests.retrieval.helpers import ( + RetrievalMetricTester, + _concat_tests, + _default_metric_class_input_arguments, + _default_metric_functional_input_arguments, + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_functional_metric_parameters_default, +) +from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision +from torchmetrics.retrieval.retrieval_r_precision import RetrievalRPrecision + +seed_all(42) + + +def _r_precision(target: np.ndarray, preds: np.ndarray): + """Didn't find a reliable implementation of R-Precision in Information Retrieval, so, reimplementing here. + + A good explanation can be found + `here _`. + """ + assert target.shape == preds.shape + assert len(target.shape) == 1 # works only with single dimension inputs + + if target.sum() > 0: + order_indexes = np.argsort(preds, axis=0)[::-1] + relevant = np.sum(target[order_indexes][: target.sum()]) + return relevant * 1.0 / target.sum() + return np.NaN + + +class TestRPrecision(RetrievalMetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) + def test_class_metric( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + ): + metric_args = {"empty_target_action": empty_target_action} + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalRPrecision, + sk_metric=_r_precision, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) + + @pytest.mark.parametrize(**_default_metric_functional_input_arguments) + def test_functional_metric(self, preds: Tensor, target: Tensor): + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=retrieval_r_precision, + sk_metric=_r_precision, + metric_args={}, + ) + + @pytest.mark.parametrize(**_default_metric_class_input_arguments) + def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor): + self.run_precision_test_cpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalRPrecision, + metric_functional=retrieval_r_precision, + ) + + @pytest.mark.parametrize(**_default_metric_class_input_arguments) + def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): + self.run_precision_test_gpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalRPrecision, + metric_functional=retrieval_r_precision, + ) + + @pytest.mark.parametrize( + **_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + ) + ) + def test_arguments_class_metric( + self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict + ): + self.run_metric_class_arguments_test( + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalRPrecision, + message=message, + metric_args=metric_args, + exception_type=ValueError, + kwargs_update={}, + ) + + @pytest.mark.parametrize(**_errors_test_functional_metric_parameters_default) + def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict): + self.run_functional_metric_arguments_test( + preds=preds, + target=target, + metric_functional=retrieval_r_precision, + message=message, + exception_type=ValueError, + kwargs_update=metric_args, + ) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 8ab7fe64c58..7e1e111a24f 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -63,6 +63,7 @@ RetrievalNormalizedDCG, RetrievalPrecision, RetrievalRecall, + RetrievalRPrecision, ) from torchmetrics.text import WER, BERTScore, BLEUScore, ROUGEScore, SacreBLEUScore # noqa: E402 from torchmetrics.wrappers import BootStrapper, MetricTracker, MultioutputWrapper # noqa: E402 @@ -123,6 +124,7 @@ "RetrievalNormalizedDCG", "RetrievalPrecision", "RetrievalRecall", + "RetrievalRPrecision", "ROC", "ROUGEScore", "SacreBLEUScore", diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 911857a7ced..98091530686 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -61,6 +61,7 @@ from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg from torchmetrics.functional.retrieval.precision import retrieval_precision +from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision from torchmetrics.functional.retrieval.recall import retrieval_recall from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank from torchmetrics.functional.self_supervised import embedding_similarity @@ -118,6 +119,7 @@ "retrieval_hit_rate", "retrieval_normalized_dcg", "retrieval_precision", + "retrieval_r_precision", "retrieval_recall", "retrieval_reciprocal_rank", "roc", diff --git a/torchmetrics/functional/retrieval/__init__.py b/torchmetrics/functional/retrieval/__init__.py index 0b3d2caa420..d9d31a41f43 100644 --- a/torchmetrics/functional/retrieval/__init__.py +++ b/torchmetrics/functional/retrieval/__init__.py @@ -17,5 +17,6 @@ from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate # noqa: F401 from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg # noqa: F401 from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401 +from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision # noqa: F401 from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401 from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401 diff --git a/torchmetrics/functional/retrieval/r_precision.py b/torchmetrics/functional/retrieval/r_precision.py new file mode 100644 index 00000000000..ed2787086f0 --- /dev/null +++ b/torchmetrics/functional/retrieval/r_precision.py @@ -0,0 +1,49 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor, tensor + +from torchmetrics.utilities.checks import _check_retrieval_functional_inputs + + +def retrieval_r_precision(preds: Tensor, target: Tensor) -> Tensor: + """Computes the r-precision metric (for information retrieval). R-Precision is the fraction of relevant + documents among all the top ``k`` retrieved documents where ``k`` is equal to the total number of relevant + documents. + + ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, + ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`, + otherwise an error is raised. If you want to measure Precision@K, ``k`` must be a positive integer. + + Args: + preds: estimated probabilities of each document to be relevant. + target: ground truth about each document being relevant or not. + + Returns: + a single-value tensor with the r-precision of the predictions ``preds`` w.r.t. the labels ``target``. + + Example: + >>> preds = tensor([0.2, 0.3, 0.5]) + >>> target = tensor([True, False, True]) + >>> retrieval_r_precision(preds, target) + tensor(0.5000) + """ + preds, target = _check_retrieval_functional_inputs(preds, target) + + relevant_number = target.sum() + if not relevant_number: + return tensor(0.0, device=preds.device) + + relevant = target[torch.argsort(preds, dim=-1, descending=True)][:relevant_number].sum().float() + return relevant / relevant_number diff --git a/torchmetrics/retrieval/__init__.py b/torchmetrics/retrieval/__init__.py index a245edf8163..84b3ad83d19 100644 --- a/torchmetrics/retrieval/__init__.py +++ b/torchmetrics/retrieval/__init__.py @@ -18,4 +18,5 @@ from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401 from torchmetrics.retrieval.retrieval_ndcg import RetrievalNormalizedDCG # noqa: F401 from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision # noqa: F401 +from torchmetrics.retrieval.retrieval_r_precision import RetrievalRPrecision # noqa: F401 from torchmetrics.retrieval.retrieval_recall import RetrievalRecall # noqa: F401 diff --git a/torchmetrics/retrieval/retrieval_r_precision.py b/torchmetrics/retrieval/retrieval_r_precision.py new file mode 100644 index 00000000000..f46be5d43b3 --- /dev/null +++ b/torchmetrics/retrieval/retrieval_r_precision.py @@ -0,0 +1,70 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torch import Tensor, tensor + +from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision +from torchmetrics.retrieval.retrieval_metric import RetrievalMetric + + +class RetrievalRPrecision(RetrievalMetric): + """Computes `IR R-Precision`_. + + Works with binary target data. Accepts float predictions from a model output. + + Forward accepts: + + - ``preds`` (float tensor): ``(N, ...)`` + - ``target`` (long or bool tensor): ``(N, ...)`` + - ``indexes`` (long tensor): ``(N, ...)`` + + ``indexes``, ``preds`` and ``target`` must have the same dimension. + ``indexes`` indicate to which query a prediction belongs. + Predictions will be first grouped by ``indexes`` and then `R-Precision` will be computed as the mean + of the `R-Precision` over each query. + + Args: + empty_target_action: + Specify what to do with queries that do not have at least a positive ``target``. Choose from: + + - ``'neg'``: those queries count as ``0.0`` (default) + - ``'pos'``: those queries count as ``1.0`` + - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned + - ``'error'``: raise a ``ValueError`` + + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects + the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + + Example: + >>> from torchmetrics import RetrievalRPrecision + >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) + >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) + >>> target = tensor([False, False, True, False, True, False, True]) + >>> p2 = RetrievalRPrecision() + >>> p2(preds, target, indexes=indexes) + tensor(0.7500) + """ + + higher_is_better = True + + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: + return retrieval_r_precision(preds, target)