Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented R-Precision for IR #577

Merged
merged 8 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Short Term Objective Intelligibility (`STOI`) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))


- Added `RetrievalRPrecision` metric to retrieval package ([#577](https://github.com/PyTorchLightning/metrics/pull/577/))


- Added `RetrievalHitRate` metric to retrieval package ([#576](https://github.com/PyTorchLightning/metrics/pull/576))


Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.. _Fall-out: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Fall-out
.. _Normalized Discounted Cumulative Gain: https://en.wikipedia.org/wiki/Discounted_cumulative_gain
.. _IR Precision: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Precision
.. _IR R-Precision: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#R-precision
.. _IR Recall: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Recall
.. _Accuracy: https://en.wikipedia.org/wiki/Accuracy_and_precision
.. _SMAPE: https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error
Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ retrieval_precision [func]
:noindex:


retrieval_r_precision [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.retrieval_r_precision
:noindex:


retrieval_recall [func]
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,13 @@ RetrievalPrecision
:noindex:


RetrievalRPrecision
~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.RetrievalRPrecision
:noindex:


RetrievalRecall
~~~~~~~~~~~~~~~

Expand Down
136 changes: 136 additions & 0 deletions tests/retrieval/test_r_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
from torch import Tensor

from tests.helpers import seed_all
from tests.retrieval.helpers import (
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_pos_target,
_errors_test_functional_metric_parameters_default,
)
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.retrieval.retrieval_r_precision import RetrievalRPrecision

seed_all(42)


def _r_precision(target: np.ndarray, preds: np.ndarray):
"""Didn't find a reliable implementation of R-Precision in Information Retrieval, so, reimplementing here.

A good explanation can be found
`here <https://web.stanford.edu/class/cs276/handouts/EvaluationNew-handout-1-per.pdf>_`.
"""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs

if target.sum() > 0:
order_indexes = np.argsort(preds, axis=0)[::-1]
relevant = np.sum(target[order_indexes][: target.sum()])
return relevant * 1.0 / target.sum()
return np.NaN


class TestRPrecision(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
):
metric_args = {"empty_target_action": empty_target_action}

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalRPrecision,
sk_metric=_r_precision,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

@pytest.mark.parametrize(**_default_metric_functional_input_arguments)
def test_functional_metric(self, preds: Tensor, target: Tensor):
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=retrieval_r_precision,
sk_metric=_r_precision,
metric_args={},
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_cpu(
indexes=indexes,
preds=preds,
target=target,
metric_module=RetrievalRPrecision,
metric_functional=retrieval_r_precision,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_gpu(
indexes=indexes,
preds=preds,
target=target,
metric_module=RetrievalRPrecision,
metric_functional=retrieval_r_precision,
)

@pytest.mark.parametrize(
**_concat_tests(
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_pos_target,
)
)
def test_arguments_class_metric(
self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict
):
self.run_metric_class_arguments_test(
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalRPrecision,
message=message,
metric_args=metric_args,
exception_type=ValueError,
kwargs_update={},
)

@pytest.mark.parametrize(**_errors_test_functional_metric_parameters_default)
def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict):
self.run_functional_metric_arguments_test(
preds=preds,
target=target,
metric_functional=retrieval_r_precision,
message=message,
exception_type=ValueError,
kwargs_update=metric_args,
)
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
RetrievalNormalizedDCG,
RetrievalPrecision,
RetrievalRecall,
RetrievalRPrecision,
)
from torchmetrics.text import WER, BERTScore, BLEUScore, ROUGEScore, SacreBLEUScore # noqa: E402
from torchmetrics.wrappers import BootStrapper, MetricTracker, MultioutputWrapper # noqa: E402
Expand Down Expand Up @@ -123,6 +124,7 @@
"RetrievalNormalizedDCG",
"RetrievalPrecision",
"RetrievalRecall",
"RetrievalRPrecision",
"ROC",
"ROUGEScore",
"SacreBLEUScore",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
from torchmetrics.functional.self_supervised import embedding_similarity
Expand Down Expand Up @@ -118,6 +119,7 @@
"retrieval_hit_rate",
"retrieval_normalized_dcg",
"retrieval_precision",
"retrieval_r_precision",
"retrieval_recall",
"retrieval_reciprocal_rank",
"roc",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate # noqa: F401
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
49 changes: 49 additions & 0 deletions torchmetrics/functional/retrieval/r_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_r_precision(preds: Tensor, target: Tensor) -> Tensor:
"""Computes the r-precision metric (for information retrieval). R-Precision is the fraction of relevant
documents among all the top ``k`` retrieved documents where ``k`` is equal to the total number of relevant
documents.

``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
otherwise an error is raised. If you want to measure Precision@K, ``k`` must be a positive integer.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.

Returns:
a single-value tensor with the r-precision of the predictions ``preds`` w.r.t. the labels ``target``.

Example:
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_r_precision(preds, target)
tensor(0.5000)
"""
preds, target = _check_retrieval_functional_inputs(preds, target)

relevant_number = target.sum()
if not relevant_number:
return tensor(0.0, device=preds.device)

relevant = target[torch.argsort(preds, dim=-1, descending=True)][:relevant_number].sum().float()
return relevant / relevant_number
1 change: 1 addition & 0 deletions torchmetrics/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401
from torchmetrics.retrieval.retrieval_ndcg import RetrievalNormalizedDCG # noqa: F401
from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision # noqa: F401
from torchmetrics.retrieval.retrieval_r_precision import RetrievalRPrecision # noqa: F401
from torchmetrics.retrieval.retrieval_recall import RetrievalRecall # noqa: F401
70 changes: 70 additions & 0 deletions torchmetrics/retrieval/retrieval_r_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch import Tensor, tensor

from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric


class RetrievalRPrecision(RetrievalMetric):
"""Computes `IR R-Precision`_.

Works with binary target data. Accepts float predictions from a model output.

Forward accepts:

- ``preds`` (float tensor): ``(N, ...)``
- ``target`` (long or bool tensor): ``(N, ...)``
- ``indexes`` (long tensor): ``(N, ...)``

``indexes``, ``preds`` and ``target`` must have the same dimension.
``indexes`` indicate to which query a prediction belongs.
Predictions will be first grouped by ``indexes`` and then `R-Precision` will be computed as the mean
of the `R-Precision` over each query.

Args:
empty_target_action:
Specify what to do with queries that do not have at least a positive ``target``. Choose from:

- ``'neg'``: those queries count as ``0.0`` (default)
- ``'pos'``: those queries count as ``1.0``
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``

compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects
the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None

Example:
>>> from torchmetrics import RetrievalRPrecision
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> p2 = RetrievalRPrecision()
>>> p2(preds, target, indexes=indexes)
tensor(0.7500)
"""

higher_is_better = True

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_r_precision(preds, target)