Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Information Retrieval (2/5) #119

Merged
merged 48 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
c4facab
fixed typo in comments
lucadiliello Mar 22, 2021
aec5b9b
Merge branch 'master' into feature-ir_map
lucadiliello Mar 22, 2021
be731fc
implemented Mean Reciprocal Rank. Added checks on input dtypes also t…
lucadiliello Mar 22, 2021
e29440d
fixed merge conflicts. move reciprocal rank functional version to ded…
lucadiliello Mar 22, 2021
93e35e5
fixed flask8 and isort compatibility
lucadiliello Mar 22, 2021
27935a9
added tests for MRR and refactored tests for MAP
lucadiliello Mar 22, 2021
8c9513d
fixed flask8 compatibility
lucadiliello Mar 22, 2021
40d92eb
added docs entries. added other arguments dtype checks
lucadiliello Mar 22, 2021
1f984f1
moved docs entries to specific section
lucadiliello Mar 22, 2021
88e0877
added tests with wrong input dtypes
lucadiliello Mar 22, 2021
f898da9
improved test code coverage
lucadiliello Mar 22, 2021
5811772
improved test code coverage
lucadiliello Mar 22, 2021
bf8a72e
improved test code coverage
lucadiliello Mar 22, 2021
a4541ff
improved test code coverage
lucadiliello Mar 22, 2021
a02e1ef
Merge branch 'master' into feature-ir_map
lucadiliello Mar 24, 2021
fe7c65a
moved constants to global, refactores and reordered tests
lucadiliello Mar 24, 2021
fc9f395
Merge branch 'feature-ir_map' of https://github.com/lucadiliello/metr…
lucadiliello Mar 24, 2021
d6ec72d
Merge branch 'master' into feature-ir_map
lucadiliello Mar 24, 2021
e06ed9e
Merge branch 'master' into feature-ir_map
lucadiliello Mar 24, 2021
b9ce579
fixed compatibility with torch 1.4.0 and improved error messages in t…
lucadiliello Mar 24, 2021
8194ceb
Merge branch 'feature-ir_map' of https://github.com/lucadiliello/metr…
lucadiliello Mar 24, 2021
631322a
implemented improvements requested by @SkafteNicki
lucadiliello Mar 24, 2021
ef69fae
format
Borda Mar 24, 2021
4395410
separated retrieval metric test files and removed useless loops
lucadiliello Mar 24, 2021
68bb040
separated retrieval metric test files and removed useless loops
lucadiliello Mar 24, 2021
72cf61a
fix tests
Borda Mar 25, 2021
e6ceea9
update CHANGELOG, implemented pytest.raises to check exceptions in tests
lucadiliello Mar 25, 2021
766ad40
merge conflicts resolved
lucadiliello Mar 25, 2021
103e3c3
restricted allowed dtypes for IR metrics and updated doc
lucadiliello Mar 25, 2021
0915ba5
fixed merge conflicts
lucadiliello Mar 25, 2021
b4fa99b
Apply suggestions from code review
Borda Mar 25, 2021
c26c41f
Apply suggestions from code review
Borda Mar 25, 2021
d7a3f25
removed lightning seed_everything
lucadiliello Mar 25, 2021
27bab89
Merge branch 'feature-ir_map' of https://github.com/lucadiliello/metr…
lucadiliello Mar 25, 2021
cdddbd8
fixed code format
lucadiliello Mar 25, 2021
68f961e
fixed typo in doc
lucadiliello Mar 25, 2021
f89a235
changed in
lucadiliello Mar 25, 2021
72dd620
fixed typos in doc
lucadiliello Mar 25, 2021
405d4fb
fixed typos in doc and test on windows
lucadiliello Mar 25, 2021
cbcf32a
fixed typos in doc
lucadiliello Mar 25, 2021
d52351a
Merge branch 'master' into feature-ir_map
lucadiliello Mar 25, 2021
d7df917
updated seeding fn
lucadiliello Mar 25, 2021
2b9d7db
fixed typo in tests regarding randomness
lucadiliello Mar 25, 2021
5a0e02f
Merge branch 'master' into feature-ir_map
lucadiliello Mar 25, 2021
e6fde35
Merge branch 'master' into feature-ir_map
lucadiliello Mar 25, 2021
69716ca
tensor
Borda Mar 25, 2021
c4eaf00
tensor
Borda Mar 25, 2021
6ed117c
Apply suggestions from code review
Borda Mar 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `RetrievalMAP` metric for Information Retrieval ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032))


