Skip to content

Commit

Permalink
Mean Average Precision metric for Information Retrieval (1/5) (#5032)
Browse files Browse the repository at this point in the history
* init information retrieval metrics

* changed retrieval metrics names, expanded arguments and fixed typo

* added 'Retrieval' prefix to metrics and fixed conflict with already-present 'average_precision' file

* improved code formatting

* pep8 code compatibility

* features/implemented new Mean Average Precision metrics for Information Retrieval + doc

* fixed pep8 compatibility

* removed threshold parameter and fixed typo on types in RetrievalMAP and improved doc

* improved doc, put first class-specific args in RetrievalMetric and transformed RetrievalMetric in abstract class

* implemented tests for functional and class metric. fixed typo when input tensors are empty or when all targets are False

* fixed typos in doc and changed torch.true_divide to torch.div

* fixed typos pep8 compatibility

* fixed types in long division in ir_average_precision and example in mean_average_precision

* RetrievalMetric states are not lists and _metric method accepts predictions and targets for easier extension

* updated CHANGELOG file

* added '# noqa: F401' flag to not used imports

* added double space before '# noqa: F401' flag

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* change get_mini_groups in get_group_indexes

* added checks on target inputs

* minor refactoring for code cleanness

* split tests over exception raising in separate function && refactored test code into multiple functions

* fixed pep8 compatibility

* implemented suggestions of @SkafteNicki

* fixed imports for isort and added types annontations to functions in test_map.py

* isort on test_map and fixed typing

* isort on retrieval and on __init__.py and utils.py in metrics package

* fixed typo in pytorch_lightning/metrics/__init__.py regarding code style

* fixed yapf compatibility

* fixed yapf compatibility

* fixed typo in doc

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Mar 15, 2021
1 parent 06756a8 commit 5d73fbb
Show file tree
Hide file tree
Showing 12 changed files with 484 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `RetrievalMAP` metric, the corresponding functional version `retrieval_average_precision` and a generic superclass for retrieval metrics `RetrievalMetric` ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032))


- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))

- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
Expand Down
24 changes: 24 additions & 0 deletions docs/source/extensions/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,30 @@ bleu_score [func]
.. autofunction:: pytorch_lightning.metrics.functional.bleu_score
:noindex:

*****************************
Information Retrieval Metrics
*****************************

Class Metrics (IR)
------------------

Mean Average Precision
~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.retrieval.RetrievalMAP
:noindex:


Functional Metrics (IR)
-----------------------

average_precision_retrieval [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.ir_average_precision.retrieval_average_precision
:noindex:


********
Pairwise
********
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
R2Score,
SSIM,
)
from pytorch_lightning.metrics.retrieval import RetrievalMAP # noqa: F401
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401
from pytorch_lightning.metrics.functional.iou import iou # noqa: F401
from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision # noqa: F401
from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401
from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401
Expand Down
54 changes: 54 additions & 0 deletions pytorch_lightning/metrics/functional/ir_average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


