Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Information Retrieval (6/5) #161

Merged
merged 37 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c8f0d80
init transition to standard metric interface for IR metrics
lucadiliello Mar 31, 2021
cc976b0
fixed typo in dtypes checks
lucadiliello Apr 1, 2021
8555c78
removed IGNORE_IDX, refactored tests using
lucadiliello Apr 3, 2021
34bcefa
added pep8 compatibility
lucadiliello Apr 3, 2021
fbaf05f
fixed np.ndarray to np.array
lucadiliello Apr 3, 2021
644477e
remove lambda functions
lucadiliello Apr 3, 2021
36ce60a
fixed typos with numpy dtype
lucadiliello Apr 3, 2021
44e2db6
fixed typo in doc example
lucadiliello Apr 3, 2021
f516e42
fixed typo in doc examples about new indexes position
lucadiliello Apr 3, 2021
1496d1b
added paramter to class testing to divide kwargs as preds and targets…
lucadiliello Apr 4, 2021
cc42311
added typo in doc example
lucadiliello Apr 4, 2021
8bb4830
added typo with new parameter frament_kwargs in MetricTester
lucadiliello Apr 4, 2021
e7c7e96
added typo in .cpu() conversion of non-tensor values
lucadiliello Apr 4, 2021
40fd75b
improved test coverage
lucadiliello Apr 4, 2021
7b3d2f8
improved test coverage
lucadiliello Apr 4, 2021
01c43de
added check on Tensor class to avoid calling .cpu() on non-tensor values
lucadiliello Apr 4, 2021
bb62f51
improved doc and changed default values for 'empty_target_action' arg…
lucadiliello Apr 5, 2021
9d65265
implemented fall-out
lucadiliello Apr 5, 2021
148a908
refactored tests lists
lucadiliello Apr 5, 2021
93c87a5
Merge branch 'refactor-IR-test' into feature-ir_fallout
lucadiliello Apr 5, 2021
57ec03b
refactoring and fixed typos in tests
lucadiliello Apr 5, 2021
f918a80
Merge branch 'master' into refactor-IR-test
Borda Apr 6, 2021
62c93f5
formatting
Borda Apr 6, 2021
5b2ce58
simple
Borda Apr 6, 2021
3a3871b
agrs
Borda Apr 6, 2021
7a3f8c9
format
Borda Apr 6, 2021
d350765
_sk
Borda Apr 6, 2021
c796557
fixed typo in tests
lucadiliello Apr 6, 2021
367f5b0
Merge branch 'refactor-IR-test' into feature-ir_fallout
lucadiliello Apr 6, 2021
6b1ed29
updated tests experiments from tuple to dict
lucadiliello Apr 6, 2021
0256157
fixed typo in doc
lucadiliello Apr 6, 2021
0fb5c35
fixed merge with master
lucadiliello Apr 6, 2021
44f7537
Apply suggestions from code review
Borda Apr 6, 2021
873df5c
chlog
Borda Apr 6, 2021
31e4d2f
Update tests/retrieval/test_fallout.py
lucadiliello Apr 6, 2021
a1c83e2
Update torchmetrics/functional/retrieval/fall_out.py
lucadiliello Apr 6, 2021
4d7f644
added metric link to wikipedia
lucadiliello Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,10 @@ retrieval_recall [func]

.. autofunction:: torchmetrics.functional.retrieval_recall
:noindex:


retrieval_fall_out [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. autofunction:: torchmetrics.functional.retrieval_fall_out
:noindex:
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,13 @@ RetrievalRecall
:noindex:


RetrievalFallOut
~~~~~~~~~~~~~~~~~~
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. autoclass:: torchmetrics.RetrievalFallOut
:noindex:


********
Wrappers
********
Expand Down
12 changes: 6 additions & 6 deletions tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _compute_sklearn_metric(
indexes: np.ndarray = None,
metric: Callable = None,
empty_target_action: str = "skip",
reverse: bool = False,
**kwargs
) -> Tensor:
""" Compute metric with multiple iterations over every query predictions set. """
Expand All @@ -67,7 +68,7 @@ def _compute_sklearn_metric(
for group in groups:
trg, pds = target[group], preds[group]

if trg.sum() == 0:
if ((1 - trg) if reverse else trg).sum() == 0:
if empty_target_action == 'skip':
pass
elif empty_target_action == 'pos':
Expand Down Expand Up @@ -268,8 +269,9 @@ def run_class_metric_test(
sk_metric: Callable,
dist_sync_on_step: bool,
metric_args: dict,
reverse: bool = False,
):
_sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, **metric_args)
_sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, reverse=reverse, **metric_args)

super().run_class_metric_test(
ddp=ddp,
Expand All @@ -290,10 +292,10 @@ def run_functional_metric_test(
metric_functional: Callable,
sk_metric: Callable,
metric_args: dict,
reverse: bool = False,
**kwargs,
):
# action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive.
_sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, **metric_args)
_sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, reverse=reverse, **metric_args)