- Added `RetrievalMRR` metric for Information Retrieval ([#119](https://github.com/PyTorchLightning/metrics/pull/119))


- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))


Expand Down
24 changes: 17 additions & 7 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,6 @@ stat_scores [func]
:noindex:


retrieval_average_precision [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.retrieval_average_precision
:noindex:


to_categorical [func]
~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -242,3 +235,20 @@ embedding_similarity [func]

.. autofunction:: torchmetrics.functional.embedding_similarity
:noindex:

*********
Retrieval
*********

retrieval_average_precision [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.retrieval_average_precision
:noindex:


retrieval_reciprocal_rank [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.retrieval_reciprocal_rank
:noindex:
72 changes: 65 additions & 7 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,6 @@ StatScores
:noindex:


RetrievalMAP
~~~~~~~~~~~~

.. autoclass:: torchmetrics.RetrievalMAP
:noindex:


******************
Regression Metrics
******************
Expand Down Expand Up @@ -276,3 +269,68 @@ R2Score

.. autoclass:: torchmetrics.R2Score
:noindex:


*********
Retrieval
*********

Input details
~~~~~~~~~~~~~

For the purposes of retrieval metrics, inputs (indexes, predictions and targets) must have the same size
(``N`` stands for the batch size) and the following types:

.. csv-table::
:header: "indexes shape", "indexes dtype", "preds shape", "preds dtype", "target shape", "target dtype"
:widths: 10, 10, 10, 10, 10, 10

"``long``", "(N,...)", "``float``", "(N,...)", "``long`` or ``bool``", "(N,...)"

.. note::
All dimensions are flattened at the beginning, so
that, for example, a tensor of shape ``(N, M)`` is treated as ``(N * M, )``.

In Information Retrieval you have a query that is compared with a variable number of documents. For each pair ``(Q_i, D_j)``,
lucadiliello marked this conversation as resolved.
Show resolved Hide resolved
a score is computed that measures the relevance of document ``D`` w.r.t. query ``Q``. Documents are then sorted by score
and you hope that relevant documents are scored higher. ``target`` contains the labels for the documents (relevant or not).

Since a query may be compared with a variable number of documents, we use ``indexes`` to keep track of which scores belong to
the set of pairs ``(Q_i, D_j)`` having the same query ``Q_i``.

.. doctest::

>>> from torchmetrics import RetrievalMAP
>>> # functional version works on a single query at a time
>>> from torchmetrics.functional import retrieval_average_precision

>>> # the first query was compared with two documents, the second with three
>>> indexes = torch.tensor([0, 0, 1, 1, 1])
>>> preds = torch.tensor([0.8, -0.4, 1.0, 1.4, 0.0])
>>> target = torch.tensor([0, 1, 0, 1, 1])

>>> map = RetrievalMAP() # or some other retrieval metric
>>> map(indexes, preds, target)
tensor(0.6667)

>>> # the previous instruction is roughly equivalent to
>>> res = []
>>> # iterate over indexes of first and second query
>>> for idx in ([0, 1], [2, 3, 4]):
... res.append(retrieval_average_precision(preds[idx], target[idx]))
>>> torch.stack(res).mean()
tensor(0.6667)


RetrievalMAP
~~~~~~~~~~~~

.. autoclass:: torchmetrics.RetrievalMAP
:noindex:


RetrievalMRR
~~~~~~~~~~~~

.. autoclass:: torchmetrics.RetrievalMRR
:noindex:
93 changes: 79 additions & 14 deletions tests/functional/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,91 @@
import torch
from sklearn.metrics import average_precision_score as sk_average_precision

from tests.helpers import seed_all
from tests.retrieval.test_mrr import _reciprocal_rank as reciprocal_rank
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank

seed_all(1337)


@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(sk_average_precision, retrieval_average_precision),
[sk_average_precision, retrieval_average_precision],
[reciprocal_rank, retrieval_reciprocal_rank],
])
@pytest.mark.parametrize("size", [1, 4, 10])
def test_metrics_output_values(sklearn_metric, torch_metric, size):
""" Compare PL metrics to sklearn version. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# test results are computed correctly wrt std implementation
for i in range(6):
preds = np.random.randn(size)
target = np.random.randn(size) > 0

# sometimes test with integer targets
if (i % 2) == 0:
target = target.astype(np.int)

sk = torch.tensor(sklearn_metric(target, preds), device=device)
tm = torch_metric(torch.tensor(preds, device=device), torch.tensor(target, device=device))

# `torch_metric`s return 0 when no label is True
# while `sklearn` metrics returns NaN
if math.isnan(sk):
assert tm == 0
else:
assert torch.allclose(sk.float(), tm.float())


@pytest.mark.parametrize(['torch_metric'], [
[retrieval_average_precision],
[retrieval_reciprocal_rank],
])
def test_input_dtypes(torch_metric) -> None:
""" Check wrong input dtypes are managed correctly. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
length = 10 # not important in this case

# check target is binary
preds = torch.tensor([0.0, 1.0] * length, device=device, dtype=torch.float32)
target = torch.tensor([-1, 2] * length, device=device, dtype=torch.int64)

with pytest.raises(ValueError, match="`target` must be of type `binary`"):
Borda marked this conversation as resolved.
Show resolved Hide resolved
torch_metric(preds, target)

# check dtypes and empty target
preds = torch.tensor([0] * length, device=device, dtype=torch.float32)
target = torch.tensor([0] * length, device=device, dtype=torch.int64)

# check error on input dtypes are raised correctly
with pytest.raises(ValueError, match="`preds` must be a tensor of floats"):
torch_metric(preds.bool(), target)
with pytest.raises(ValueError, match="`target` must be a tensor of booleans or integers"):
torch_metric(preds, target.float())

# test checks on empty targets
assert torch.allclose(torch_metric(preds=preds, target=target), torch.tensor(0.0))


@pytest.mark.parametrize(['torch_metric'], [
[retrieval_average_precision],
[retrieval_reciprocal_rank],
])
@pytest.mark.parametrize("size", [1, 4, 10, 100])
def test_against_sklearn(sklearn_metric, torch_metric, size):
"""Compare PL metrics to sklearn version. """
def test_input_shapes(torch_metric) -> None:
""" Check wrong input shapes are managed correctly. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'

a = np.random.randn(size)
b = np.random.randn(size) > 0
# test with empty tensors
preds = torch.tensor([0] * 0, device=device, dtype=torch.float)
target = torch.tensor([0] * 0, device=device, dtype=torch.int64)
with pytest.raises(ValueError, match="`preds` and `target` must be non-empty"):
torch_metric(preds, target)

sk = torch.tensor(sklearn_metric(b, a), device=device)
pl = torch_metric(torch.tensor(a, device=device), torch.tensor(b, device=device))
# test checks when shapes are different
elements_1, elements_2 = np.random.choice(np.arange(1, 20), size=2, replace=False) # ensure sizes are different
preds = torch.tensor([0] * elements_1, device=device, dtype=torch.float)
target = torch.tensor([0] * elements_2, device=device, dtype=torch.int64)

# `torch_metric`s return 0 when no label is True
# while `sklearn.average_precision_score` returns NaN
if math.isnan(sk):
assert pl == 0
else:
assert torch.allclose(sk.float(), pl.float())
with pytest.raises(ValueError, match="`preds` and `target` must be of the same shape"):
torch_metric(preds, target)
Empty file added tests/retrieval/__init__.py
Empty file.
126 changes: 126 additions & 0 deletions tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import Callable, List

import numpy as np
import pytest
import torch
from torch import Tensor

from tests.helpers import seed_all

seed_all(1337)


def _compute_sklearn_metric(
metric: Callable, target: List[np.ndarray], preds: List[np.ndarray], behaviour: str
) -> Tensor:
""" Compute metric with multiple iterations over every query predictions set. """
sk_results = []

for b, a in zip(target, preds):
if b.sum() == 0:
if behaviour == 'skip':
pass
elif behaviour == 'pos':
sk_results.append(1.0)
else:
sk_results.append(0.0)
else:
res = metric(b, a)
sk_results.append(res)

if len(sk_results) > 0:
return np.mean(sk_results)
return np.array(0.0)


def _test_retrieval_against_sklearn(
sklearn_metric,
torch_metric,
size,
n_documents,
query_without_relevant_docs_options
) -> None:
""" Compare PL metrics to standard version. """
metric = torch_metric(query_without_relevant_docs=query_without_relevant_docs_options)
shape = (size, )

indexes = []
preds = []
target = []

for i in range(n_documents):
indexes.append(np.ones(shape, dtype=np.long) * i)
preds.append(np.random.randn(*shape))
target.append(np.random.randn(*shape) > 0)

sk_results = _compute_sklearn_metric(sklearn_metric, target, preds, query_without_relevant_docs_options)
sk_results = torch.tensor(sk_results)

indexes_tensor = torch.cat([torch.tensor(i) for i in indexes]).long()
preds_tensor = torch.cat([torch.tensor(p) for p in preds]).float()
target_tensor = torch.cat([torch.tensor(t) for t in target]).long()

# lets assume data are not ordered
perm = torch.randperm(indexes_tensor.nelement())
indexes_tensor = indexes_tensor.view(-1)[perm].view(indexes_tensor.size())
preds_tensor = preds_tensor.view(-1)[perm].view(preds_tensor.size())
target_tensor = target_tensor.view(-1)[perm].view(target_tensor.size())

# shuffle ids to require also sorting of documents ability from the torch metric
pl_result = metric(indexes_tensor, preds_tensor, target_tensor)

assert torch.allclose(sk_results.float(), pl_result.float(), equal_nan=False), (
f"Test failed comparing metric {sklearn_metric} with {torch_metric}: "
f"{sk_results.float()} vs {pl_result.float()}. "
f"indexes: {indexes}, preds: {preds}, target: {target}"
)


def _test_dtypes(torchmetric) -> None:
"""Check PL metrics inputs are controlled correctly. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
length = 10 # not important in this test

# check error when `query_without_relevant_docs='error'` is raised correctly
indexes = torch.tensor([0] * length, device=device, dtype=torch.int64)
preds = torch.rand(size=(length, ), device=device, dtype=torch.float32)
target = torch.tensor([False] * length, device=device, dtype=torch.bool)

metric = torchmetric(query_without_relevant_docs='error')
with pytest.raises(ValueError, match="`compute` method was provided with a query with no positive target."):
metric(indexes, preds, target)

# check ValueError with invalid `query_without_relevant_docs` argument
casual_argument = 'casual_argument'
with pytest.raises(ValueError, match=f"`query_without_relevant_docs` received a wrong value {casual_argument}."):
metric = torchmetric(query_without_relevant_docs=casual_argument)

# check input dtypes
indexes = torch.tensor([0] * length, device=device, dtype=torch.int64)
preds = torch.tensor([0] * length, device=device, dtype=torch.float32)
target = torch.tensor([0] * length, device=device, dtype=torch.int64)

metric = torchmetric(query_without_relevant_docs='error')

# check error on input dtypes are raised correctly
with pytest.raises(ValueError, match="`indexes` must be a tensor of long integers"):
metric(indexes.bool(), preds, target)
with pytest.raises(ValueError, match="`preds` must be a tensor of floats"):
metric(indexes, preds.bool(), target)
with pytest.raises(ValueError, match="`target` must be a tensor of booleans or integers"):
metric(indexes, preds, target.float())


def _test_input_shapes(torchmetric) -> None:
"""Check PL metrics inputs are controlled correctly. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
metric = torchmetric(query_without_relevant_docs='error')

# check input shapes are checked correclty
elements_1, elements_2 = np.random.choice(np.arange(1, 20), size=2, replace=False)
indexes = torch.tensor([0] * elements_1, device=device, dtype=torch.int64)
preds = torch.tensor([0] * elements_2, device=device, dtype=torch.float32)
target = torch.tensor([0] * elements_2, device=device, dtype=torch.int64)

with pytest.raises(ValueError, match="`indexes`, `preds` and `target` must be of the same shape"):
metric(indexes, preds, target)
Loading