def retrieval_average_precision(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""
Computes average precision (for information retrieval), as explained
`here <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`_.
`preds` and `target` should be of the same shape and live on the same device. If no `target` is ``True``,
0 is returned. Target must be of type `bool` or `int`, otherwise an error is raised.
Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not. Requires `bool` or `int` tensor.
Return:
a single-value tensor with the average precision (AP) of the predictions `preds` wrt the labels `target`.
Example:
>>> preds = torch.tensor([0.2, 0.3, 0.5])
>>> target = torch.tensor([True, False, True])
>>> retrieval_average_precision(preds, target)
tensor(0.8333)
"""

if preds.shape != target.shape or preds.device != target.device:
raise ValueError("`preds` and `target` must have the same shape and live on the same device")

if target.dtype not in (torch.bool, torch.int16, torch.int32, torch.int64):
raise ValueError("`target` must be a tensor of booleans or integers")

if target.dtype is not torch.bool:
target = target.bool()

if target.sum() == 0:
return torch.tensor(0, device=preds.device)

target = target[torch.argsort(preds, dim=-1, descending=True)]
positions = torch.arange(1, len(target) + 1, device=target.device, dtype=torch.float32)[target > 0]
res = torch.div((torch.arange(len(positions), device=positions.device, dtype=torch.float32) + 1), positions).mean()
return res
15 changes: 15 additions & 0 deletions pytorch_lightning/metrics/retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401
from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401
61 changes: 61 additions & 0 deletions pytorch_lightning/metrics/retrieval/mean_average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch

from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision
from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric


class RetrievalMAP(RetrievalMetric):
r"""
Computes `Mean Average Precision
<https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision>`_.
Works with binary data. Accepts integer or float predictions from a model output.
Forward accepts
- ``indexes`` (long tensor): ``(N, ...)``
- ``preds`` (float tensor): ``(N, ...)``
- ``target`` (long or bool tensor): ``(N, ...)``
`indexes`, `preds` and `target` must have the same dimension.
`indexes` indicate to which query a prediction belongs.
Predictions will be first grouped by indexes and then MAP will be computed as the mean
of the Average Precisions over each query.
Args:
query_without_relevant_docs:
Specify what to do with queries that do not have at least a positive target. Choose from:
- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
- ``'pos'``: score on those queries is counted as ``1.0``
- ``'neg'``: score on those queries is counted as ``0.0``
exclude:
Do not take into account predictions where the target is equal to this value. default `-100`
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects
the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
Example:
>>> from pytorch_lightning.metrics import RetrievalMAP
>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = torch.tensor([False, False, True, False, True, False, False])
>>> map = RetrievalMAP()
>>> map(indexes, preds, target)
tensor(0.7500)
>>> map.compute()
tensor(0.7500)
"""

def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
valid_indexes = target != self.exclude
return retrieval_average_precision(preds[valid_indexes], target[valid_indexes])
140 changes: 140 additions & 0 deletions pytorch_lightning/metrics/retrieval/retrieval_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional

import torch

from pytorch_lightning.metrics import Metric
from pytorch_lightning.metrics.utils import get_group_indexes

#: get_group_indexes is used to group predictions belonging to the same query

IGNORE_IDX = -100


class RetrievalMetric(Metric, ABC):
r"""
Works with binary data. Accepts integer or float predictions from a model output.
Forward accepts
- ``indexes`` (long tensor): ``(N, ...)``
- ``preds`` (float or int tensor): ``(N, ...)``
- ``target`` (long or bool tensor): ``(N, ...)``
`indexes`, `preds` and `target` must have the same dimension and will be flatten
to single dimension once provided.
`indexes` indicate to which query a prediction belongs.
Predictions will be first grouped by indexes. Then the
real metric, defined by overriding the `_metric` method,
will be computed as the mean of the scores over each query.
Args:
query_without_relevant_docs:
Specify what to do with queries that do not have at least a positive target. Choose from:
- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
- ``'pos'``: score on those queries is counted as ``1.0``
- ``'neg'``: score on those queries is counted as ``0.0``
exclude:
Do not take into account predictions where the target is equal to this value. default `-100`
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects
the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
"""

def __init__(
self,
query_without_relevant_docs: str = 'skip',
exclude: int = IGNORE_IDX,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn
)

query_without_relevant_docs_options = ('error', 'skip', 'pos', 'neg')
if query_without_relevant_docs not in query_without_relevant_docs_options:
raise ValueError(
f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}. "
f"Allowed values are {query_without_relevant_docs_options}"
)

self.query_without_relevant_docs = query_without_relevant_docs
self.exclude = exclude

self.add_state("idx", default=[], dist_reduce_fx=None)
self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("target", default=[], dist_reduce_fx=None)

def update(self, idx: torch.Tensor, preds: torch.Tensor, target: torch.Tensor) -> None:
if not (idx.shape == target.shape == preds.shape):
raise ValueError("`idx`, `preds` and `target` must be of the same shape")

idx = idx.to(dtype=torch.int64).flatten()
preds = preds.to(dtype=torch.float32).flatten()
target = target.to(dtype=torch.int64).flatten()

self.idx.append(idx)
self.preds.append(preds)
self.target.append(target)

def compute(self) -> torch.Tensor:
r"""
First concat state `idx`, `preds` and `target` since they were stored as lists. After that,
compute list of groups that will help in keeping together predictions about the same query.
Finally, for each group compute the `_metric` if the number of positive targets is at least
1, otherwise behave as specified by `self.query_without_relevant_docs`.
"""

idx = torch.cat(self.idx, dim=0)
preds = torch.cat(self.preds, dim=0)
target = torch.cat(self.target, dim=0)

res = []
kwargs = {'device': idx.device, 'dtype': torch.float32}

groups = get_group_indexes(idx)
for group in groups:

mini_preds = preds[group]
mini_target = target[group]

if not mini_target.sum():
if self.query_without_relevant_docs == 'error':
raise ValueError(
f"`{self.__class__.__name__}.compute()` was provided with "
f"a query without positive targets, indexes: {group}"
)
if self.query_without_relevant_docs == 'pos':
res.append(torch.tensor(1.0, **kwargs))
elif self.query_without_relevant_docs == 'neg':
res.append(torch.tensor(0.0, **kwargs))
else:
res.append(self._metric(mini_preds, mini_target))

if len(res) > 0:
return torch.stack(res).mean()
return torch.tensor(0.0, **kwargs)

@abstractmethod
def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""
Compute a metric over a predictions and target of a single group.
This method should be overridden by subclasses.
"""
31 changes: 30 additions & 1 deletion pytorch_lightning/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import torch

Expand Down Expand Up @@ -93,6 +93,35 @@ def _input_format_classification_one_hot(
return preds.reshape(num_classes, -1), target.reshape(num_classes, -1)


def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]:
"""
Given an integer `torch.Tensor` `idx`, return a `torch.Tensor` of indexes for
each different value in `idx`.
Args:
idx: a `torch.Tensor` of integers
Return:
A list of integer `torch.Tensor`s
Example:
>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
>>> groups = get_group_indexes(indexes)
>>> groups
[tensor([0, 1, 2]), tensor([3, 4, 5, 6])]
"""

indexes = dict()
for i, _id in enumerate(idx):
_id = _id.item()
if _id in indexes:
indexes[_id] += [i]
else:
indexes[_id] = [i]
return [torch.tensor(x, dtype=torch.int64) for x in indexes.values()]


def to_onehot(
label_tensor: torch.Tensor,
num_classes: Optional[int] = None,
Expand Down
Loading

0 comments on commit 5d73fbb

Please sign in to comment.