super().run_functional_metric_test(
preds=preds,
Expand All @@ -313,7 +315,6 @@ def run_precision_test_cpu(
metric_module: Metric,
metric_functional: Callable,
):
# action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive.
def metric_functional_ignore_indexes(preds, target, indexes):
return metric_functional(preds, target)

Expand All @@ -337,7 +338,6 @@ def run_precision_test_gpu(
if not torch.cuda.is_available():
pytest.skip()

# action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive.
def metric_functional_ignore_indexes(preds, target, indexes):
return metric_functional(preds, target)

Expand Down
151 changes: 151 additions & 0 deletions tests/retrieval/test_fallout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
from torch import Tensor

from tests.helpers import seed_all
from tests.retrieval.helpers import (
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_k,
_errors_test_class_metric_parameters_no_neg_target,
_errors_test_functional_metric_parameters_default,
_errors_test_functional_metric_parameters_k,
)
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out
from torchmetrics.retrieval.retrieval_fallout import RetrievalFallOut

seed_all(42)


def _fallout_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
"""
Didn't find a reliable implementation of Fall-out in Information Retrieval, so,
reimplementing here. See wikipedia for more information about definition.
lucadiliello marked this conversation as resolved.
Show resolved Hide resolved
"""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs

if k is None:
k = len(preds)
lucadiliello marked this conversation as resolved.
Show resolved Hide resolved

target = 1 - target
if target.sum():
order_indexes = np.argsort(preds, axis=0)[::-1]
relevant = np.sum(target[order_indexes][:k])
return relevant * 1.0 / target.sum()
else:
return np.NaN


class TestFallOut(RetrievalMetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos'])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
k: int,
):
metric_args = {'empty_target_action': empty_target_action, 'k': k}

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalFallOut,
sk_metric=_fallout_at_k,
dist_sync_on_step=dist_sync_on_step,
reverse=True,
metric_args=metric_args,
)

@pytest.mark.parametrize(**_default_metric_functional_input_arguments)
@pytest.mark.parametrize("k", [None, 1, 4, 10])
def test_functional_metric(self, preds: Tensor, target: Tensor, k: int):
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=retrieval_fall_out,
sk_metric=_fallout_at_k,
reverse=True,
metric_args={},
k=k,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_cpu(
indexes=indexes,
preds=preds,
target=target,
metric_module=RetrievalFallOut,
metric_functional=retrieval_fall_out,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_gpu(
indexes=indexes,
preds=preds,
target=target,
metric_module=RetrievalFallOut,
metric_functional=retrieval_fall_out,
)

@pytest.mark.parametrize(**_concat_tests(
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_neg_target,
_errors_test_class_metric_parameters_k,
))
def test_arguments_class_metric(
self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict
):
self.run_metric_class_arguments_test(
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalFallOut,
message=message,
metric_args=metric_args,
exception_type=ValueError,
kwargs_update={},
)

@pytest.mark.parametrize(**_concat_tests(
_errors_test_functional_metric_parameters_default,
_errors_test_functional_metric_parameters_k,
))
def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict):
self.run_functional_metric_arguments_test(
preds=preds,
target=target,
metric_functional=retrieval_fall_out,
message=message,
exception_type=ValueError,
kwargs_update=metric_args,
)
8 changes: 7 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,11 @@
MeanSquaredLogError,
R2Score,
)
from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR, RetrievalPrecision, RetrievalRecall # noqa: F401 E402
from torchmetrics.retrieval import ( # noqa: F401 E402
RetrievalFallOut,
RetrievalMAP,
RetrievalMRR,
RetrievalPrecision,
RetrievalRecall,
)
from torchmetrics.wrappers import BootStrapper # noqa: F401 E402
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from torchmetrics.functional.regression.r2score import r2score # noqa: F401
from torchmetrics.functional.regression.ssim import ssim # noqa: F401
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
59 changes: 59 additions & 0 deletions torchmetrics/functional/retrieval/fall_out.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_fall_out(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
"""
Computes the Fall-out (for information retrieval),
as explained `here <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Fall-out>`__.
Fall-out is the fraction of non-relevant documents retrieved among all the non-relevant documents.

``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
otherwise an error is raised. If you want to measure Fall-out@K, ``k`` must be a positive integer.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
k: consider only the top k elements (default: None)

Returns:
a single-value tensor with the fall-out (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``.

Example:
>>> from torchmetrics.functional import retrieval_fall_out
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_fall_out(preds, target, k=2)
tensor(1.)
"""
preds, target = _check_retrieval_functional_inputs(preds, target)

if k is None:
k = preds.shape[-1]
lucadiliello marked this conversation as resolved.
Show resolved Hide resolved

if not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")

target = (1 - target)

if not target.sum():
return tensor(0.0, device=preds.device)

relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum().float()
return relevant / target.sum()
2 changes: 1 addition & 1 deletion torchmetrics/functional/retrieval/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def retrieval_recall(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
"""
Computes the recall metric (for information retrieval),
as explained `here <https://en.wikipedia.org/wiki/Precision_and_recall#Recall>`__.
Recall is the fraction of relevant documents among all the relevant documents.
Recall is the fraction of relevant documents retrieved among all the relevant documents.

``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401
from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR # noqa: F401
from torchmetrics.retrieval.retrieval_fallout import RetrievalFallOut # noqa: F401
from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401
from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision # noqa: F401
from torchmetrics.retrieval.retrieval_recall import RetrievalRecall # noqa: F401
Loading