From c8f0d8090e2d29d743c9d53c5432afd8652ce02f Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Wed, 31 Mar 2021 21:14:49 +0200 Subject: [PATCH 01/33] init transition to standard metric interface for IR metrics --- .gitignore | 3 + tests/retrieval/helpers.py | 113 ++++++++++++++++----- tests/retrieval/inputs.py | 26 +++++ tests/retrieval/test_map.py | 13 +++ tests/retrieval/test_map_mt.py | 62 +++++++++++ tests/retrieval/test_mrr.py | 13 +++ tests/retrieval/test_precision.py | 13 +++ tests/retrieval/test_recall.py | 13 +++ torchmetrics/retrieval/retrieval_metric.py | 11 +- torchmetrics/utilities/data.py | 20 +++- 10 files changed, 256 insertions(+), 31 deletions(-) create mode 100644 tests/retrieval/inputs.py create mode 100644 tests/retrieval/test_map_mt.py diff --git a/.gitignore b/.gitignore index d078f439dca..e3ac9e11a1b 100644 --- a/.gitignore +++ b/.gitignore @@ -91,6 +91,9 @@ ENV/ env.bak/ venv.bak/ +# Editor configs +.vscode/ + # Spyder project settings .spyderproject .spyproject diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 3adc74e0556..cf1c187d8a9 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -1,4 +1,18 @@ -from typing import Callable, List +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.utilities.data import get_group_indexes +from typing import Callable, Union import numpy as np import pytest @@ -7,26 +21,43 @@ from tests.helpers import seed_all from torchmetrics import Metric +from tests.helpers.testers import MetricTester seed_all(1337) def _compute_sklearn_metric( - metric: Callable, target: List[np.ndarray], preds: List[np.ndarray], behaviour: str, **kwargs + preds: Union[Tensor, np.ndarray], + target: Union[Tensor, np.ndarray], + idx: np.ndarray = None, + metric: Callable = None, + empty_target_action: str = "skip", + **kwargs ) -> Tensor: """ Compute metric with multiple iterations over every query predictions set. """ + + if isinstance(preds, Tensor): + preds = preds.cpu().numpy() + if isinstance(target, Tensor): + target = target.cpu().numpy() + + if idx is None: + idx = np.full_like(preds, fill_value=0, dtype=np.int64) + + groups = get_group_indexes(idx) sk_results = [] + for group in groups: + trg, pds = target[group], preds[group] - for b, a in zip(target, preds): - if b.sum() == 0: - if behaviour == 'skip': + if trg.sum() == 0: + if empty_target_action == 'skip': pass - elif behaviour == 'pos': + elif empty_target_action == 'pos': sk_results.append(1.0) else: sk_results.append(0.0) else: - res = metric(b, a, **kwargs) + res = metric(trg, pds, **kwargs) sk_results.append(res) if len(sk_results) > 0: @@ -44,23 +75,20 @@ def _test_retrieval_against_sklearn( ) -> None: """ Compare PL metrics to standard version. """ metric = torch_metric(empty_target_action=empty_target_action, **kwargs) - shape = (size, ) + shape = (n_documents, size) - indexes = [] - preds = [] - target = [] + indexes = np.ones(shape, dtype=np.int64) * np.arange(n_documents) + preds = np.random.randn(*shape) + target = np.random.randint(0, 2, size=shape) - for i in range(n_documents): - indexes.append(np.ones(shape, dtype=np.long) * i) - preds.append(np.random.randn(*shape)) - target.append(np.random.randn(*shape) > 0) - - sk_results = _compute_sklearn_metric(sklearn_metric, target, preds, empty_target_action, **kwargs) + sk_results = _compute_sklearn_metric( + preds, target, metric=sklearn_metric, empty_target_action=empty_target_action, **kwargs + ) sk_results = torch.tensor(sk_results) - indexes_tensor = torch.cat([torch.tensor(i) for i in indexes]).long() - preds_tensor = torch.cat([torch.tensor(p) for p in preds]).float() - target_tensor = torch.cat([torch.tensor(t) for t in target]).long() + indexes_tensor = torch.tensor(indexes).long() + preds_tensor = torch.tensor(preds).float() + target_tensor = torch.tensor(target).long() # lets assume data are not ordered perm = torch.randperm(indexes_tensor.nelement()) @@ -69,7 +97,7 @@ def _test_retrieval_against_sklearn( target_tensor = target_tensor.view(-1)[perm].view(target_tensor.size()) # shuffle ids to require also sorting of documents ability from the torch metric - pl_result = metric(indexes_tensor, preds_tensor, target_tensor) + pl_result = metric(preds_tensor, target_tensor, idx=indexes_tensor) assert torch.allclose(sk_results.float(), pl_result.float(), equal_nan=False), ( f"Test failed comparing metric {sklearn_metric} with {torch_metric}: " @@ -90,7 +118,7 @@ def _test_dtypes(torchmetric) -> None: metric = torchmetric(empty_target_action='error') with pytest.raises(ValueError, match="`compute` method was provided with a query with no positive target."): - metric(indexes, preds, target) + metric(preds, target, idx=indexes) # check ValueError with invalid `empty_target_action` argument casual_argument = 'casual_argument' @@ -106,11 +134,11 @@ def _test_dtypes(torchmetric) -> None: # check error on input dtypes are raised correctly with pytest.raises(ValueError, match="`indexes` must be a tensor of long integers"): - metric(indexes.bool(), preds, target) + metric(preds, target, idx=indexes.bool()) with pytest.raises(ValueError, match="`preds` must be a tensor of floats"): - metric(indexes, preds.bool(), target) + metric(preds.bool(), target, idx=indexes) with pytest.raises(ValueError, match="`target` must be a tensor of booleans or integers"): - metric(indexes, preds, target.float()) + metric(preds, target.float(), idx=indexes) def _test_input_shapes(torchmetric) -> None: @@ -125,10 +153,43 @@ def _test_input_shapes(torchmetric) -> None: target = torch.tensor([0] * elements_2, device=device, dtype=torch.int64) with pytest.raises(ValueError, match="`indexes`, `preds` and `target` must be of the same shape"): - metric(indexes, preds, target) + metric(preds, target, idx=indexes) def _test_input_args(torchmetric: Metric, message: str, **kwargs) -> None: """Check invalid args are managed correctly. """ with pytest.raises(ValueError, match=message): torchmetric(**kwargs) + + + + + + + +class RetrievalMetricTester(MetricTester): + + """ + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=RetrievalMAP, + sk_metric=sk_metric, + dist_sync_on_step=dist_sync_on_step, + ) + """ + + def test_average_precision_functional(self, preds, target, sk_metric): + self.run_functional_metric_test( + preds, + target, + metric_functional=retrieval_average_precision, + sk_metric=sk_metric, + ) + + def test_a_caso(self, preds, target, sk_metric): + assert False \ No newline at end of file diff --git a/tests/retrieval/inputs.py b/tests/retrieval/inputs.py new file mode 100644 index 00000000000..6b1dab83ece --- /dev/null +++ b/tests/retrieval/inputs.py @@ -0,0 +1,26 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +import torch + +from tests.helpers.testers import NUM_BATCHES, BATCH_SIZE + +Input = namedtuple('InputMultiple', ["preds", "target", "idx"]) + +_input_retrieval_scores = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + idx=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)) +) diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index e866ce597ae..d021820958f 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from sklearn.metrics import average_precision_score as sk_average_precision diff --git a/tests/retrieval/test_map_mt.py b/tests/retrieval/test_map_mt.py new file mode 100644 index 00000000000..65c15a33a65 --- /dev/null +++ b/tests/retrieval/test_map_mt.py @@ -0,0 +1,62 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +from tests.retrieval.helpers import _test_dtypes, _test_input_shapes, _test_retrieval_against_sklearn +from torchmetrics.retrieval.mean_average_precision import RetrievalMAP +from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision + +import pytest +from sklearn.metrics import average_precision_score as sk_average_precision_score + +from tests.retrieval.inputs import _input_retrieval_scores +from tests.retrieval.helpers import _compute_sklearn_metric + +from tests.helpers import seed_all +from tests.helpers.testers import MetricTester + +seed_all(42) + + +@pytest.mark.parametrize( + "preds, target, idx, sk_metric", [ + (_input_retrieval_scores.preds, _input_retrieval_scores.target, _input_retrieval_scores.idx, sk_average_precision_score), + ] +) +class TestRetrievalMetric(MetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos']) + def test_average_precision(self, preds, target, idx, sk_metric, ddp, dist_sync_on_step, empty_target_action): + _sk_metric = partial(_compute_sklearn_metric, metric=sk_metric) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=RetrievalMAP, + sk_metric=_sk_metric, + dist_sync_on_step=dist_sync_on_step, + metric_args={'empty_target_action': empty_target_action}, + idx=idx + ) + + def test_average_precision_functional(self, preds, target, idx, sk_metric): + _sk_metric = partial(_compute_sklearn_metric, metric=sk_metric, empty_target_action="neg", idx=None) + self.run_functional_metric_test( + preds, + target, + metric_functional=retrieval_average_precision, + sk_metric=_sk_metric, + ) diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 62d79be0578..07f05c42145 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import numpy as np import pytest from sklearn.metrics import label_ranking_average_precision_score diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index 1282b521f11..af309f7af6a 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import numpy as np import pytest diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index 31b9c15ee75..4d8393f218d 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import numpy as np import pytest diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index e9296554855..7df6773d97d 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -88,13 +88,22 @@ def __init__( self.empty_target_action = empty_target_action self.exclude = exclude + self.next_index = 0 self.add_state("idx", default=[], dist_reduce_fx=None) self.add_state("preds", default=[], dist_reduce_fx=None) self.add_state("target", default=[], dist_reduce_fx=None) - def update(self, idx: Tensor, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor, idx: Tensor = None) -> None: """ Check shape, check and convert dtypes, flatten and add to accumulators. """ + if idx is None: + idx = torch.full(preds.shape, fill_value=self.next_index, dtype=torch.long, device=preds.device) + + # update index + actual_max_id = torch.max(idx).item() + if actual_max_id > self.next_index: + self.next_index = actual_max_id + idx, preds, target = _check_retrieval_inputs(idx, preds, target, ignore=IGNORE_IDX) self.idx.append(idx.flatten()) self.preds.append(preds.flatten()) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 1ee0bc72367..c58d442d73b 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Any, Callable, List, Mapping, Optional, Sequence, Union +import numpy as np import torch from torch import Tensor, tensor @@ -232,22 +233,33 @@ def apply_to_collection( def get_group_indexes(idx: Tensor) -> List[Tensor]: """ - Given an integer `torch.Tensor` `idx`, return a `torch.Tensor` of indexes for + Given an integer `torch.Tensor` or `np.array` `idx`, return a `torch.Tensor` or `np.array` of indexes for each different value in `idx`. Args: - idx: a `torch.Tensor` of integers + idx: a `torch.Tensor` or `np.array` of integers Return: - A list of integer `torch.Tensor`s + A list of integer `torch.Tensor`s or `np.array`s Example: >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) >>> groups = get_group_indexes(indexes) >>> groups [tensor([0, 1, 2]), tensor([3, 4, 5, 6])] + >>> + >>> indexes = np.ndarray([0, 0, 0, 1, 1, 1, 1]) + >>> groups = get_group_indexes(indexes) + >>> groups + [array([0, 1, 2]), array([3, 4, 5, 6])] """ + if not isinstance(idx, (Tensor, np.ndarray)): + raise ValueError("`idx` must be a torch tensor or numpy array") + + structure = tensor if isinstance(idx, Tensor) else np.array + dtype = torch.long if isinstance(idx, Tensor) else np.int64 + indexes = dict() for i, _id in enumerate(idx): _id = _id.item() @@ -255,4 +267,4 @@ def get_group_indexes(idx: Tensor) -> List[Tensor]: indexes[_id] += [i] else: indexes[_id] = [i] - return [tensor(x, dtype=torch.int64) for x in indexes.values()] + return [structure(x, dtype=dtype) for x in indexes.values()] From cc976b04469d958f16f1fb0b9b1e663d4e5e22a2 Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Thu, 1 Apr 2021 10:55:56 +0200 Subject: [PATCH 02/33] fixed typo in dtypes checks --- tests/retrieval/test_map_mt.py | 5 +++-- torchmetrics/utilities/checks.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/retrieval/test_map_mt.py b/tests/retrieval/test_map_mt.py index 65c15a33a65..d402818171b 100644 --- a/tests/retrieval/test_map_mt.py +++ b/tests/retrieval/test_map_mt.py @@ -29,9 +29,10 @@ seed_all(42) +@pytest.mark.parametrize("sk_metric", [sk_average_precision_score]) @pytest.mark.parametrize( - "preds, target, idx, sk_metric", [ - (_input_retrieval_scores.preds, _input_retrieval_scores.target, _input_retrieval_scores.idx, sk_average_precision_score), + "preds, target, idx", [ + (_input_retrieval_scores.preds, _input_retrieval_scores.target, _input_retrieval_scores.idx), ] ) class TestRetrievalMetric(MetricTester): diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 84424b4624a..f5d8f027467 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -554,6 +554,8 @@ def _check_retrieval_inputs( """ if ignore is not None: target = target[target != ignore] # ignore check on values that are ignored + preds = preds[target != ignore] + preds, target = _check_retrieval_functional_inputs(preds, target) if indexes.shape != target.shape: From 8555c78e024cc342105dfd5d4464788fb7132579 Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sat, 3 Apr 2021 22:11:59 +0200 Subject: [PATCH 03/33] removed IGNORE_IDX, refactored tests using --- docs/source/references/modules.rst | 4 +- tests/functional/test_retrieval.py | 150 ------ tests/helpers/testers.py | 44 +- tests/retrieval/helpers.py | 476 +++++++++++++----- tests/retrieval/inputs.py | 44 +- tests/retrieval/test_map.py | 132 ++++- tests/retrieval/test_map_mt.py | 63 --- tests/retrieval/test_mrr.py | 134 ++++- tests/retrieval/test_precision.py | 160 +++++- tests/retrieval/test_recall.py | 165 +++++- .../retrieval/mean_average_precision.py | 6 +- .../retrieval/mean_reciprocal_rank.py | 5 +- torchmetrics/retrieval/retrieval_metric.py | 36 +- torchmetrics/retrieval/retrieval_precision.py | 9 +- torchmetrics/retrieval/retrieval_recall.py | 9 +- torchmetrics/utilities/checks.py | 38 +- torchmetrics/utilities/data.py | 38 +- 17 files changed, 1002 insertions(+), 511 deletions(-) delete mode 100644 tests/functional/test_retrieval.py delete mode 100644 tests/retrieval/test_map_mt.py diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 23714f62f6e..b2cfdc5b92c 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -316,8 +316,8 @@ the set of pairs ``(Q_i, D_j)`` having the same query ``Q_i``. >>> # the previous instruction is roughly equivalent to >>> res = [] >>> # iterate over indexes of first and second query - >>> for idx in ([0, 1], [2, 3, 4]): - ... res.append(retrieval_average_precision(preds[idx], target[idx])) + >>> for indexes in ([0, 1], [2, 3, 4]): + ... res.append(retrieval_average_precision(preds[indexes], target[indexes])) >>> torch.stack(res).mean() tensor(0.6667) diff --git a/tests/functional/test_retrieval.py b/tests/functional/test_retrieval.py deleted file mode 100644 index 96cde9e87d4..00000000000 --- a/tests/functional/test_retrieval.py +++ /dev/null @@ -1,150 +0,0 @@ -import math - -import numpy as np -import pytest -import torch -from sklearn.metrics import average_precision_score as sk_average_precision - -from tests.helpers import seed_all -from tests.retrieval.test_mrr import _reciprocal_rank as reciprocal_rank -from tests.retrieval.test_precision import _precision_at_k as precision_at_k -from tests.retrieval.test_recall import _recall_at_k as recall_at_k -from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision -from torchmetrics.functional.retrieval.precision import retrieval_precision -from torchmetrics.functional.retrieval.recall import retrieval_recall -from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank - -seed_all(1337) - - -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ - [sk_average_precision, retrieval_average_precision], - [reciprocal_rank, retrieval_reciprocal_rank], -]) -@pytest.mark.parametrize("size", [1, 4, 10]) -def test_metrics_output_values(sklearn_metric, torch_metric, size): - """ Compare PL metrics to sklearn version. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # test results are computed correctly wrt std implementation - for i in range(6): - preds = np.random.randn(size) - target = np.random.randn(size) > 0 - - # sometimes test with integer targets - if (i % 2) == 0: - target = target.astype(np.int) - - sk = torch.tensor(sklearn_metric(target, preds), device=device) - tm = torch_metric(torch.tensor(preds, device=device), torch.tensor(target, device=device)) - - # `torch_metric`s return 0 when no label is True - # while `sklearn` metrics returns NaN - if math.isnan(sk): - assert tm == 0 - else: - assert torch.allclose(sk.float(), tm.float()) - - -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ - [precision_at_k, retrieval_precision], - [recall_at_k, retrieval_recall], -]) -@pytest.mark.parametrize("size", [1, 4, 10]) -@pytest.mark.parametrize("k", [None, 1, 4, 10]) -def test_metrics_output_values_with_k(sklearn_metric, torch_metric, size, k): - """ Compare PL metrics to sklearn version. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # test results are computed correctly wrt std implementation - for i in range(6): - preds = np.random.randn(size) - target = np.random.randn(size) > 0 - - # sometimes test with integer targets - if (i % 2) == 0: - target = target.astype(np.int) - - sk = torch.tensor(sklearn_metric(target, preds, k), device=device) - tm = torch_metric(torch.tensor(preds, device=device), torch.tensor(target, device=device), k) - - # `torch_metric`s return 0 when no label is True - # while `sklearn` metrics returns NaN - if math.isnan(sk): - assert tm == 0 - else: - assert torch.allclose(sk.float(), tm.float()) - - -@pytest.mark.parametrize(['torch_metric'], [ - [retrieval_average_precision], - [retrieval_reciprocal_rank], - [retrieval_precision], - [retrieval_recall], -]) -def test_input_dtypes(torch_metric) -> None: - """ Check wrong input dtypes are managed correctly. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - length = 10 # not important in this case - - # check target is binary - preds = torch.tensor([0.0, 1.0] * length, device=device, dtype=torch.float32) - target = torch.tensor([-1, 2] * length, device=device, dtype=torch.int64) - - with pytest.raises(ValueError, match="`target` must be of type `binary`"): - torch_metric(preds, target) - - # check dtypes and empty target - preds = torch.tensor([0] * length, device=device, dtype=torch.float32) - target = torch.tensor([0] * length, device=device, dtype=torch.int64) - - # check error on input dtypes are raised correctly - with pytest.raises(ValueError, match="`preds` must be a tensor of floats"): - torch_metric(preds.bool(), target) - with pytest.raises(ValueError, match="`target` must be a tensor of booleans or integers"): - torch_metric(preds, target.float()) - - # test checks on empty targets - assert torch.allclose(torch_metric(preds=preds, target=target), torch.tensor(0.0)) - - -@pytest.mark.parametrize(['torch_metric'], ( - [retrieval_average_precision], - [retrieval_reciprocal_rank], - [retrieval_precision], - [retrieval_recall], -)) -def test_input_shapes(torch_metric) -> None: - """ Check wrong input shapes are managed correctly. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # test with empty tensors - preds = torch.tensor([0] * 0, device=device, dtype=torch.float) - target = torch.tensor([0] * 0, device=device, dtype=torch.int64) - with pytest.raises(ValueError, match="`preds` and `target` must be non-empty"): - torch_metric(preds, target) - - # test checks when shapes are different - elements_1, elements_2 = np.random.choice(np.arange(1, 20), size=2, replace=False) # ensure sizes are different - preds = torch.tensor([0] * elements_1, device=device, dtype=torch.float) - target = torch.tensor([0] * elements_2, device=device, dtype=torch.int64) - - with pytest.raises(ValueError, match="`preds` and `target` must be of the same shape"): - torch_metric(preds, target) - - -# test metrics using top K parameter -@pytest.mark.parametrize(['torch_metric'], [ - [retrieval_precision], - [retrieval_recall], -]) -@pytest.mark.parametrize('k', [-1, 1.0]) -def test_input_params(torch_metric, k) -> None: - """ Check wrong input shapes are managed correctly. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # test with random tensors - preds = torch.tensor([0] * 4, device=device, dtype=torch.float) - target = torch.tensor([0] * 4, device=device, dtype=torch.int64) - with pytest.raises(ValueError, match="`k` has to be a positive integer or None"): - torch_metric(preds, target, k=k) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index ec6b0196846..cf338e9200f 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -129,18 +129,20 @@ def _class_test( batch_result = metric(preds[i], target[i], **batch_kwargs_update) if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0: - ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]).cpu() - ddp_target = torch.cat([target[i + r] for r in range(worldsize)]).cpu() + rank_indexes = [i + r for r in range(worldsize)] + + ddp_preds = preds[rank_indexes].cpu() + ddp_target = target[rank_indexes].cpu() ddp_kwargs_upd = { - k: torch.cat([v[i + r] for r in range(worldsize)]).cpu() if isinstance(v, Tensor) else v - for k, v in batch_kwargs_update.items() + k: v[rank_indexes].cpu() if isinstance(v, Tensor) else v + for k, v in kwargs_update.items() } sk_batch_result = sk_metric(ddp_preds, ddp_target, **ddp_kwargs_upd) _assert_allclose(batch_result, sk_batch_result, atol=atol) elif check_batch and not metric.dist_sync_on_step: - batch_kwargs_update = {k: v.cpu() for k, v in kwargs_update.items()} + batch_kwargs_update = {k: v.cpu() for k, v in batch_kwargs_update.items()} sk_batch_result = sk_metric(preds[i].cpu(), target[i].cpu(), **batch_kwargs_update) _assert_allclose(batch_result, sk_batch_result, atol=atol) @@ -196,7 +198,6 @@ def _functional_test( for i in range(NUM_BATCHES): extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} lightning_result = metric(preds[i], target[i], **extra_kwargs) - extra_kwargs = {k: v.cpu() for k, v in kwargs_update.items()} sk_result = sk_metric(preds[i].cpu(), target[i].cpu(), **extra_kwargs) # assert its the same @@ -209,6 +210,7 @@ def _assert_half_support( preds: torch.Tensor, target: torch.Tensor, device: str = "cpu", + **kwargs_update ): """ Test if an metric can be used with half precision tensors @@ -219,12 +221,18 @@ def _assert_half_support( preds: torch tensor with predictions target: torch tensor with targets device: determine device, either "cpu" or "cuda" + kwargs_update: Additional keyword arguments that will be passed with preds and + target when running update on the metric. """ y_hat = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device) y = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device) + kwargs_update = { + k: (v[0].half() if v.is_floating_point() else v[0]).to(device) if isinstance(v, Tensor) else v + for k, v in kwargs_update.items() + } metric_module = metric_module.to(device) - _assert_tensor(metric_module(y_hat, y)) - _assert_tensor(metric_functional(y_hat, y)) + _assert_tensor(metric_module(y_hat, y, **kwargs_update)) + _assert_tensor(metric_functional(y_hat, y, **kwargs_update)) class MetricTester: @@ -367,6 +375,7 @@ def run_precision_test_cpu( metric_module: Metric, metric_functional: Callable, metric_args: dict = {}, + **kwargs_update, ): """Test if an metric can be used with half precision tensors on cpu Args: @@ -375,9 +384,16 @@ def run_precision_test_cpu( metric_module: the metric module to test metric_functional: the metric functional to test metric_args: dict with additional arguments used for class initialization + kwargs_update: Additional keyword arguments that will be passed with preds and + target when running update on the metric. """ _assert_half_support( - metric_module(**metric_args), partial(metric_functional, **metric_args), preds, target, device="cpu" + metric_module(**metric_args), + metric_functional, + preds, + target, + device="cpu", + **kwargs_update ) def run_precision_test_gpu( @@ -387,6 +403,7 @@ def run_precision_test_gpu( metric_module: Metric, metric_functional: Callable, metric_args: dict = {}, + **kwargs_update, ): """Test if an metric can be used with half precision tensors on gpu Args: @@ -395,9 +412,16 @@ def run_precision_test_gpu( metric_module: the metric module to test metric_functional: the metric functional to test metric_args: dict with additional arguments used for class initialization + kwargs_update: Additional keyword arguments that will be passed with preds and + target when running update on the metric. """ _assert_half_support( - metric_module(**metric_args), partial(metric_functional, **metric_args), preds, target, device="cuda" + metric_module(**metric_args), + metric_functional, + preds, + target, + device="cuda", + **kwargs_update ) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index cf1c187d8a9..1471cee9602 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -11,40 +11,59 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.utilities.data import get_group_indexes +from functools import partial from typing import Callable, Union import numpy as np import pytest import torch +from numpy import array from torch import Tensor from tests.helpers import seed_all -from torchmetrics import Metric -from tests.helpers.testers import MetricTester +from tests.helpers.testers import Metric, MetricTester +from tests.retrieval.inputs import ( + _input_retrieval_scores, + _input_retrieval_scores_empty, + _input_retrieval_scores_extra, + _input_retrieval_scores_mismatching_sizes, + _input_retrieval_scores_mismatching_sizes_func, + _input_retrieval_scores_no_target, + _input_retrieval_scores_wrong_targets, +) +from torchmetrics.utilities.data import get_group_indexes -seed_all(1337) +seed_all(42) def _compute_sklearn_metric( - preds: Union[Tensor, np.ndarray], - target: Union[Tensor, np.ndarray], - idx: np.ndarray = None, + preds: Union[Tensor, array], + target: Union[Tensor, array], + indexes: np.ndarray = None, metric: Callable = None, empty_target_action: str = "skip", **kwargs ) -> Tensor: """ Compute metric with multiple iterations over every query predictions set. """ + if indexes is None: + indexes = np.full_like(preds, fill_value=0, dtype=np.int64) + if isinstance(indexes, Tensor): + indexes = indexes.cpu().numpy() if isinstance(preds, Tensor): preds = preds.cpu().numpy() if isinstance(target, Tensor): target = target.cpu().numpy() + + assert isinstance(indexes, np.ndarray) + assert isinstance(preds, np.ndarray) + assert isinstance(target, np.ndarray) - if idx is None: - idx = np.full_like(preds, fill_value=0, dtype=np.int64) + indexes = indexes.flatten() + preds = preds.flatten() + target = target.flatten() + groups = get_group_indexes(indexes) - groups = get_group_indexes(idx) sk_results = [] for group in groups: trg, pds = target[group], preds[group] @@ -65,131 +84,336 @@ def _compute_sklearn_metric( return np.array(0.0) -def _test_retrieval_against_sklearn( - sklearn_metric: Callable, - torch_metric: Metric, - size: int, - n_documents: int, - empty_target_action: str, - **kwargs -) -> None: - """ Compare PL metrics to standard version. """ - metric = torch_metric(empty_target_action=empty_target_action, **kwargs) - shape = (n_documents, size) - - indexes = np.ones(shape, dtype=np.int64) * np.arange(n_documents) - preds = np.random.randn(*shape) - target = np.random.randint(0, 2, size=shape) - - sk_results = _compute_sklearn_metric( - preds, target, metric=sklearn_metric, empty_target_action=empty_target_action, **kwargs - ) - sk_results = torch.tensor(sk_results) - - indexes_tensor = torch.tensor(indexes).long() - preds_tensor = torch.tensor(preds).float() - target_tensor = torch.tensor(target).long() - - # lets assume data are not ordered - perm = torch.randperm(indexes_tensor.nelement()) - indexes_tensor = indexes_tensor.view(-1)[perm].view(indexes_tensor.size()) - preds_tensor = preds_tensor.view(-1)[perm].view(preds_tensor.size()) - target_tensor = target_tensor.view(-1)[perm].view(target_tensor.size()) - - # shuffle ids to require also sorting of documents ability from the torch metric - pl_result = metric(preds_tensor, target_tensor, idx=indexes_tensor) - - assert torch.allclose(sk_results.float(), pl_result.float(), equal_nan=False), ( - f"Test failed comparing metric {sklearn_metric} with {torch_metric}: " - f"{sk_results.float()} vs {pl_result.float()}. " - f"indexes: {indexes}, preds: {preds}, target: {target}" - ) - - -def _test_dtypes(torchmetric) -> None: - """Check PL metrics inputs are controlled correctly. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - length = 10 # not important in this test - - # check error when `empty_target_action='error'` is raised correctly - indexes = torch.tensor([0] * length, device=device, dtype=torch.int64) - preds = torch.rand(size=(length, ), device=device, dtype=torch.float32) - target = torch.tensor([False] * length, device=device, dtype=torch.bool) - - metric = torchmetric(empty_target_action='error') - with pytest.raises(ValueError, match="`compute` method was provided with a query with no positive target."): - metric(preds, target, idx=indexes) - - # check ValueError with invalid `empty_target_action` argument - casual_argument = 'casual_argument' - with pytest.raises(ValueError, match=f"`empty_target_action` received a wrong value {casual_argument}."): - metric = torchmetric(empty_target_action=casual_argument) - - # check input dtypes - indexes = torch.tensor([0] * length, device=device, dtype=torch.int64) - preds = torch.tensor([0] * length, device=device, dtype=torch.float32) - target = torch.tensor([0] * length, device=device, dtype=torch.int64) - - metric = torchmetric(empty_target_action='error') - - # check error on input dtypes are raised correctly - with pytest.raises(ValueError, match="`indexes` must be a tensor of long integers"): - metric(preds, target, idx=indexes.bool()) - with pytest.raises(ValueError, match="`preds` must be a tensor of floats"): - metric(preds.bool(), target, idx=indexes) - with pytest.raises(ValueError, match="`target` must be a tensor of booleans or integers"): - metric(preds, target.float(), idx=indexes) - - -def _test_input_shapes(torchmetric) -> None: - """Check PL metrics inputs are controlled correctly. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - metric = torchmetric(empty_target_action='error') - - # check input shapes are checked correclty - elements_1, elements_2 = np.random.choice(np.arange(1, 20), size=2, replace=False) - indexes = torch.tensor([0] * elements_1, device=device, dtype=torch.int64) - preds = torch.tensor([0] * elements_2, device=device, dtype=torch.float32) - target = torch.tensor([0] * elements_2, device=device, dtype=torch.int64) - - with pytest.raises(ValueError, match="`indexes`, `preds` and `target` must be of the same shape"): - metric(preds, target, idx=indexes) - - -def _test_input_args(torchmetric: Metric, message: str, **kwargs) -> None: - """Check invalid args are managed correctly. """ - with pytest.raises(ValueError, match=message): - torchmetric(**kwargs) - - - - - +_errors_test_functional_metric_parameters = [ + "preds, target, message", [ + # check input shapes are consistent (func) + ( + _input_retrieval_scores_mismatching_sizes_func.preds, + _input_retrieval_scores_mismatching_sizes_func.target, + "`preds` and `target` must be of the same shape", + ), + # check input tensors are not empty + ( + _input_retrieval_scores_empty.preds, + _input_retrieval_scores_empty.target, + "`preds` and `target` must be non-empty and non-scalar tensors", + ), + # check on input dtypes + ( + _input_retrieval_scores.preds.bool(), + _input_retrieval_scores.target, + "`preds` must be a tensor of floats", + ), + ( + _input_retrieval_scores.preds, + _input_retrieval_scores.target.float(), + "`target` must be a tensor of booleans or integers", + ), + # check targets are between 0 and 1 + ( + _input_retrieval_scores_wrong_targets.preds, + _input_retrieval_scores_wrong_targets.target, + "`target` must contain `binary` values", + ), + ] +] + + +_errors_test_class_metric_parameters_k = [ + "indexes, preds, target, message, metric_args", [ + ( + _input_retrieval_scores.index, + _input_retrieval_scores.preds, + _input_retrieval_scores.target, + "`k` has to be a positive integer or None", + {'k': -10}, + ), + ] +] + + +_errors_test_class_metric_parameters = [ + "indexes, preds, target, message, metric_args", [ + ( + None, + _input_retrieval_scores.preds, + _input_retrieval_scores.target, + "`indexes` cannot be None", + {'empty_target_action': "error"}, + ), + # check when error when there are not positive targets + ( + _input_retrieval_scores_no_target.indexes, + _input_retrieval_scores_no_target.preds, + _input_retrieval_scores_no_target.target, + "`compute` method was provided with a query with no positive target.", + {'empty_target_action': "error"}, + ), + # check when input arguments are invalid + ( + _input_retrieval_scores.indexes, + _input_retrieval_scores.preds, + _input_retrieval_scores.target, + "`empty_target_action` received a wrong value `casual_argument`.", + {'empty_target_action': "casual_argument"}, + ), + # check input shapes are consistent + ( + _input_retrieval_scores_mismatching_sizes.indexes, + _input_retrieval_scores_mismatching_sizes.preds, + _input_retrieval_scores_mismatching_sizes.target, + "`indexes`, `preds` and `target` must be of the same shape", + {'empty_target_action': "skip"}, + ), + # check input tensors are not empty + ( + _input_retrieval_scores_empty.indexes, + _input_retrieval_scores_empty.preds, + _input_retrieval_scores_empty.target, + "`indexes`, `preds` and `target` must be non-empty and non-scalar tensors", + {'empty_target_action': "skip"}, + ), + # check on input dtypes + ( + _input_retrieval_scores.indexes.bool(), + _input_retrieval_scores.preds, + _input_retrieval_scores.target, + "`indexes` must be a tensor of long integers", + {'empty_target_action': "skip"}, + ), + ( + _input_retrieval_scores.indexes, + _input_retrieval_scores.preds.bool(), + _input_retrieval_scores.target, + "`preds` must be a tensor of floats", + {'empty_target_action': "skip"}, + ), + ( + _input_retrieval_scores.indexes, + _input_retrieval_scores.preds, + _input_retrieval_scores.target.float(), + "`target` must be a tensor of booleans or integers", + {'empty_target_action': "skip"}, + ), + # check targets are between 0 and 1 + ( + _input_retrieval_scores_wrong_targets.indexes, + _input_retrieval_scores_wrong_targets.preds, + _input_retrieval_scores_wrong_targets.target, + "`target` must contain `binary` values", + {'empty_target_action': "skip"}, + ), + ] +] + + +_default_metric_class_input_arguments = [ + "indexes, preds, target", [ + ( + _input_retrieval_scores.indexes, + _input_retrieval_scores.preds, + _input_retrieval_scores.target + ), + ( + _input_retrieval_scores_extra.indexes, + _input_retrieval_scores_extra.preds, + _input_retrieval_scores_extra.target, + ), + ] +] + + +_default_metric_functional_input_arguments = [ + "preds, target", [ + ( + _input_retrieval_scores.preds, + _input_retrieval_scores.target + ), + ( + _input_retrieval_scores_extra.preds, + _input_retrieval_scores_extra.target + ), + ] +] + + +def _errors_test_class_metric( + indexes: Tensor, + preds: Tensor, + target: Tensor, + metric_class: Metric, + message: str = "", + metric_args: dict = {}, + exception_type: Exception = ValueError, + kwargs_update: dict = {}, +): + """Utility function doing checks about types, parameters and errors. + + Args: + indexes: torch tensor with indexes + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + message: message that exception should return + metric_args: arguments for class initialization + exception_type: callable function that is used for comparison + kwargs_update: Additional keyword arguments that will be passed with indexes, preds and + target when running update on the metric. + """ + with pytest.raises(exception_type, match=message): + metric = metric_class(**metric_args) + metric(preds, target, indexes=indexes, **kwargs_update) + + +def _errors_test_functional_metric( + preds: Tensor, + target: Tensor, + metric_functional: Metric, + message: str = "", + exception_type: Exception = ValueError, + kwargs_update: dict = {}, +): + """Utility function doing checks about types, parameters and errors. + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning functional metric that should be tested + message: message that exception should return + exception_type: callable function that is used for comparison + kwargs_update: Additional keyword arguments that will be passed with indexes, preds and + target when running update on the metric. + """ + with pytest.raises(exception_type, match=message): + metric_functional(preds, target, **kwargs_update) class RetrievalMetricTester(MetricTester): - """ - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - self.run_class_metric_test( + def run_class_metric_test( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + metric_class: Metric, + sk_metric: Callable, + dist_sync_on_step: bool, + metric_args: dict, + ): + _sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, **metric_args) + + super().run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=RetrievalMAP, - sk_metric=sk_metric, + metric_class=metric_class, + sk_metric=_sk_metric_adapted, dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + indexes=indexes, # every additional argument will be passed to metric_class and _sk_metric_adapted + ) + + def run_functional_metric_test( + self, + preds: Tensor, + target: Tensor, + metric_functional: Callable, + sk_metric: Callable, + metric_args: dict, + **kwargs, + ): + # action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive. + _sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, **metric_args) + + super().run_functional_metric_test( + preds=preds, + target=target, + metric_functional=metric_functional, + sk_metric=_sk_metric_adapted, + metric_args=metric_args, + **kwargs, + ) + + def run_precision_test_cpu( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + metric_module: Metric, + metric_functional: Callable, + ): + # action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive. + metric_functional_ignore_indexes = lambda preds, target, indexes: metric_functional(preds, target) + + super().run_precision_test_cpu( + preds=preds, + target=target, + metric_module=metric_module, + metric_functional=metric_functional_ignore_indexes, + metric_args={'empty_target_action': 'neg'}, + indexes=indexes, # every additional argument will be passed to RetrievalMAP and _sk_metric_adapted + ) + + def run_precision_test_gpu( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + metric_module: Metric, + metric_functional: Callable, + ): + if not torch.cuda.is_available(): + pytest.skip() + + # action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive. + metric_functional_ignore_indexes = lambda preds, target, indexes: metric_functional(preds, target) + + super().run_precision_test_gpu( + preds=preds, + target=target, + metric_module=metric_module, + metric_functional=metric_functional_ignore_indexes, + metric_args={'empty_target_action': 'neg'}, + indexes=indexes, # every additional argument will be passed to RetrievalMAP and _sk_metric_adapted ) - """ - def test_average_precision_functional(self, preds, target, sk_metric): - self.run_functional_metric_test( - preds, - target, - metric_functional=retrieval_average_precision, - sk_metric=sk_metric, + def run_metric_class_arguments_test( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + metric_class: Metric, + message: str = "", + metric_args: dict = {}, + exception_type: Exception = ValueError, + kwargs_update: dict = {}, + ): + _errors_test_class_metric( + indexes=indexes, + preds=preds, + target=target, + metric_class=metric_class, + message=message, + metric_args=metric_args, + exception_type=exception_type, + **kwargs_update, ) - def test_a_caso(self, preds, target, sk_metric): - assert False \ No newline at end of file + def run_functional_metric_arguments_test( + self, + preds: Tensor, + target: Tensor, + metric_functional: Callable, + message: str = "", + exception_type: Exception = ValueError, + kwargs_update: dict = {}, + ): + _errors_test_functional_metric( + preds=preds, + target=target, + metric_functional=metric_functional, + message=message, + exception_type=exception_type, + **kwargs_update, + ) diff --git a/tests/retrieval/inputs.py b/tests/retrieval/inputs.py index 6b1dab83ece..1cb7663361d 100644 --- a/tests/retrieval/inputs.py +++ b/tests/retrieval/inputs.py @@ -15,12 +15,50 @@ import torch -from tests.helpers.testers import NUM_BATCHES, BATCH_SIZE +from tests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES -Input = namedtuple('InputMultiple', ["preds", "target", "idx"]) +Input = namedtuple('InputMultiple', ["indexes", "preds", "target"]) +# correct _input_retrieval_scores = Input( + indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), - idx=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)) +) + +_input_retrieval_scores_extra = Input( + indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), +) + +# with errors +_input_retrieval_scores_no_target = Input( + indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.randint(high=1, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_input_retrieval_scores_empty = Input( + indexes=torch.randint(high=10, size=[0]), + preds=torch.rand(0), + target=torch.randint(high=2, size=[0]), +) + +_input_retrieval_scores_mismatching_sizes = Input( + indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE - 2)), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_input_retrieval_scores_mismatching_sizes_func = Input( + indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE - 2), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_input_retrieval_scores_wrong_targets = Input( + indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.randint(low=-2**31, high=2**31, size=(NUM_BATCHES, BATCH_SIZE)), ) diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index d021820958f..dfc4910e606 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -12,27 +12,127 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from sklearn.metrics import average_precision_score as sk_average_precision +from sklearn.metrics import average_precision_score as sk_average_precision_score +from torch import Tensor -from tests.retrieval.helpers import _test_dtypes, _test_input_shapes, _test_retrieval_against_sklearn +from tests.helpers import seed_all +from tests.retrieval.helpers import ( + RetrievalMetricTester, + _default_metric_class_input_arguments, + _default_metric_functional_input_arguments, + _errors_test_class_metric_parameters, + _errors_test_functional_metric_parameters, +) +from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision from torchmetrics.retrieval.mean_average_precision import RetrievalMAP +seed_all(42) -@pytest.mark.parametrize('size', [1, 4, 10]) -@pytest.mark.parametrize('n_documents', [1, 5]) -@pytest.mark.parametrize('empty_target_action', ['skip', 'pos', 'neg']) -def test_results(size, n_documents, empty_target_action): - """ Test metrics are computed correctly. """ - _test_retrieval_against_sklearn( - sk_average_precision, RetrievalMAP, size, n_documents, empty_target_action - ) +class TestMAP(RetrievalMetricTester): -def test_dtypes(): - """ Check dypes are managed correctly. """ - _test_dtypes(RetrievalMAP) + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos']) + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_class_metric( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + ): + metric_args = {'empty_target_action': empty_target_action} + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalMAP, + sk_metric=sk_average_precision_score, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) -def test_input_shapes() -> None: - """Check inputs shapes are managed correctly. """ - _test_input_shapes(RetrievalMAP) + @pytest.mark.parametrize(*_default_metric_functional_input_arguments) + def test_functional_metric( + self, + preds: Tensor, + target: Tensor, + ): + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=retrieval_average_precision, + sk_metric=sk_average_precision_score, + metric_args={}, + ) + + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_precision_cpu( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + ): + self.run_precision_test_cpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalMAP, + metric_functional=retrieval_average_precision, + ) + + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_precision_gpu( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + ): + self.run_precision_test_gpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalMAP, + metric_functional=retrieval_average_precision, + ) + + @pytest.mark.parametrize(*_errors_test_class_metric_parameters) + def test_arguments_class_metric( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + message: str, + metric_args: dict, + ): + self.run_metric_class_arguments_test( + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalMAP, + message=message, + metric_args=metric_args, + exception_type=ValueError, + kwargs_update={}, + ) + + @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) + def test_arguments_class_metric( + self, + preds: Tensor, + target: Tensor, + message: str, + ): + self.run_functional_metric_arguments_test( + preds=preds, + target=target, + metric_functional=retrieval_average_precision, + message=message, + exception_type=ValueError, + kwargs_update={}, + ) diff --git a/tests/retrieval/test_map_mt.py b/tests/retrieval/test_map_mt.py deleted file mode 100644 index d402818171b..00000000000 --- a/tests/retrieval/test_map_mt.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from functools import partial - -from tests.retrieval.helpers import _test_dtypes, _test_input_shapes, _test_retrieval_against_sklearn -from torchmetrics.retrieval.mean_average_precision import RetrievalMAP -from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision - -import pytest -from sklearn.metrics import average_precision_score as sk_average_precision_score - -from tests.retrieval.inputs import _input_retrieval_scores -from tests.retrieval.helpers import _compute_sklearn_metric - -from tests.helpers import seed_all -from tests.helpers.testers import MetricTester - -seed_all(42) - - -@pytest.mark.parametrize("sk_metric", [sk_average_precision_score]) -@pytest.mark.parametrize( - "preds, target, idx", [ - (_input_retrieval_scores.preds, _input_retrieval_scores.target, _input_retrieval_scores.idx), - ] -) -class TestRetrievalMetric(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos']) - def test_average_precision(self, preds, target, idx, sk_metric, ddp, dist_sync_on_step, empty_target_action): - _sk_metric = partial(_compute_sklearn_metric, metric=sk_metric) - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=RetrievalMAP, - sk_metric=_sk_metric, - dist_sync_on_step=dist_sync_on_step, - metric_args={'empty_target_action': empty_target_action}, - idx=idx - ) - - def test_average_precision_functional(self, preds, target, idx, sk_metric): - _sk_metric = partial(_compute_sklearn_metric, metric=sk_metric, empty_target_action="neg", idx=None) - self.run_functional_metric_test( - preds, - target, - metric_functional=retrieval_average_precision, - sk_metric=_sk_metric, - ) diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 07f05c42145..4c275f0215c 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -1,3 +1,4 @@ + # Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,10 +15,22 @@ import numpy as np import pytest from sklearn.metrics import label_ranking_average_precision_score +from torch import Tensor + +from tests.helpers import seed_all +from tests.retrieval.helpers import ( + RetrievalMetricTester, + _default_metric_class_input_arguments, + _default_metric_functional_input_arguments, + _errors_test_class_metric_parameters, + _errors_test_functional_metric_parameters, +) -from tests.retrieval.helpers import _test_dtypes, _test_input_shapes, _test_retrieval_against_sklearn +from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR +seed_all(42) + def _reciprocal_rank(target: np.array, preds: np.array): """ @@ -30,7 +43,7 @@ def _reciprocal_rank(target: np.array, preds: np.array): assert len(target.shape) == 1 # works only with single dimension inputs # going to remove T targets that are not ranked as highest - indexes = preds[target.astype(np.bool)] + indexes = preds[target.astype(bool)] if len(indexes) > 0: target[preds != indexes.max(-1, keepdims=True)[0]] = 0 # ensure that only 1 positive label is present @@ -41,21 +54,110 @@ def _reciprocal_rank(target: np.array, preds: np.array): return 0.0 -@pytest.mark.parametrize('size', [1, 4, 10]) -@pytest.mark.parametrize('n_documents', [1, 5]) -@pytest.mark.parametrize('empty_target_action', ['skip', 'pos', 'neg']) -def test_results(size, n_documents, empty_target_action): - """ Test metrics are computed correctly. """ - _test_retrieval_against_sklearn( - _reciprocal_rank, RetrievalMRR, size, n_documents, empty_target_action - ) +class TestMRR(RetrievalMetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos']) + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_class_metric( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + ): + metric_args = {'empty_target_action': empty_target_action} + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalMRR, + sk_metric=_reciprocal_rank, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) + @pytest.mark.parametrize(*_default_metric_functional_input_arguments) + def test_functional_metric( + self, + preds: Tensor, + target: Tensor, + ): + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=retrieval_reciprocal_rank, + sk_metric=_reciprocal_rank, + metric_args={}, + ) -def test_dtypes(): - """ Check dypes are managed correctly. """ - _test_dtypes(RetrievalMRR) + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_precision_cpu( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + ): + self.run_precision_test_cpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalMRR, + metric_functional=retrieval_reciprocal_rank, + ) + + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_precision_gpu( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + ): + self.run_precision_test_gpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalMRR, + metric_functional=retrieval_reciprocal_rank, + ) + @pytest.mark.parametrize(*_errors_test_class_metric_parameters) + def test_arguments_class_metric( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + message: str, + metric_args: dict, + ): + self.run_metric_class_arguments_test( + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalMRR, + message=message, + metric_args=metric_args, + exception_type=ValueError, + kwargs_update={}, + ) -def test_input_shapes() -> None: - """Check inputs shapes are managed correctly. """ - _test_input_shapes(RetrievalMRR) + @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) + def test_arguments_class_metric( + self, + preds: Tensor, + target: Tensor, + message: str, + ): + self.run_functional_metric_arguments_test( + preds=preds, + target=target, + metric_functional=retrieval_reciprocal_rank, + message=message, + exception_type=ValueError, + kwargs_update={}, + ) diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index af309f7af6a..ab3215c8b18 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -13,10 +13,23 @@ # limitations under the License. import numpy as np import pytest +from torch import Tensor -from tests.retrieval.helpers import _test_dtypes, _test_input_args, _test_input_shapes, _test_retrieval_against_sklearn +from tests.helpers import seed_all +from tests.retrieval.helpers import ( + RetrievalMetricTester, + _default_metric_class_input_arguments, + _default_metric_functional_input_arguments, + _errors_test_class_metric_parameters, + _errors_test_functional_metric_parameters, + _errors_test_class_metric_parameters_k, +) + +from torchmetrics.functional.retrieval.precision import retrieval_precision from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision +seed_all(42) + def _precision_at_k(target: np.array, preds: np.array, k: int = None): """ @@ -38,28 +51,135 @@ def _precision_at_k(target: np.array, preds: np.array, k: int = None): return np.NaN -@pytest.mark.parametrize('size', [1, 4, 10]) -@pytest.mark.parametrize('n_documents', [1, 5]) -@pytest.mark.parametrize('empty_target_action', ['skip', 'pos', 'neg']) -@pytest.mark.parametrize('k', [None, 1, 4, 10]) -def test_results(size, n_documents, empty_target_action, k): - """ Test metrics are computed correctly. """ - _test_retrieval_against_sklearn( - _precision_at_k, RetrievalPrecision, size, n_documents, empty_target_action, k=k - ) +class TestPrecision(RetrievalMetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos']) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_class_metric( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + k: int, + ): + metric_args = {'empty_target_action': empty_target_action, 'k': k} + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalPrecision, + sk_metric=_precision_at_k, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) -def test_dtypes(): - """ Check dypes are managed correctly. """ - _test_dtypes(RetrievalPrecision) + @pytest.mark.parametrize(*_default_metric_functional_input_arguments) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + def test_functional_metric( + self, + preds: Tensor, + target: Tensor, + k: int, + ): + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=retrieval_precision, + sk_metric=_precision_at_k, + metric_args={}, + k=k, + ) + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_precision_cpu( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + ): + self.run_precision_test_cpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalPrecision, + metric_functional=retrieval_precision, + ) + + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_precision_gpu( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + ): + self.run_precision_test_gpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalPrecision, + metric_functional=retrieval_precision, + ) -def test_input_shapes() -> None: - """Check inputs shapes are managed correctly. """ - _test_input_shapes(RetrievalPrecision) + @pytest.mark.parametrize(*_errors_test_class_metric_parameters) + def test_arguments_class_metric( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + message: str, + metric_args: dict, + ): + self.run_metric_class_arguments_test( + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalPrecision, + message=message, + metric_args=metric_args, + exception_type=ValueError, + kwargs_update={}, + ) + @pytest.mark.parametrize(*_errors_test_class_metric_parameters_k) + def test_additional_arguments_class_metric( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + message: str, + metric_args: dict, + ): + self.run_metric_class_arguments_test( + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalPrecision, + message=message, + metric_args=metric_args, + exception_type=ValueError, + kwargs_update={}, + ) -@pytest.mark.parametrize('k', [-1, 1.0]) -def test_input_params(k) -> None: - """Check invalid args are managed correctly. """ - _test_input_args(RetrievalPrecision, "`k` has to be a positive integer or None", k=k) + @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) + def test_arguments_class_metric( + self, + preds: Tensor, + target: Tensor, + message: str, + ): + self.run_functional_metric_arguments_test( + preds=preds, + target=target, + metric_functional=retrieval_precision, + message=message, + exception_type=ValueError, + kwargs_update={}, + ) diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index 4d8393f218d..a05658aba33 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -13,10 +13,23 @@ # limitations under the License. import numpy as np import pytest +from torch import Tensor -from tests.retrieval.helpers import _test_dtypes, _test_input_args, _test_input_shapes, _test_retrieval_against_sklearn +from tests.helpers import seed_all +from tests.retrieval.helpers import ( + RetrievalMetricTester, + _default_metric_class_input_arguments, + _default_metric_functional_input_arguments, + _errors_test_class_metric_parameters, + _errors_test_functional_metric_parameters, + _errors_test_class_metric_parameters_k, +) + +from torchmetrics.functional.retrieval.recall import retrieval_recall from torchmetrics.retrieval.retrieval_recall import RetrievalRecall +seed_all(42) + def _recall_at_k(target: np.array, preds: np.array, k: int = None): """ @@ -37,33 +50,135 @@ def _recall_at_k(target: np.array, preds: np.array, k: int = None): return np.NaN -@pytest.mark.parametrize('size', [1, 4, 10]) -@pytest.mark.parametrize('n_documents', [1, 5]) -@pytest.mark.parametrize('empty_target_action', ['skip', 'pos', 'neg']) -@pytest.mark.parametrize('k', [None, 1, 4, 10]) -def test_results(size, n_documents, empty_target_action, k): - """ Test metrics are computed correctly. """ - _test_retrieval_against_sklearn( - _recall_at_k, - RetrievalRecall, - size, - n_documents, - empty_target_action, - k=k - ) +class TestRecall(RetrievalMetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos']) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_class_metric( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + k: int, + ): + metric_args = {'empty_target_action': empty_target_action, 'k': k} + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalRecall, + sk_metric=_recall_at_k, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) -def test_dtypes(): - """ Check dypes are managed correctly. """ - _test_dtypes(RetrievalRecall) + @pytest.mark.parametrize(*_default_metric_functional_input_arguments) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + def test_functional_metric( + self, + preds: Tensor, + target: Tensor, + k: int, + ): + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=retrieval_recall, + sk_metric=_recall_at_k, + metric_args={}, + k=k, + ) + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_precision_cpu( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + ): + self.run_precision_test_cpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalRecall, + metric_functional=retrieval_recall, + ) + + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_precision_gpu( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + ): + self.run_precision_test_gpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalRecall, + metric_functional=retrieval_recall, + ) -def test_input_shapes() -> None: - """Check inputs shapes are managed correctly. """ - _test_input_shapes(RetrievalRecall) + @pytest.mark.parametrize(*_errors_test_class_metric_parameters) + def test_arguments_class_metric( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + message: str, + metric_args: dict, + ): + self.run_metric_class_arguments_test( + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalRecall, + message=message, + metric_args=metric_args, + exception_type=ValueError, + kwargs_update={}, + ) + @pytest.mark.parametrize(*_errors_test_class_metric_parameters_k) + def test_additional_arguments_class_metric( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + message: str, + metric_args: dict, + ): + self.run_metric_class_arguments_test( + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalRecall, + message=message, + metric_args=metric_args, + exception_type=ValueError, + kwargs_update={}, + ) -@pytest.mark.parametrize('k', [-1, 1.0]) -def test_input_params(k) -> None: - """Check invalid args are managed correctly. """ - _test_input_args(RetrievalRecall, "`k` has to be a positive integer or None", k=k) + @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) + def test_arguments_class_metric( + self, + preds: Tensor, + target: Tensor, + message: str, + ): + self.run_functional_metric_arguments_test( + preds=preds, + target=target, + metric_functional=retrieval_recall, + message=message, + exception_type=ValueError, + kwargs_update={}, + ) diff --git a/torchmetrics/retrieval/mean_average_precision.py b/torchmetrics/retrieval/mean_average_precision.py index fbc57ca6fcc..15897fa8ee3 100644 --- a/torchmetrics/retrieval/mean_average_precision.py +++ b/torchmetrics/retrieval/mean_average_precision.py @@ -43,8 +43,7 @@ class RetrievalMAP(RetrievalMetric): - ``'error'``: raise a ``ValueError`` - ``'pos'``: score on those queries is counted as ``1.0`` - ``'neg'``: score on those queries is counted as ``0.0`` - exclude: - Do not take into account predictions where the ``target`` is equal to this value. default `-100` + compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -68,5 +67,4 @@ class RetrievalMAP(RetrievalMetric): """ def _metric(self, preds: Tensor, target: Tensor) -> Tensor: - valid_indexes = (target != self.exclude) - return retrieval_average_precision(preds[valid_indexes], target[valid_indexes]) + return retrieval_average_precision(preds, target) diff --git a/torchmetrics/retrieval/mean_reciprocal_rank.py b/torchmetrics/retrieval/mean_reciprocal_rank.py index 688c1f8c393..8fb4fe17459 100644 --- a/torchmetrics/retrieval/mean_reciprocal_rank.py +++ b/torchmetrics/retrieval/mean_reciprocal_rank.py @@ -44,8 +44,6 @@ class RetrievalMRR(RetrievalMetric): - ``'pos'``: score on those queries is counted as ``1.0`` - ``'neg'``: score on those queries is counted as ``0.0`` - exclude: - Do not take into account predictions where the ``target`` is equal to this value. default `-100` compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -69,5 +67,4 @@ class RetrievalMRR(RetrievalMetric): """ def _metric(self, preds: Tensor, target: Tensor) -> Tensor: - valid_indexes = (target != self.exclude) - return retrieval_reciprocal_rank(preds[valid_indexes], target[valid_indexes]) + return retrieval_reciprocal_rank(preds, target) diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index 7df6773d97d..a9663503544 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -22,7 +22,6 @@ from torchmetrics.utilities.data import get_group_indexes #: get_group_indexes is used to group predictions belonging to the same document -IGNORE_IDX = -100 class RetrievalMetric(Metric, ABC): @@ -51,8 +50,6 @@ class RetrievalMetric(Metric, ABC): - ``'error'``: raise a ``ValueError`` - ``'pos'``: score on those queries is counted as ``1.0`` - ``'neg'``: score on those queries is counted as ``0.0`` - exclude: - Do not take into account predictions where the target is equal to this value. default `-100` compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -69,7 +66,6 @@ class RetrievalMetric(Metric, ABC): def __init__( self, empty_target_action: str = 'skip', - exclude: int = IGNORE_IDX, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -84,46 +80,40 @@ def __init__( empty_target_action_options = ('error', 'skip', 'pos', 'neg') if empty_target_action not in empty_target_action_options: - raise ValueError(f"`empty_target_action` received a wrong value {empty_target_action}.") + raise ValueError(f"`empty_target_action` received a wrong value `{empty_target_action}`.") self.empty_target_action = empty_target_action - self.exclude = exclude - self.next_index = 0 - self.add_state("idx", default=[], dist_reduce_fx=None) + self.add_state("indexes", default=[], dist_reduce_fx=None) self.add_state("preds", default=[], dist_reduce_fx=None) self.add_state("target", default=[], dist_reduce_fx=None) - def update(self, preds: Tensor, target: Tensor, idx: Tensor = None) -> None: + def update(self, preds: Tensor, target: Tensor, indexes: Tensor = None) -> None: """ Check shape, check and convert dtypes, flatten and add to accumulators. """ - if idx is None: - idx = torch.full(preds.shape, fill_value=self.next_index, dtype=torch.long, device=preds.device) + if indexes is None: + raise ValueError("`indexes` cannot be None") - # update index - actual_max_id = torch.max(idx).item() - if actual_max_id > self.next_index: - self.next_index = actual_max_id + indexes, preds, target = _check_retrieval_inputs(indexes, preds, target) - idx, preds, target = _check_retrieval_inputs(idx, preds, target, ignore=IGNORE_IDX) - self.idx.append(idx.flatten()) - self.preds.append(preds.flatten()) - self.target.append(target.flatten()) + self.indexes.append(indexes) + self.preds.append(preds) + self.target.append(target) def compute(self) -> Tensor: """ - First concat state `idx`, `preds` and `target` since they were stored as lists. After that, + First concat state `indexes`, `preds` and `target` since they were stored as lists. After that, compute list of groups that will help in keeping together predictions about the same query. Finally, for each group compute the `_metric` if the number of positive targets is at least 1, otherwise behave as specified by `self.empty_target_action`. """ - idx = torch.cat(self.idx, dim=0) + indexes = torch.cat(self.indexes, dim=0) preds = torch.cat(self.preds, dim=0) target = torch.cat(self.target, dim=0) res = [] - kwargs = {'device': idx.device, 'dtype': torch.float32} + kwargs = {'device': indexes.device, 'dtype': torch.float32} - groups = get_group_indexes(idx) + groups = get_group_indexes(indexes) for group in groups: mini_preds = preds[group] diff --git a/torchmetrics/retrieval/retrieval_precision.py b/torchmetrics/retrieval/retrieval_precision.py index b44698f980e..766575c245e 100644 --- a/torchmetrics/retrieval/retrieval_precision.py +++ b/torchmetrics/retrieval/retrieval_precision.py @@ -16,7 +16,7 @@ from torch import Tensor, tensor from torchmetrics.functional.retrieval.precision import retrieval_precision -from torchmetrics.retrieval.retrieval_metric import IGNORE_IDX, RetrievalMetric +from torchmetrics.retrieval.retrieval_metric import RetrievalMetric class RetrievalPrecision(RetrievalMetric): @@ -46,8 +46,6 @@ class RetrievalPrecision(RetrievalMetric): - ``'pos'``: score on those queries is counted as ``1.0`` - ``'neg'``: score on those queries is counted as ``0.0`` - exclude: - Do not take into account predictions where the ``target`` is equal to this value. default `-100` compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -74,7 +72,6 @@ class RetrievalPrecision(RetrievalMetric): def __init__( self, empty_target_action: str = 'skip', - exclude: int = IGNORE_IDX, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -83,7 +80,6 @@ def __init__( ): super().__init__( empty_target_action=empty_target_action, - exclude=exclude, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, @@ -95,5 +91,4 @@ def __init__( self.k = k def _metric(self, preds: Tensor, target: Tensor) -> Tensor: - valid_indexes = (target != self.exclude) - return retrieval_precision(preds[valid_indexes], target[valid_indexes], k=self.k) + return retrieval_precision(preds, target, k=self.k) diff --git a/torchmetrics/retrieval/retrieval_recall.py b/torchmetrics/retrieval/retrieval_recall.py index ace27761d6b..f8d06db5cd3 100644 --- a/torchmetrics/retrieval/retrieval_recall.py +++ b/torchmetrics/retrieval/retrieval_recall.py @@ -16,7 +16,7 @@ from torch import Tensor, tensor from torchmetrics.functional.retrieval.recall import retrieval_recall -from torchmetrics.retrieval.retrieval_metric import IGNORE_IDX, RetrievalMetric +from torchmetrics.retrieval.retrieval_metric import RetrievalMetric class RetrievalRecall(RetrievalMetric): @@ -46,8 +46,6 @@ class RetrievalRecall(RetrievalMetric): - ``'pos'``: score on those queries is counted as ``1.0`` - ``'neg'``: score on those queries is counted as ``0.0`` - exclude: - Do not take into account predictions where the ``target`` is equal to this value. default `-100` compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -74,7 +72,6 @@ class RetrievalRecall(RetrievalMetric): def __init__( self, empty_target_action: str = 'skip', - exclude: int = IGNORE_IDX, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -83,7 +80,6 @@ def __init__( ): super().__init__( empty_target_action=empty_target_action, - exclude=exclude, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, @@ -95,5 +91,4 @@ def __init__( self.k = k def _metric(self, preds: Tensor, target: Tensor) -> Tensor: - valid_indexes = (target != self.exclude) - return retrieval_recall(preds[valid_indexes], target[valid_indexes], k=self.k) + return retrieval_recall(preds, target, k=self.k) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index f5d8f027467..a83ba160669 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -14,7 +14,7 @@ from typing import Optional, Tuple import torch -from torch import Tensor +from torch import Tensor, tensor from torchmetrics.utilities.data import select_topk, to_onehot from torchmetrics.utilities.enums import DataType @@ -513,26 +513,25 @@ def _check_retrieval_functional_inputs(preds: Tensor, target: Tensor) -> None: if preds.shape != target.shape: raise ValueError("`preds` and `target` must be of the same shape") - if not preds.numel() or not target.numel(): - raise ValueError("`preds` and `target` must be non-empty") + if not preds.numel() or not preds.size(): + raise ValueError("`preds` and `target` must be non-empty and non-scalar tensors") if target.dtype not in (torch.bool, torch.long, torch.int): raise ValueError("`target` must be a tensor of booleans or integers") - if target.max() > 1 or target.min() < 0: - raise ValueError("`target` must be of type `binary`") - if not preds.is_floating_point(): raise ValueError("`preds` must be a tensor of floats") - return preds.float(), target.long() + if target.max() > 1 or target.min() < 0: + raise ValueError("`target` must contain `binary` values") + + return preds.float().flatten(), target.long().flatten() def _check_retrieval_inputs( indexes: Tensor, preds: Tensor, target: Tensor, - ignore: int = None, ) -> Tuple[Tensor, Tensor, Tensor]: """Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct dtype. @@ -540,7 +539,6 @@ def _check_retrieval_inputs( indexes: tensor with queries indexes preds: tensor with scores/logits target: tensor with ground true labels - ignore: ignore target with this value Raises: ValueError: @@ -552,16 +550,22 @@ def _check_retrieval_inputs( preds: as torch.float32 target: as torch.long """ - if ignore is not None: - target = target[target != ignore] # ignore check on values that are ignored - preds = preds[target != ignore] - - preds, target = _check_retrieval_functional_inputs(preds, target) - - if indexes.shape != target.shape: + if indexes.shape != preds.shape or preds.shape != target.shape: raise ValueError("`indexes`, `preds` and `target` must be of the same shape") + if not indexes.numel() or not indexes.size(): + raise ValueError("`indexes`, `preds` and `target` must be non-empty and non-scalar tensors",) + if indexes.dtype is not torch.long: raise ValueError("`indexes` must be a tensor of long integers") - return indexes, preds, target + if not preds.is_floating_point(): + raise ValueError("`preds` must be a tensor of floats") + + if target.dtype not in (torch.bool, torch.long, torch.int): + raise ValueError("`target` must be a tensor of booleans or integers") + + if target.max() > 1 or target.min() < 0: + raise ValueError("`target` must contain `binary` values") + + return indexes.long().flatten(), preds.float().flatten(), target.long().flatten() diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index c58d442d73b..c6edce71ee5 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -231,40 +231,42 @@ def apply_to_collection( return data -def get_group_indexes(idx: Tensor) -> List[Tensor]: +def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, np.ndarray]]: """ - Given an integer `torch.Tensor` or `np.array` `idx`, return a `torch.Tensor` or `np.array` of indexes for - each different value in `idx`. + Given an integer `torch.Tensor` or `np.array` `indexes`, return a `torch.Tensor` or `np.array` of indexes for + each different value in `indexes`. Args: - idx: a `torch.Tensor` or `np.array` of integers + indexes: a `torch.Tensor` or `np.array` of integers Return: A list of integer `torch.Tensor`s or `np.array`s Example: >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) - >>> groups = get_group_indexes(indexes) - >>> groups + >>> get_group_indexes(indexes) [tensor([0, 1, 2]), tensor([3, 4, 5, 6])] >>> >>> indexes = np.ndarray([0, 0, 0, 1, 1, 1, 1]) - >>> groups = get_group_indexes(indexes) - >>> groups + >>> get_group_indexes(indexes) [array([0, 1, 2]), array([3, 4, 5, 6])] """ - if not isinstance(idx, (Tensor, np.ndarray)): - raise ValueError("`idx` must be a torch tensor or numpy array") + if not isinstance(indexes, (Tensor, np.ndarray)): + raise ValueError("`indexes` must be a torch tensor or numpy array") - structure = tensor if isinstance(idx, Tensor) else np.array - dtype = torch.long if isinstance(idx, Tensor) else np.int64 + if not len(indexes.shape) == 1: + raise ValueError("`indexes` must have a single dimension") - indexes = dict() - for i, _id in enumerate(idx): + structure = tensor if isinstance(indexes, Tensor) else np.array + dtype = torch.long if isinstance(indexes, Tensor) else np.int64 + + res = dict() + for i, _id in enumerate(indexes): _id = _id.item() - if _id in indexes: - indexes[_id] += [i] + if _id in res: + res[_id] += [i] else: - indexes[_id] = [i] - return [structure(x, dtype=dtype) for x in indexes.values()] + res[_id] = [i] + + return [structure(x, dtype=dtype) for x in res.values()] From 34bcefa1444cd701e110e716f93d9ba201bc7f48 Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sat, 3 Apr 2021 22:21:38 +0200 Subject: [PATCH 04/33] added pep8 compatibility --- tests/retrieval/helpers.py | 2 +- tests/retrieval/test_map.py | 4 ++-- tests/retrieval/test_mrr.py | 5 ++--- tests/retrieval/test_precision.py | 7 +++---- tests/retrieval/test_recall.py | 7 +++---- torchmetrics/utilities/checks.py | 2 +- 6 files changed, 12 insertions(+), 15 deletions(-) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 1471cee9602..5743af7098f 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -54,7 +54,7 @@ def _compute_sklearn_metric( preds = preds.cpu().numpy() if isinstance(target, Tensor): target = target.cpu().numpy() - + assert isinstance(indexes, np.ndarray) assert isinstance(preds, np.ndarray) assert isinstance(target, np.ndarray) diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index dfc4910e606..276bb0d0248 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -85,7 +85,7 @@ def test_precision_cpu( metric_module=RetrievalMAP, metric_functional=retrieval_average_precision, ) - + @pytest.mark.parametrize(*_default_metric_class_input_arguments) def test_precision_gpu( self, @@ -122,7 +122,7 @@ def test_arguments_class_metric( ) @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) - def test_arguments_class_metric( + def test_arguments_functional_metric( self, preds: Tensor, target: Tensor, diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 4c275f0215c..6bdf30ab38d 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -25,7 +25,6 @@ _errors_test_class_metric_parameters, _errors_test_functional_metric_parameters, ) - from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR @@ -110,7 +109,7 @@ def test_precision_cpu( metric_module=RetrievalMRR, metric_functional=retrieval_reciprocal_rank, ) - + @pytest.mark.parametrize(*_default_metric_class_input_arguments) def test_precision_gpu( self, @@ -147,7 +146,7 @@ def test_arguments_class_metric( ) @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) - def test_arguments_class_metric( + def test_arguments_functional_metric( self, preds: Tensor, target: Tensor, diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index ab3215c8b18..106ac5b8b73 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -21,10 +21,9 @@ _default_metric_class_input_arguments, _default_metric_functional_input_arguments, _errors_test_class_metric_parameters, - _errors_test_functional_metric_parameters, _errors_test_class_metric_parameters_k, + _errors_test_functional_metric_parameters, ) - from torchmetrics.functional.retrieval.precision import retrieval_precision from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision @@ -112,7 +111,7 @@ def test_precision_cpu( metric_module=RetrievalPrecision, metric_functional=retrieval_precision, ) - + @pytest.mark.parametrize(*_default_metric_class_input_arguments) def test_precision_gpu( self, @@ -169,7 +168,7 @@ def test_additional_arguments_class_metric( ) @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) - def test_arguments_class_metric( + def test_arguments_functional_metric( self, preds: Tensor, target: Tensor, diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index a05658aba33..1be55b590c6 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -21,10 +21,9 @@ _default_metric_class_input_arguments, _default_metric_functional_input_arguments, _errors_test_class_metric_parameters, - _errors_test_functional_metric_parameters, _errors_test_class_metric_parameters_k, + _errors_test_functional_metric_parameters, ) - from torchmetrics.functional.retrieval.recall import retrieval_recall from torchmetrics.retrieval.retrieval_recall import RetrievalRecall @@ -111,7 +110,7 @@ def test_precision_cpu( metric_module=RetrievalRecall, metric_functional=retrieval_recall, ) - + @pytest.mark.parametrize(*_default_metric_class_input_arguments) def test_precision_gpu( self, @@ -168,7 +167,7 @@ def test_additional_arguments_class_metric( ) @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) - def test_arguments_class_metric( + def test_arguments_functional_metric( self, preds: Tensor, target: Tensor, diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index a83ba160669..23278b9c86a 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -14,7 +14,7 @@ from typing import Optional, Tuple import torch -from torch import Tensor, tensor +from torch import Tensor from torchmetrics.utilities.data import select_topk, to_onehot from torchmetrics.utilities.enums import DataType From fbaf05f4214591f9de8af4b34bbf07d7616ee850 Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sat, 3 Apr 2021 22:42:09 +0200 Subject: [PATCH 05/33] fixed np.ndarray to np.array --- tests/retrieval/helpers.py | 8 ++++---- torchmetrics/utilities/data.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 5743af7098f..4ec35737af6 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -39,7 +39,7 @@ def _compute_sklearn_metric( preds: Union[Tensor, array], target: Union[Tensor, array], - indexes: np.ndarray = None, + indexes: np.array = None, metric: Callable = None, empty_target_action: str = "skip", **kwargs @@ -55,9 +55,9 @@ def _compute_sklearn_metric( if isinstance(target, Tensor): target = target.cpu().numpy() - assert isinstance(indexes, np.ndarray) - assert isinstance(preds, np.ndarray) - assert isinstance(target, np.ndarray) + assert isinstance(indexes, np.array) + assert isinstance(preds, np.array) + assert isinstance(target, np.array) indexes = indexes.flatten() preds = preds.flatten() diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index c6edce71ee5..dc60d7559b7 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -231,7 +231,7 @@ def apply_to_collection( return data -def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, np.ndarray]]: +def get_group_indexes(indexes: Union[Tensor, np.array]) -> List[Union[Tensor, np.array]]: """ Given an integer `torch.Tensor` or `np.array` `indexes`, return a `torch.Tensor` or `np.array` of indexes for each different value in `indexes`. @@ -247,12 +247,12 @@ def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, >>> get_group_indexes(indexes) [tensor([0, 1, 2]), tensor([3, 4, 5, 6])] >>> - >>> indexes = np.ndarray([0, 0, 0, 1, 1, 1, 1]) + >>> indexes = np.array([0, 0, 0, 1, 1, 1, 1]) >>> get_group_indexes(indexes) [array([0, 1, 2]), array([3, 4, 5, 6])] """ - if not isinstance(indexes, (Tensor, np.ndarray)): + if not isinstance(indexes, (Tensor, np.array)): raise ValueError("`indexes` must be a torch tensor or numpy array") if not len(indexes.shape) == 1: From 644477e7a1273bae4dd5a1e259a83d0d2994f77e Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sat, 3 Apr 2021 22:44:10 +0200 Subject: [PATCH 06/33] remove lambda functions --- tests/retrieval/helpers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 4ec35737af6..861860145f5 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -344,7 +344,8 @@ def run_precision_test_cpu( metric_functional: Callable, ): # action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive. - metric_functional_ignore_indexes = lambda preds, target, indexes: metric_functional(preds, target) + def metric_functional_ignore_indexes(preds, target, indexes): + metric_functional(preds, target) super().run_precision_test_cpu( preds=preds, @@ -367,7 +368,8 @@ def run_precision_test_gpu( pytest.skip() # action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive. - metric_functional_ignore_indexes = lambda preds, target, indexes: metric_functional(preds, target) + def metric_functional_ignore_indexes(preds, target, indexes): + metric_functional(preds, target) super().run_precision_test_gpu( preds=preds, From 36ce60a3e8895a96beeb18893b5967d6a3ba3a3c Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sat, 3 Apr 2021 22:54:34 +0200 Subject: [PATCH 07/33] fixed typos with numpy dtype --- tests/retrieval/helpers.py | 12 ++++++------ tests/retrieval/test_mrr.py | 2 +- tests/retrieval/test_precision.py | 2 +- tests/retrieval/test_recall.py | 2 +- torchmetrics/utilities/data.py | 10 +++++----- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 861860145f5..2f3c5f5cf77 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -39,7 +39,7 @@ def _compute_sklearn_metric( preds: Union[Tensor, array], target: Union[Tensor, array], - indexes: np.array = None, + indexes: np.ndarray = None, metric: Callable = None, empty_target_action: str = "skip", **kwargs @@ -55,9 +55,9 @@ def _compute_sklearn_metric( if isinstance(target, Tensor): target = target.cpu().numpy() - assert isinstance(indexes, np.array) - assert isinstance(preds, np.array) - assert isinstance(target, np.array) + assert isinstance(indexes, np.ndarray) + assert isinstance(preds, np.ndarray) + assert isinstance(target, np.ndarray) indexes = indexes.flatten() preds = preds.flatten() @@ -345,7 +345,7 @@ def run_precision_test_cpu( ): # action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive. def metric_functional_ignore_indexes(preds, target, indexes): - metric_functional(preds, target) + return metric_functional(preds, target) super().run_precision_test_cpu( preds=preds, @@ -369,7 +369,7 @@ def run_precision_test_gpu( # action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive. def metric_functional_ignore_indexes(preds, target, indexes): - metric_functional(preds, target) + return metric_functional(preds, target) super().run_precision_test_gpu( preds=preds, diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 6bdf30ab38d..7fb409bd662 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -31,7 +31,7 @@ seed_all(42) -def _reciprocal_rank(target: np.array, preds: np.array): +def _reciprocal_rank(target: np.ndarray, preds: np.ndarray): """ Adaptation of `sklearn.metrics.label_ranking_average_precision_score`. Since the original sklearn metric works as RR only when the number of positive diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index 106ac5b8b73..008fbe18f49 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -30,7 +30,7 @@ seed_all(42) -def _precision_at_k(target: np.array, preds: np.array, k: int = None): +def _precision_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): """ Didn't find a reliable implementation of Precision in Information Retrieval, so, reimplementing here. A good explanation can be found diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index 1be55b590c6..a281751813a 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -30,7 +30,7 @@ seed_all(42) -def _recall_at_k(target: np.array, preds: np.array, k: int = None): +def _recall_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): """ Didn't find a reliable implementation of Recall in Information Retrieval, so, reimplementing here. See wikipedia for more information about definition. diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index dc60d7559b7..7bac95770d4 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -231,16 +231,16 @@ def apply_to_collection( return data -def get_group_indexes(indexes: Union[Tensor, np.array]) -> List[Union[Tensor, np.array]]: +def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, np.ndarray]]: """ - Given an integer `torch.Tensor` or `np.array` `indexes`, return a `torch.Tensor` or `np.array` of indexes for + Given an integer `torch.Tensor` or `np.ndarray` `indexes`, return a `torch.Tensor` or `np.ndarray` of indexes for each different value in `indexes`. Args: - indexes: a `torch.Tensor` or `np.array` of integers + indexes: a `torch.Tensor` or `np.ndarray` of integers Return: - A list of integer `torch.Tensor`s or `np.array`s + A list of integer `torch.Tensor`s or `np.ndarray`s Example: >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) @@ -252,7 +252,7 @@ def get_group_indexes(indexes: Union[Tensor, np.array]) -> List[Union[Tensor, np [array([0, 1, 2]), array([3, 4, 5, 6])] """ - if not isinstance(indexes, (Tensor, np.array)): + if not isinstance(indexes, (Tensor, np.ndarray)): raise ValueError("`indexes` must be a torch tensor or numpy array") if not len(indexes.shape) == 1: From 44e2db61cf6fe0ed7ed76e9230319d9372415901 Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sat, 3 Apr 2021 23:04:13 +0200 Subject: [PATCH 08/33] fixed typo in doc example --- torchmetrics/utilities/data.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 7bac95770d4..46cbf6ee55b 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -246,10 +246,6 @@ def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) >>> get_group_indexes(indexes) [tensor([0, 1, 2]), tensor([3, 4, 5, 6])] - >>> - >>> indexes = np.array([0, 0, 0, 1, 1, 1, 1]) - >>> get_group_indexes(indexes) - [array([0, 1, 2]), array([3, 4, 5, 6])] """ if not isinstance(indexes, (Tensor, np.ndarray)): From f516e42100e3680a6506e3effb2c3c0f2001f7ff Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sat, 3 Apr 2021 23:08:07 +0200 Subject: [PATCH 09/33] fixed typo in doc examples about new indexes position --- torchmetrics/retrieval/mean_average_precision.py | 2 +- torchmetrics/retrieval/mean_reciprocal_rank.py | 2 +- torchmetrics/retrieval/retrieval_precision.py | 2 +- torchmetrics/retrieval/retrieval_recall.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmetrics/retrieval/mean_average_precision.py b/torchmetrics/retrieval/mean_average_precision.py index 15897fa8ee3..c00ba459cd7 100644 --- a/torchmetrics/retrieval/mean_average_precision.py +++ b/torchmetrics/retrieval/mean_average_precision.py @@ -62,7 +62,7 @@ class RetrievalMAP(RetrievalMetric): >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> map = RetrievalMAP() - >>> map(indexes, preds, target) + >>> map(preds, target, indexes=indexes) tensor(0.7917) """ diff --git a/torchmetrics/retrieval/mean_reciprocal_rank.py b/torchmetrics/retrieval/mean_reciprocal_rank.py index 8fb4fe17459..6eaecf235ce 100644 --- a/torchmetrics/retrieval/mean_reciprocal_rank.py +++ b/torchmetrics/retrieval/mean_reciprocal_rank.py @@ -62,7 +62,7 @@ class RetrievalMRR(RetrievalMetric): >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> mrr = RetrievalMRR() - >>> mrr(indexes, preds, target) + >>> mrr(preds, target, indexes=indexes) tensor(0.7500) """ diff --git a/torchmetrics/retrieval/retrieval_precision.py b/torchmetrics/retrieval/retrieval_precision.py index 766575c245e..6ad7c27622e 100644 --- a/torchmetrics/retrieval/retrieval_precision.py +++ b/torchmetrics/retrieval/retrieval_precision.py @@ -65,7 +65,7 @@ class RetrievalPrecision(RetrievalMetric): >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> p2 = RetrievalPrecision(k=2) - >>> p2(indexes, preds, target) + >>> p2(preds, target, indexes=indexes) tensor(0.5000) """ diff --git a/torchmetrics/retrieval/retrieval_recall.py b/torchmetrics/retrieval/retrieval_recall.py index f8d06db5cd3..bee66e8f4da 100644 --- a/torchmetrics/retrieval/retrieval_recall.py +++ b/torchmetrics/retrieval/retrieval_recall.py @@ -65,7 +65,7 @@ class RetrievalRecall(RetrievalMetric): >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> r2 = RetrievalRecall(k=2) - >>> r2(indexes, preds, target) + >>> r2(preds, target, indexes=indexes) tensor(0.7500) """ From 1496d1b397eac2dddeb8781205e75024765c737b Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sun, 4 Apr 2021 12:13:43 +0200 Subject: [PATCH 10/33] added paramter to class testing to divide kwargs as preds and targets. Fixed typo in doc format --- tests/helpers/testers.py | 28 ++++++++++++++----- tests/retrieval/helpers.py | 2 ++ torchmetrics/retrieval/retrieval_precision.py | 8 +++--- torchmetrics/retrieval/retrieval_recall.py | 8 +++--- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index cf338e9200f..e2593923dd9 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -82,6 +82,7 @@ def _class_test( check_batch: bool = True, atol: float = 1e-8, device: str = 'cpu', + fragment_kwargs: bool = False, **kwargs_update: Any, ): """Utility function doing the actual comparison between lightning class metric @@ -102,6 +103,7 @@ def _class_test( check_batch: bool, if true will check if the metric is also correctly calculated across devices for each batch (and not just at the end) device: determine which device to run on, either 'cuda' or 'cpu' + fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ @@ -129,20 +131,21 @@ def _class_test( batch_result = metric(preds[i], target[i], **batch_kwargs_update) if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0: - rank_indexes = [i + r for r in range(worldsize)] - - ddp_preds = preds[rank_indexes].cpu() - ddp_target = target[rank_indexes].cpu() + ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]).cpu() + ddp_target = torch.cat([target[i + r] for r in range(worldsize)]).cpu() ddp_kwargs_upd = { - k: v[rank_indexes].cpu() if isinstance(v, Tensor) else v - for k, v in kwargs_update.items() + k: torch.cat([v[i + r] for r in range(worldsize)]).cpu() if isinstance(v, Tensor) else v + for k, v in (kwargs_update if fragment_kwargs else batch_kwargs_update).items() } sk_batch_result = sk_metric(ddp_preds, ddp_target, **ddp_kwargs_upd) _assert_allclose(batch_result, sk_batch_result, atol=atol) elif check_batch and not metric.dist_sync_on_step: - batch_kwargs_update = {k: v.cpu() for k, v in batch_kwargs_update.items()} + batch_kwargs_update = { + k: v.cpu() + for k, v in (batch_kwargs_update if fragment_kwargs else kwargs_update).items() + } sk_batch_result = sk_metric(preds[i].cpu(), target[i].cpu(), **batch_kwargs_update) _assert_allclose(batch_result, sk_batch_result, atol=atol) @@ -170,6 +173,7 @@ def _functional_test( metric_args: dict = None, atol: float = 1e-8, device: str = 'cpu', + fragment_kwargs: bool = False, **kwargs_update, ): """Utility function doing the actual comparison between lightning functional metric @@ -182,6 +186,7 @@ def _functional_test( sk_metric: callable function that is used for comparison metric_args: dict with additional arguments used for class initialization device: determine which device to run on, either 'cuda' or 'cpu' + fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ @@ -198,6 +203,8 @@ def _functional_test( for i in range(NUM_BATCHES): extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} lightning_result = metric(preds[i], target[i], **extra_kwargs) + if not fragment_kwargs: + extra_kwargs = {k: v.cpu() for k, v in kwargs_update.items()} sk_result = sk_metric(preds[i].cpu(), target[i].cpu(), **extra_kwargs) # assert its the same @@ -268,6 +275,7 @@ def run_functional_metric_test( metric_functional: Callable, sk_metric: Callable, metric_args: dict = None, + fragment_kwargs: bool = False, **kwargs_update, ): """Main method that should be used for testing functions. Call this inside @@ -279,6 +287,7 @@ def run_functional_metric_test( metric_functional: lightning metric class that should be tested sk_metric: callable function that is used for comparison metric_args: dict with additional arguments used for class initialization + fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ @@ -292,6 +301,7 @@ def run_functional_metric_test( metric_args=metric_args, atol=self.atol, device=device, + fragment_kwargs=fragment_kwargs, **kwargs_update, ) @@ -306,6 +316,7 @@ def run_class_metric_test( metric_args: dict = None, check_dist_sync_on_step: bool = True, check_batch: bool = True, + fragment_kwargs: bool = False, **kwargs_update, ): """Main method that should be used for testing class. Call this inside testing @@ -324,6 +335,7 @@ def run_class_metric_test( calculated per batch per device (and not just at the end) check_batch: bool, if true will check if the metric is also correctly calculated across devices for each batch (and not just at the end) + fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ @@ -345,6 +357,7 @@ def run_class_metric_test( check_dist_sync_on_step=check_dist_sync_on_step, check_batch=check_batch, atol=self.atol, + fragment_kwargs=fragment_kwargs, **kwargs_update, ), [(rank, self.poolSize) for rank in range(self.poolSize)], @@ -365,6 +378,7 @@ def run_class_metric_test( check_batch=check_batch, atol=self.atol, device=device, + fragment_kwargs=fragment_kwargs, **kwargs_update, ) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 2f3c5f5cf77..39c1991bf0b 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -311,6 +311,7 @@ def run_class_metric_test( sk_metric=_sk_metric_adapted, dist_sync_on_step=dist_sync_on_step, metric_args=metric_args, + fragment_kwargs=True, indexes=indexes, # every additional argument will be passed to metric_class and _sk_metric_adapted ) @@ -332,6 +333,7 @@ def run_functional_metric_test( metric_functional=metric_functional, sk_metric=_sk_metric_adapted, metric_args=metric_args, + fragment_kwargs=True, **kwargs, ) diff --git a/torchmetrics/retrieval/retrieval_precision.py b/torchmetrics/retrieval/retrieval_precision.py index 6ad7c27622e..acfbeb0eb2e 100644 --- a/torchmetrics/retrieval/retrieval_precision.py +++ b/torchmetrics/retrieval/retrieval_precision.py @@ -41,10 +41,10 @@ class RetrievalPrecision(RetrievalMetric): empty_target_action: Specify what to do with queries that do not have at least a positive ``target``. Choose from: - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned - - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` + - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned + - ``'error'``: raise a ``ValueError`` + - ``'pos'``: score on those queries is counted as ``1.0`` + - ``'neg'``: score on those queries is counted as ``0.0`` compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True diff --git a/torchmetrics/retrieval/retrieval_recall.py b/torchmetrics/retrieval/retrieval_recall.py index bee66e8f4da..a7752544f17 100644 --- a/torchmetrics/retrieval/retrieval_recall.py +++ b/torchmetrics/retrieval/retrieval_recall.py @@ -41,10 +41,10 @@ class RetrievalRecall(RetrievalMetric): empty_target_action: Specify what to do with queries that do not have at least a positive ``target``. Choose from: - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned - - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` + - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned + - ``'error'``: raise a ``ValueError`` + - ``'pos'``: score on those queries is counted as ``1.0`` + - ``'neg'``: score on those queries is counted as ``0.0`` compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True From cc4231157272ae929bc84521bc5a94dd748bbf69 Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sun, 4 Apr 2021 13:01:39 +0200 Subject: [PATCH 11/33] added typo in doc example --- docs/source/references/modules.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index b2cfdc5b92c..9d78092aa65 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -310,7 +310,7 @@ the set of pairs ``(Q_i, D_j)`` having the same query ``Q_i``. >>> target = torch.tensor([0, 1, 0, 1, 1]) >>> map = RetrievalMAP() # or some other retrieval metric - >>> map(indexes, preds, target) + >>> map(preds, target, indexes=indexes) tensor(0.6667) >>> # the previous instruction is roughly equivalent to From 8bb4830fc2771072c7ac80d32334d38c4c4e4f4f Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sun, 4 Apr 2021 13:04:41 +0200 Subject: [PATCH 12/33] added typo with new parameter frament_kwargs in MetricTester --- tests/helpers/testers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index e2593923dd9..68120869beb 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -203,8 +203,7 @@ def _functional_test( for i in range(NUM_BATCHES): extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} lightning_result = metric(preds[i], target[i], **extra_kwargs) - if not fragment_kwargs: - extra_kwargs = {k: v.cpu() for k, v in kwargs_update.items()} + extra_kwargs = {k: v.cpu() for k, v in (extra_kwargs if fragment_kwargs else kwargs_update).items()} sk_result = sk_metric(preds[i].cpu(), target[i].cpu(), **extra_kwargs) # assert its the same From e7c7e961d5163c84a10a7fef3304c0dc961fa14e Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sun, 4 Apr 2021 13:59:54 +0200 Subject: [PATCH 13/33] added typo in .cpu() conversion of non-tensor values --- tests/helpers/testers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 68120869beb..e8bc68f600a 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -203,7 +203,10 @@ def _functional_test( for i in range(NUM_BATCHES): extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} lightning_result = metric(preds[i], target[i], **extra_kwargs) - extra_kwargs = {k: v.cpu() for k, v in (extra_kwargs if fragment_kwargs else kwargs_update).items()} + extra_kwargs = { + k: v.cpu() if isinstance(v, Tensor) else v + for k, v in (extra_kwargs if fragment_kwargs else kwargs_update).items() + } sk_result = sk_metric(preds[i].cpu(), target[i].cpu(), **extra_kwargs) # assert its the same From 40fd75bfb511fefc93425d515aef1497bb9236df Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sun, 4 Apr 2021 14:45:11 +0200 Subject: [PATCH 14/33] improved test coverage --- tests/retrieval/helpers.py | 30 ++++++++++++++++++++++++++---- tests/retrieval/test_precision.py | 18 ++++++++++++++++++ tests/retrieval/test_recall.py | 18 ++++++++++++++++++ torchmetrics/utilities/data.py | 10 +--------- 4 files changed, 63 insertions(+), 13 deletions(-) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 39c1991bf0b..a6c93c84b78 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -119,15 +119,20 @@ def _compute_sklearn_metric( ] -_errors_test_class_metric_parameters_k = [ - "indexes, preds, target, message, metric_args", [ +_errors_test_functional_metric_parameters_k = [ + "preds, target, message, metric_args", [ ( - _input_retrieval_scores.index, _input_retrieval_scores.preds, _input_retrieval_scores.target, "`k` has to be a positive integer or None", {'k': -10}, ), + ( + _input_retrieval_scores.preds, + _input_retrieval_scores.target, + "`k` has to be a positive integer or None", + {'k': 4.0}, + ), ] ] @@ -207,6 +212,19 @@ def _compute_sklearn_metric( ] +_errors_test_class_metric_parameters_k = [ + "indexes, preds, target, message, metric_args", [ + ( + _input_retrieval_scores.index, + _input_retrieval_scores.preds, + _input_retrieval_scores.target, + "`k` has to be a positive integer or None", + {'k': -10}, + ), + ] +] + + _default_metric_class_input_arguments = [ "indexes, preds, target", [ ( @@ -233,6 +251,10 @@ def _compute_sklearn_metric( _input_retrieval_scores_extra.preds, _input_retrieval_scores_extra.target ), + ( + _input_retrieval_scores_no_target.preds, + _input_retrieval_scores_no_target.target + ), ] ] @@ -419,5 +441,5 @@ def run_functional_metric_arguments_test( metric_functional=metric_functional, message=message, exception_type=exception_type, - **kwargs_update, + kwargs_update=kwargs_update, ) diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index 008fbe18f49..6ea2dd547d9 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -23,6 +23,7 @@ _errors_test_class_metric_parameters, _errors_test_class_metric_parameters_k, _errors_test_functional_metric_parameters, + _errors_test_functional_metric_parameters_k, ) from torchmetrics.functional.retrieval.precision import retrieval_precision from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision @@ -182,3 +183,20 @@ def test_arguments_functional_metric( exception_type=ValueError, kwargs_update={}, ) + + @pytest.mark.parametrize(*_errors_test_functional_metric_parameters_k) + def test_additional_arguments_functional_metric( + self, + preds: Tensor, + target: Tensor, + message: str, + metric_args: dict, + ): + self.run_functional_metric_arguments_test( + preds=preds, + target=target, + metric_functional=retrieval_precision, + message=message, + exception_type=ValueError, + kwargs_update=metric_args, + ) diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index a281751813a..4c79478f664 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -23,6 +23,7 @@ _errors_test_class_metric_parameters, _errors_test_class_metric_parameters_k, _errors_test_functional_metric_parameters, + _errors_test_functional_metric_parameters_k, ) from torchmetrics.functional.retrieval.recall import retrieval_recall from torchmetrics.retrieval.retrieval_recall import RetrievalRecall @@ -181,3 +182,20 @@ def test_arguments_functional_metric( exception_type=ValueError, kwargs_update={}, ) + + @pytest.mark.parametrize(*_errors_test_functional_metric_parameters_k) + def test_additional_arguments_functional_metric( + self, + preds: Tensor, + target: Tensor, + message: str, + metric_args: dict, + ): + self.run_functional_metric_arguments_test( + preds=preds, + target=target, + metric_functional=retrieval_recall, + message=message, + exception_type=ValueError, + kwargs_update=metric_args, + ) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 46cbf6ee55b..383385be8a9 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -247,15 +247,7 @@ def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, >>> get_group_indexes(indexes) [tensor([0, 1, 2]), tensor([3, 4, 5, 6])] """ - - if not isinstance(indexes, (Tensor, np.ndarray)): - raise ValueError("`indexes` must be a torch tensor or numpy array") - - if not len(indexes.shape) == 1: - raise ValueError("`indexes` must have a single dimension") - - structure = tensor if isinstance(indexes, Tensor) else np.array - dtype = torch.long if isinstance(indexes, Tensor) else np.int64 + structure, dtype = (tensor, torch.long) if isinstance(indexes, Tensor) else (np.array, np.int64) res = dict() for i, _id in enumerate(indexes): From 7b3d2f82558ac743cd20644991f413bcfa786470 Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sun, 4 Apr 2021 19:07:34 +0200 Subject: [PATCH 15/33] improved test coverage --- tests/retrieval/helpers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index a6c93c84b78..324c67a296f 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -237,6 +237,11 @@ def _compute_sklearn_metric( _input_retrieval_scores_extra.preds, _input_retrieval_scores_extra.target, ), + ( + _input_retrieval_scores_no_target.indexes, + _input_retrieval_scores_no_target.preds, + _input_retrieval_scores_no_target.target, + ), ] ] From 01c43de17b5bdbc85c7bf189c8a25046ea4fe5da Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Sun, 4 Apr 2021 19:11:18 +0200 Subject: [PATCH 16/33] added check on Tensor class to avoid calling .cpu() on non-tensor values --- tests/helpers/testers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index e8bc68f600a..5788930a120 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -143,7 +143,7 @@ def _class_test( elif check_batch and not metric.dist_sync_on_step: batch_kwargs_update = { - k: v.cpu() + k: v.cpu() if isinstance(v, Tensor) else v for k, v in (batch_kwargs_update if fragment_kwargs else kwargs_update).items() } sk_batch_result = sk_metric(preds[i].cpu(), target[i].cpu(), **batch_kwargs_update) From bb62f519908f4dd8ec9776713a69689b7955ce8a Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Mon, 5 Apr 2021 17:16:55 +0200 Subject: [PATCH 17/33] improved doc and changed default values for 'empty_target_action' argument --- .../retrieval/mean_average_precision.py | 8 +++--- .../retrieval/mean_reciprocal_rank.py | 8 +++--- torchmetrics/retrieval/retrieval_metric.py | 27 +++++++++---------- torchmetrics/retrieval/retrieval_precision.py | 10 +++---- torchmetrics/retrieval/retrieval_recall.py | 10 +++---- 5 files changed, 30 insertions(+), 33 deletions(-) diff --git a/torchmetrics/retrieval/mean_average_precision.py b/torchmetrics/retrieval/mean_average_precision.py index c00ba459cd7..06b16ae0072 100644 --- a/torchmetrics/retrieval/mean_average_precision.py +++ b/torchmetrics/retrieval/mean_average_precision.py @@ -26,9 +26,9 @@ class RetrievalMAP(RetrievalMetric): Forward accepts - - ``indexes`` (long tensor): ``(N, ...)`` - ``preds`` (float tensor): ``(N, ...)`` - ``target`` (long or bool tensor): ``(N, ...)`` + - ``indexes`` (long tensor): ``(N, ...)`` ``indexes``, ``preds`` and ``target`` must have the same dimension. ``indexes`` indicate to which query a prediction belongs. @@ -39,10 +39,10 @@ class RetrievalMAP(RetrievalMetric): empty_target_action: Specify what to do with queries that do not have at least a positive ``target``. Choose from: - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned + - ``'neg'``: those queries count as ``0.0`` (default) + - ``'pos'``: those queries count as ``1.0`` + - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True diff --git a/torchmetrics/retrieval/mean_reciprocal_rank.py b/torchmetrics/retrieval/mean_reciprocal_rank.py index 6eaecf235ce..30ae4c66747 100644 --- a/torchmetrics/retrieval/mean_reciprocal_rank.py +++ b/torchmetrics/retrieval/mean_reciprocal_rank.py @@ -26,9 +26,9 @@ class RetrievalMRR(RetrievalMetric): Forward accepts - - ``indexes`` (long tensor): ``(N, ...)`` - ``preds`` (float tensor): ``(N, ...)`` - ``target`` (long or bool tensor): ``(N, ...)`` + - ``indexes`` (long tensor): ``(N, ...)`` ``indexes``, ``preds`` and ``target`` must have the same dimension. ``indexes`` indicate to which query a prediction belongs. @@ -39,10 +39,10 @@ class RetrievalMRR(RetrievalMetric): empty_target_action: Specify what to do with queries that do not have at least a positive ``target``. Choose from: - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned + - ``'neg'``: those queries count as ``0.0`` (default) + - ``'pos'``: those queries count as ``1.0`` + - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index a9663503544..23117f01b3b 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -30,9 +30,9 @@ class RetrievalMetric(Metric, ABC): Forward accepts - - ``indexes`` (long tensor): ``(N, ...)`` - ``preds`` (float tensor): ``(N, ...)`` - ``target`` (long or bool tensor): ``(N, ...)`` + - ``indexes`` (long tensor): ``(N, ...)`` `indexes`, `preds` and `target` must have the same dimension and will be flatten to single dimension once provided. @@ -46,10 +46,11 @@ class RetrievalMetric(Metric, ABC): empty_target_action: Specify what to do with queries that do not have at least a positive target. Choose from: - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned + - ``'neg'``: those queries count as ``0.0`` (default) + - ``'pos'``: those queries count as ``1.0`` + - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` + compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -65,7 +66,7 @@ class RetrievalMetric(Metric, ABC): def __init__( self, - empty_target_action: str = 'skip', + empty_target_action: str = 'neg', compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -78,7 +79,7 @@ def __init__( dist_sync_fn=dist_sync_fn ) - empty_target_action_options = ('error', 'skip', 'pos', 'neg') + empty_target_action_options = ('error', 'skip', 'neg', 'pos') if empty_target_action not in empty_target_action_options: raise ValueError(f"`empty_target_action` received a wrong value `{empty_target_action}`.") @@ -111,11 +112,9 @@ def compute(self) -> Tensor: target = torch.cat(self.target, dim=0) res = [] - kwargs = {'device': indexes.device, 'dtype': torch.float32} - groups = get_group_indexes(indexes) - for group in groups: + for group in groups: mini_preds = preds[group] mini_target = target[group] @@ -123,16 +122,14 @@ def compute(self) -> Tensor: if self.empty_target_action == 'error': raise ValueError("`compute` method was provided with a query with no positive target.") if self.empty_target_action == 'pos': - res.append(tensor(1.0, **kwargs)) + res.append(tensor(1.0)) elif self.empty_target_action == 'neg': - res.append(tensor(0.0, **kwargs)) + res.append(tensor(0.0)) else: # ensure list containt only float tensors - res.append(self._metric(mini_preds, mini_target).to(**kwargs)) + res.append(self._metric(mini_preds, mini_target)) - if len(res) > 0: - return torch.stack(res).mean() - return tensor(0.0, **kwargs) + return torch.stack([x.to(preds) for x in res]).mean() if len(res) else tensor(0.0).to(preds) @abstractmethod def _metric(self, preds: Tensor, target: Tensor) -> Tensor: diff --git a/torchmetrics/retrieval/retrieval_precision.py b/torchmetrics/retrieval/retrieval_precision.py index acfbeb0eb2e..3695fcf3bf2 100644 --- a/torchmetrics/retrieval/retrieval_precision.py +++ b/torchmetrics/retrieval/retrieval_precision.py @@ -28,9 +28,9 @@ class RetrievalPrecision(RetrievalMetric): Forward accepts: - - ``indexes`` (long tensor): ``(N, ...)`` - ``preds`` (float tensor): ``(N, ...)`` - ``target`` (long or bool tensor): ``(N, ...)`` + - ``indexes`` (long tensor): ``(N, ...)`` ``indexes``, ``preds`` and ``target`` must have the same dimension. ``indexes`` indicate to which query a prediction belongs. @@ -41,10 +41,10 @@ class RetrievalPrecision(RetrievalMetric): empty_target_action: Specify what to do with queries that do not have at least a positive ``target``. Choose from: - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned + - ``'neg'``: those queries count as ``0.0`` (default) + - ``'pos'``: those queries count as ``1.0`` + - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True @@ -71,7 +71,7 @@ class RetrievalPrecision(RetrievalMetric): def __init__( self, - empty_target_action: str = 'skip', + empty_target_action: str = 'neg', compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, diff --git a/torchmetrics/retrieval/retrieval_recall.py b/torchmetrics/retrieval/retrieval_recall.py index a7752544f17..09fc3d2d695 100644 --- a/torchmetrics/retrieval/retrieval_recall.py +++ b/torchmetrics/retrieval/retrieval_recall.py @@ -28,9 +28,9 @@ class RetrievalRecall(RetrievalMetric): Forward accepts: - - ``indexes`` (long tensor): ``(N, ...)`` - ``preds`` (float tensor): ``(N, ...)`` - ``target`` (long or bool tensor): ``(N, ...)`` + - ``indexes`` (long tensor): ``(N, ...)`` ``indexes``, ``preds`` and ``target`` must have the same dimension. ``indexes`` indicate to which query a prediction belongs. @@ -41,10 +41,10 @@ class RetrievalRecall(RetrievalMetric): empty_target_action: Specify what to do with queries that do not have at least a positive ``target``. Choose from: - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned + - ``'neg'``: those queries count as ``0.0`` (default) + - ``'pos'``: those queries count as ``1.0`` + - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True @@ -71,7 +71,7 @@ class RetrievalRecall(RetrievalMetric): def __init__( self, - empty_target_action: str = 'skip', + empty_target_action: str = 'neg', compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, From 9d65265f7890acb3bd2c11bcae805e5b7333f3c1 Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Mon, 5 Apr 2021 20:10:28 +0200 Subject: [PATCH 18/33] implemented fall-out --- tests/retrieval/helpers.py | 9 +- tests/retrieval/test_fallout.py | 204 ++++++++++++++++++ torchmetrics/__init__.py | 8 +- torchmetrics/functional/retrieval/__init__.py | 1 + torchmetrics/functional/retrieval/fall_out.py | 59 +++++ torchmetrics/retrieval/__init__.py | 1 + torchmetrics/retrieval/retrieval_fallout.py | 127 +++++++++++ torchmetrics/retrieval/retrieval_metric.py | 3 +- 8 files changed, 407 insertions(+), 5 deletions(-) create mode 100644 tests/retrieval/test_fallout.py create mode 100644 torchmetrics/functional/retrieval/fall_out.py create mode 100644 torchmetrics/retrieval/retrieval_fallout.py diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 324c67a296f..7d7764f8cdc 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -42,6 +42,7 @@ def _compute_sklearn_metric( indexes: np.ndarray = None, metric: Callable = None, empty_target_action: str = "skip", + reverse: bool = False, **kwargs ) -> Tensor: """ Compute metric with multiple iterations over every query predictions set. """ @@ -68,7 +69,7 @@ def _compute_sklearn_metric( for group in groups: trg, pds = target[group], preds[group] - if trg.sum() == 0: + if ((1 - trg) if reverse else trg).sum() == 0: if empty_target_action == 'skip': pass elif empty_target_action == 'pos': @@ -327,8 +328,9 @@ def run_class_metric_test( sk_metric: Callable, dist_sync_on_step: bool, metric_args: dict, + reverse: bool = True, ): - _sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, **metric_args) + _sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, reverse=reverse, **metric_args) super().run_class_metric_test( ddp=ddp, @@ -349,10 +351,11 @@ def run_functional_metric_test( metric_functional: Callable, sk_metric: Callable, metric_args: dict, + reverse: bool = True, **kwargs, ): # action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive. - _sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, **metric_args) + _sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, reverse=reverse, **metric_args) super().run_functional_metric_test( preds=preds, diff --git a/tests/retrieval/test_fallout.py b/tests/retrieval/test_fallout.py new file mode 100644 index 00000000000..aef9ba9e35c --- /dev/null +++ b/tests/retrieval/test_fallout.py @@ -0,0 +1,204 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest +from torch import Tensor + +from tests.helpers import seed_all +from tests.retrieval.helpers import ( + RetrievalMetricTester, + _default_metric_class_input_arguments, + _default_metric_functional_input_arguments, + _errors_test_class_metric_parameters, + _errors_test_class_metric_parameters_k, + _errors_test_functional_metric_parameters, + _errors_test_functional_metric_parameters_k, +) +from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out +from torchmetrics.retrieval.retrieval_fallout import RetrievalFallOut + +seed_all(42) + + +def _fallout_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): + """ + Didn't find a reliable implementation of Fall-out in Information Retrieval, so, + reimplementing here. See wikipedia for more information about definition. + """ + assert target.shape == preds.shape + assert len(target.shape) == 1 # works only with single dimension inputs + + if k is None: + k = len(preds) + + target = 1 - target + if target.sum(): + order_indexes = np.argsort(preds, axis=0)[::-1] + relevant = np.sum(target[order_indexes][:k]) + return relevant * 1.0 / target.sum() + else: + return np.NaN + + +class TestFallOut(RetrievalMetricTester): + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos']) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_class_metric( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + k: int, + ): + metric_args = {'empty_target_action': empty_target_action, 'k': k} + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalFallOut, + sk_metric=_fallout_at_k, + dist_sync_on_step=dist_sync_on_step, + reverse=True, + metric_args=metric_args, + ) + + @pytest.mark.parametrize(*_default_metric_functional_input_arguments) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + def test_functional_metric( + self, + preds: Tensor, + target: Tensor, + k: int, + ): + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=retrieval_fall_out, + sk_metric=_fallout_at_k, + reverse=True, + metric_args={}, + k=k, + ) + + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_precision_cpu( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + ): + self.run_precision_test_cpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalFallOut, + metric_functional=retrieval_fall_out, + ) + + @pytest.mark.parametrize(*_default_metric_class_input_arguments) + def test_precision_gpu( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + ): + self.run_precision_test_gpu( + indexes=indexes, + preds=preds, + target=target, + metric_module=RetrievalFallOut, + metric_functional=retrieval_fall_out, + ) + + @pytest.mark.parametrize(*_errors_test_class_metric_parameters) + def test_arguments_class_metric( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + message: str, + metric_args: dict, + ): + self.run_metric_class_arguments_test( + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalFallOut, + message=message, + metric_args=metric_args, + exception_type=ValueError, + kwargs_update={}, + ) + + @pytest.mark.parametrize(*_errors_test_class_metric_parameters_k) + def test_additional_arguments_class_metric( + self, + indexes: Tensor, + preds: Tensor, + target: Tensor, + message: str, + metric_args: dict, + ): + self.run_metric_class_arguments_test( + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalFallOut, + message=message, + metric_args=metric_args, + exception_type=ValueError, + kwargs_update={}, + ) + + @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) + def test_arguments_functional_metric( + self, + preds: Tensor, + target: Tensor, + message: str, + ): + self.run_functional_metric_arguments_test( + preds=preds, + target=target, + metric_functional=retrieval_fall_out, + message=message, + exception_type=ValueError, + kwargs_update={}, + ) + + @pytest.mark.parametrize(*_errors_test_functional_metric_parameters_k) + def test_additional_arguments_functional_metric( + self, + preds: Tensor, + target: Tensor, + message: str, + metric_args: dict, + ): + self.run_functional_metric_arguments_test( + preds=preds, + target=target, + metric_functional=retrieval_fall_out, + message=message, + exception_type=ValueError, + kwargs_update=metric_args, + ) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 1edcd153a1d..284daa4b064 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -49,5 +49,11 @@ MeanSquaredLogError, R2Score, ) -from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR, RetrievalPrecision, RetrievalRecall # noqa: F401 E402 +from torchmetrics.retrieval import ( # noqa: F401 E402 + RetrievalFallOut, + RetrievalMAP, + RetrievalMRR, + RetrievalPrecision, + RetrievalRecall, +) from torchmetrics.wrappers import BootStrapper # noqa: F401 E402 diff --git a/torchmetrics/functional/retrieval/__init__.py b/torchmetrics/functional/retrieval/__init__.py index 1bac1fe5c18..b078465c886 100644 --- a/torchmetrics/functional/retrieval/__init__.py +++ b/torchmetrics/functional/retrieval/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401 +from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out # noqa: F401 from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401 from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401 from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401 diff --git a/torchmetrics/functional/retrieval/fall_out.py b/torchmetrics/functional/retrieval/fall_out.py new file mode 100644 index 00000000000..4225759cc7c --- /dev/null +++ b/torchmetrics/functional/retrieval/fall_out.py @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor, tensor + +from torchmetrics.utilities.checks import _check_retrieval_functional_inputs + + +def retrieval_fall_out(preds: Tensor, target: Tensor, k: int = None) -> Tensor: + """ + Computes the Fall-out (for information retrieval), + as explained `here `__. + Fall-out is the fraction of non-relevant documents retrieved among all the non-relevant documents. + + ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, + ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`, + otherwise an error is raised. If you want to measure Fall-out@K, ``k`` must be a positive integer. + + Args: + preds: estimated probabilities of each document to be relevant. + target: ground truth about each document being relevant or not. + k: consider only the top k elements (default: None) + + Returns: + a single-value tensor with the fall-out (at ``k``) of the predictions ``preds`` w.r.t. the labels ``target``. + + Example: + >>> from torchmetrics.functional import retrieval_fall_out + >>> preds = tensor([0.2, 0.3, 0.5]) + >>> target = tensor([True, False, True]) + >>> retrieval_recall(preds, target, k=2) + tensor(1.) + """ + preds, target = _check_retrieval_functional_inputs(preds, target) + + if k is None: + k = preds.shape[-1] + + if not (isinstance(k, int) and k > 0): + raise ValueError("`k` has to be a positive integer or None") + + target = (1 - target) + + if not target.sum(): + return tensor(0.0, device=preds.device) + + relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum().float() + return relevant / target.sum() diff --git a/torchmetrics/retrieval/__init__.py b/torchmetrics/retrieval/__init__.py index b9b250bd805..235f7068257 100644 --- a/torchmetrics/retrieval/__init__.py +++ b/torchmetrics/retrieval/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from torchmetrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401 from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR # noqa: F401 +from torchmetrics.retrieval.retrieval_fallout import RetrievalFallOut # noqa: F401 from torchmetrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401 from torchmetrics.retrieval.retrieval_precision import RetrievalPrecision # noqa: F401 from torchmetrics.retrieval.retrieval_recall import RetrievalRecall # noqa: F401 diff --git a/torchmetrics/retrieval/retrieval_fallout.py b/torchmetrics/retrieval/retrieval_fallout.py new file mode 100644 index 00000000000..0f3cdc12952 --- /dev/null +++ b/torchmetrics/retrieval/retrieval_fallout.py @@ -0,0 +1,127 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +import torch +from torch import Tensor, tensor + +from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out +from torchmetrics.retrieval.retrieval_metric import RetrievalMetric +from torchmetrics.utilities.data import get_group_indexes + + +class RetrievalFallOut(RetrievalMetric): + """ + Computes `Fall-out + `__. + + Works with binary target data. Accepts float predictions from a model output. + + Forward accepts: + + - ``preds`` (float tensor): ``(N, ...)`` + - ``target`` (long or bool tensor): ``(N, ...)`` + - ``indexes`` (long tensor): ``(N, ...)`` + + ``indexes``, ``preds`` and ``target`` must have the same dimension. + ``indexes`` indicate to which query a prediction belongs. + Predictions will be first grouped by ``indexes`` and then `Fall-out` will be computed as the mean + of the `Fall-out` over each query. + + Args: + empty_target_action: + Specify what to do with queries that do not have at least a negative ``target``. Choose from: + + - ``'neg'``: those queries count as ``0.0`` (default) + - ``'pos'``: those queries count as ``1.0`` + - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned + - ``'error'``: raise a ``ValueError`` + + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects + the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + k: consider only the top k elements for each query. default: None + + Example: + >>> from torchmetrics import RetrievalFallOut + >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) + >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) + >>> target = tensor([False, False, True, False, True, False, True]) + >>> fo = RetrievalFallOut(k=2) + >>> fo(preds, target, indexes=indexes) + tensor(0.5000) + """ + + def __init__( + self, + empty_target_action: str = 'pos', + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + k: int = None + ): + super().__init__( + empty_target_action=empty_target_action, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn + ) + + if (k is not None) and not (isinstance(k, int) and k > 0): + raise ValueError("`k` has to be a positive integer or None") + self.k = k + + def compute(self) -> Tensor: + """ + First concat state `indexes`, `preds` and `target` since they were stored as lists. After that, + compute list of groups that will help in keeping together predictions about the same query. + Finally, for each group compute the `_metric` if the number of negative targets is at least + 1, otherwise behave as specified by `self.empty_target_action`. + """ + indexes = torch.cat(self.indexes, dim=0) + preds = torch.cat(self.preds, dim=0) + target = torch.cat(self.target, dim=0) + + res = [] + groups = get_group_indexes(indexes) + + for group in groups: + mini_preds = preds[group] + mini_target = target[group] + + if not (1 - mini_target).sum(): + if self.empty_target_action == 'error': + raise ValueError("`compute` method was provided with a query with no negative target.") + if self.empty_target_action == 'pos': + res.append(tensor(1.0)) + elif self.empty_target_action == 'neg': + res.append(tensor(0.0)) + else: + # ensure list containt only float tensors + res.append(self._metric(mini_preds, mini_target)) + + return torch.stack([x.to(preds) for x in res]).mean() if len(res) else tensor(0.0).to(preds) + + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: + return retrieval_fall_out(preds, target, k=self.k) diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index 23117f01b3b..5b48f5b63d0 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -44,7 +44,8 @@ class RetrievalMetric(Metric, ABC): Args: empty_target_action: - Specify what to do with queries that do not have at least a positive target. Choose from: + Specify what to do with queries that do not have at least a positive + or negative (depend on metric) target. Choose from: - ``'neg'``: those queries count as ``0.0`` (default) - ``'pos'``: those queries count as ``1.0`` From 148a908d7054ee79c27e9defb4bdc46f8c7e4796 Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Mon, 5 Apr 2021 20:36:35 +0200 Subject: [PATCH 19/33] refactored tests lists --- tests/retrieval/helpers.py | 57 +++++++++++++++---- tests/retrieval/inputs.py | 6 ++ tests/retrieval/test_map.py | 16 ++++-- tests/retrieval/test_mrr.py | 16 ++++-- tests/retrieval/test_precision.py | 53 +++++------------ tests/retrieval/test_recall.py | 53 +++++------------ .../functional/retrieval/average_precision.py | 2 +- .../functional/retrieval/precision.py | 2 +- torchmetrics/functional/retrieval/recall.py | 2 +- .../functional/retrieval/reciprocal_rank.py | 2 +- 10 files changed, 103 insertions(+), 106 deletions(-) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 324c67a296f..ed1e9884267 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Union +from typing import Callable, Tuple, Union import numpy as np import pytest @@ -24,6 +24,7 @@ from tests.helpers.testers import Metric, MetricTester from tests.retrieval.inputs import ( _input_retrieval_scores, + _input_retrieval_scores_all_target, _input_retrieval_scores_empty, _input_retrieval_scores_extra, _input_retrieval_scores_mismatching_sizes, @@ -84,36 +85,48 @@ def _compute_sklearn_metric( return np.array(0.0) -_errors_test_functional_metric_parameters = [ - "preds, target, message", [ +def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: + """Concat tests composed by a string and a list of arguments.""" + assert len(tests), "cannot concatenate " + assert all([tests[0][0] == x[0] for x in tests[1:]]), "the header must be the same for all tests" + return (tests[0][0], sum([x[1] for x in tests], [])) + + +_errors_test_functional_metric_parameters_default = [ + "preds, target, message, metric_args", [ # check input shapes are consistent (func) ( _input_retrieval_scores_mismatching_sizes_func.preds, _input_retrieval_scores_mismatching_sizes_func.target, "`preds` and `target` must be of the same shape", + {}, ), # check input tensors are not empty ( _input_retrieval_scores_empty.preds, _input_retrieval_scores_empty.target, "`preds` and `target` must be non-empty and non-scalar tensors", + {}, ), # check on input dtypes ( _input_retrieval_scores.preds.bool(), _input_retrieval_scores.target, "`preds` must be a tensor of floats", + {}, ), ( _input_retrieval_scores.preds, _input_retrieval_scores.target.float(), "`target` must be a tensor of booleans or integers", + {}, ), # check targets are between 0 and 1 ( _input_retrieval_scores_wrong_targets.preds, _input_retrieval_scores_wrong_targets.target, "`target` must contain `binary` values", + {}, ), ] ] @@ -137,16 +150,9 @@ def _compute_sklearn_metric( ] -_errors_test_class_metric_parameters = [ +_errors_test_class_metric_parameters_no_pos_target = [ "indexes, preds, target, message, metric_args", [ - ( - None, - _input_retrieval_scores.preds, - _input_retrieval_scores.target, - "`indexes` cannot be None", - {'empty_target_action': "error"}, - ), - # check when error when there are not positive targets + # check when error when there are no positive targets ( _input_retrieval_scores_no_target.indexes, _input_retrieval_scores_no_target.preds, @@ -154,6 +160,33 @@ def _compute_sklearn_metric( "`compute` method was provided with a query with no positive target.", {'empty_target_action': "error"}, ), + ] +] + + +_errors_test_class_metric_parameters_no_neg_target = [ + "indexes, preds, target, message, metric_args", [ + # check when error when there are no negative targets + ( + _input_retrieval_scores_all_target.indexes, + _input_retrieval_scores_all_target.preds, + _input_retrieval_scores_all_target.target, + "`compute` method was provided with a query with no negative target.", + {'empty_target_action': "error"}, + ), + ] +] + + +_errors_test_class_metric_parameters_default = [ + "indexes, preds, target, message, metric_args", [ + ( + None, + _input_retrieval_scores.preds, + _input_retrieval_scores.target, + "`indexes` cannot be None", + {'empty_target_action': "error"}, + ), # check when input arguments are invalid ( _input_retrieval_scores.indexes, diff --git a/tests/retrieval/inputs.py b/tests/retrieval/inputs.py index 1cb7663361d..addcc56c787 100644 --- a/tests/retrieval/inputs.py +++ b/tests/retrieval/inputs.py @@ -39,6 +39,12 @@ target=torch.randint(high=1, size=(NUM_BATCHES, BATCH_SIZE)), ) +_input_retrieval_scores_all_target = Input( + indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.randint(low=1, high=2, size=(NUM_BATCHES, BATCH_SIZE)), +) + _input_retrieval_scores_empty = Input( indexes=torch.randint(high=10, size=[0]), preds=torch.rand(0), diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index 276bb0d0248..63275f71797 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -18,10 +18,12 @@ from tests.helpers import seed_all from tests.retrieval.helpers import ( RetrievalMetricTester, + _concat_tests, _default_metric_class_input_arguments, _default_metric_functional_input_arguments, - _errors_test_class_metric_parameters, - _errors_test_functional_metric_parameters, + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_functional_metric_parameters_default, ) from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision from torchmetrics.retrieval.mean_average_precision import RetrievalMAP @@ -101,7 +103,10 @@ def test_precision_gpu( metric_functional=retrieval_average_precision, ) - @pytest.mark.parametrize(*_errors_test_class_metric_parameters) + @pytest.mark.parametrize(*_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + )) def test_arguments_class_metric( self, indexes: Tensor, @@ -121,12 +126,13 @@ def test_arguments_class_metric( kwargs_update={}, ) - @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) + @pytest.mark.parametrize(*_errors_test_functional_metric_parameters_default) def test_arguments_functional_metric( self, preds: Tensor, target: Tensor, message: str, + metric_args: dict, ): self.run_functional_metric_arguments_test( preds=preds, @@ -134,5 +140,5 @@ def test_arguments_functional_metric( metric_functional=retrieval_average_precision, message=message, exception_type=ValueError, - kwargs_update={}, + kwargs_update=metric_args, ) diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 7fb409bd662..110b8242eb2 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -20,10 +20,12 @@ from tests.helpers import seed_all from tests.retrieval.helpers import ( RetrievalMetricTester, + _concat_tests, _default_metric_class_input_arguments, _default_metric_functional_input_arguments, - _errors_test_class_metric_parameters, - _errors_test_functional_metric_parameters, + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_functional_metric_parameters_default, ) from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank from torchmetrics.retrieval.mean_reciprocal_rank import RetrievalMRR @@ -125,7 +127,10 @@ def test_precision_gpu( metric_functional=retrieval_reciprocal_rank, ) - @pytest.mark.parametrize(*_errors_test_class_metric_parameters) + @pytest.mark.parametrize(*_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + )) def test_arguments_class_metric( self, indexes: Tensor, @@ -145,12 +150,13 @@ def test_arguments_class_metric( kwargs_update={}, ) - @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) + @pytest.mark.parametrize(*_errors_test_functional_metric_parameters_default) def test_arguments_functional_metric( self, preds: Tensor, target: Tensor, message: str, + metric_args: dict, ): self.run_functional_metric_arguments_test( preds=preds, @@ -158,5 +164,5 @@ def test_arguments_functional_metric( metric_functional=retrieval_reciprocal_rank, message=message, exception_type=ValueError, - kwargs_update={}, + kwargs_update=metric_args, ) diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index 6ea2dd547d9..70af5da9458 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -18,11 +18,13 @@ from tests.helpers import seed_all from tests.retrieval.helpers import ( RetrievalMetricTester, + _concat_tests, _default_metric_class_input_arguments, _default_metric_functional_input_arguments, - _errors_test_class_metric_parameters, + _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_k, - _errors_test_functional_metric_parameters, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_functional_metric_parameters_default, _errors_test_functional_metric_parameters_k, ) from torchmetrics.functional.retrieval.precision import retrieval_precision @@ -128,7 +130,11 @@ def test_precision_gpu( metric_functional=retrieval_precision, ) - @pytest.mark.parametrize(*_errors_test_class_metric_parameters) + @pytest.mark.parametrize(*_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_class_metric_parameters_k, + )) def test_arguments_class_metric( self, indexes: Tensor, @@ -148,48 +154,15 @@ def test_arguments_class_metric( kwargs_update={}, ) - @pytest.mark.parametrize(*_errors_test_class_metric_parameters_k) - def test_additional_arguments_class_metric( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, - ): - self.run_metric_class_arguments_test( - indexes=indexes, - preds=preds, - target=target, - metric_class=RetrievalPrecision, - message=message, - metric_args=metric_args, - exception_type=ValueError, - kwargs_update={}, - ) - - @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) + @pytest.mark.parametrize(*_concat_tests( + _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_k, + )) def test_arguments_functional_metric( self, preds: Tensor, target: Tensor, message: str, - ): - self.run_functional_metric_arguments_test( - preds=preds, - target=target, - metric_functional=retrieval_precision, - message=message, - exception_type=ValueError, - kwargs_update={}, - ) - - @pytest.mark.parametrize(*_errors_test_functional_metric_parameters_k) - def test_additional_arguments_functional_metric( - self, - preds: Tensor, - target: Tensor, - message: str, metric_args: dict, ): self.run_functional_metric_arguments_test( diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index 4c79478f664..f180c1bab45 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -18,11 +18,13 @@ from tests.helpers import seed_all from tests.retrieval.helpers import ( RetrievalMetricTester, + _concat_tests, _default_metric_class_input_arguments, _default_metric_functional_input_arguments, - _errors_test_class_metric_parameters, + _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_k, - _errors_test_functional_metric_parameters, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_functional_metric_parameters_default, _errors_test_functional_metric_parameters_k, ) from torchmetrics.functional.retrieval.recall import retrieval_recall @@ -127,7 +129,11 @@ def test_precision_gpu( metric_functional=retrieval_recall, ) - @pytest.mark.parametrize(*_errors_test_class_metric_parameters) + @pytest.mark.parametrize(*_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_class_metric_parameters_k, + )) def test_arguments_class_metric( self, indexes: Tensor, @@ -147,48 +153,15 @@ def test_arguments_class_metric( kwargs_update={}, ) - @pytest.mark.parametrize(*_errors_test_class_metric_parameters_k) - def test_additional_arguments_class_metric( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, - ): - self.run_metric_class_arguments_test( - indexes=indexes, - preds=preds, - target=target, - metric_class=RetrievalRecall, - message=message, - metric_args=metric_args, - exception_type=ValueError, - kwargs_update={}, - ) - - @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) + @pytest.mark.parametrize(*_concat_tests( + _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_k, + )) def test_arguments_functional_metric( self, preds: Tensor, target: Tensor, message: str, - ): - self.run_functional_metric_arguments_test( - preds=preds, - target=target, - metric_functional=retrieval_recall, - message=message, - exception_type=ValueError, - kwargs_update={}, - ) - - @pytest.mark.parametrize(*_errors_test_functional_metric_parameters_k) - def test_additional_arguments_functional_metric( - self, - preds: Tensor, - target: Tensor, - message: str, metric_args: dict, ): self.run_functional_metric_arguments_test( diff --git a/torchmetrics/functional/retrieval/average_precision.py b/torchmetrics/functional/retrieval/average_precision.py index 4e4672b91c7..c8bf87757f4 100644 --- a/torchmetrics/functional/retrieval/average_precision.py +++ b/torchmetrics/functional/retrieval/average_precision.py @@ -42,7 +42,7 @@ def retrieval_average_precision(preds: Tensor, target: Tensor) -> Tensor: """ preds, target = _check_retrieval_functional_inputs(preds, target) - if target.sum() == 0: + if not target.sum(): return tensor(0.0, device=preds.device) target = target[torch.argsort(preds, dim=-1, descending=True)] diff --git a/torchmetrics/functional/retrieval/precision.py b/torchmetrics/functional/retrieval/precision.py index d896697d358..0005d9f4dd3 100644 --- a/torchmetrics/functional/retrieval/precision.py +++ b/torchmetrics/functional/retrieval/precision.py @@ -49,7 +49,7 @@ def retrieval_precision(preds: Tensor, target: Tensor, k: int = None) -> Tensor: if not (isinstance(k, int) and k > 0): raise ValueError("`k` has to be a positive integer or None") - if target.sum() == 0: + if not target.sum(): return tensor(0.0, device=preds.device) relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum().float() diff --git a/torchmetrics/functional/retrieval/recall.py b/torchmetrics/functional/retrieval/recall.py index 5f7dcb9b3b2..7207027c8a3 100644 --- a/torchmetrics/functional/retrieval/recall.py +++ b/torchmetrics/functional/retrieval/recall.py @@ -50,7 +50,7 @@ def retrieval_recall(preds: Tensor, target: Tensor, k: int = None) -> Tensor: if not (isinstance(k, int) and k > 0): raise ValueError("`k` has to be a positive integer or None") - if target.sum() == 0: + if not target.sum(): return tensor(0.0, device=preds.device) relevant = target[torch.argsort(preds, dim=-1, descending=True)][:k].sum().float() diff --git a/torchmetrics/functional/retrieval/reciprocal_rank.py b/torchmetrics/functional/retrieval/reciprocal_rank.py index 3daed08ac33..35755774969 100644 --- a/torchmetrics/functional/retrieval/reciprocal_rank.py +++ b/torchmetrics/functional/retrieval/reciprocal_rank.py @@ -42,7 +42,7 @@ def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor: """ preds, target = _check_retrieval_functional_inputs(preds, target) - if target.sum() == 0: + if not target.sum(): return tensor(0.0, device=preds.device) target = target[torch.argsort(preds, dim=-1, descending=True)] From 57ec03bf9b7d51b5c10bc38ca2d58c00c5c0812e Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Mon, 5 Apr 2021 20:54:32 +0200 Subject: [PATCH 20/33] refactoring and fixed typos in tests --- docs/source/references/functional.rst | 7 +++ docs/source/references/modules.rst | 7 +++ tests/retrieval/helpers.py | 9 ++-- tests/retrieval/test_fallout.py | 53 +++++-------------- torchmetrics/functional/retrieval/fall_out.py | 2 +- torchmetrics/functional/retrieval/recall.py | 2 +- 6 files changed, 32 insertions(+), 48 deletions(-) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 9fda00135a1..3d8619ffafb 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -266,3 +266,10 @@ retrieval_recall [func] .. autofunction:: torchmetrics.functional.retrieval_recall :noindex: + + +retrieval_fall_out [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.retrieval_fall_out + :noindex: diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 9d78092aa65..7459f87ff78 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -350,6 +350,13 @@ RetrievalRecall :noindex: +RetrievalFallOut +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.RetrievalFallOut + :noindex: + + ******** Wrappers ******** diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index df698395f60..5e236644635 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -361,9 +361,9 @@ def run_class_metric_test( sk_metric: Callable, dist_sync_on_step: bool, metric_args: dict, - reverse: bool = True, + reverse: bool = False, ): - _sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, reverse=reverse, **metric_args) + _sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, reverse=reverse, **metric_args) super().run_class_metric_test( ddp=ddp, @@ -384,10 +384,9 @@ def run_functional_metric_test( metric_functional: Callable, sk_metric: Callable, metric_args: dict, - reverse: bool = True, + reverse: bool = False, **kwargs, ): - # action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive. _sk_metric_adapted = partial(_compute_sklearn_metric, metric=sk_metric, reverse=reverse, **metric_args) super().run_functional_metric_test( @@ -408,7 +407,6 @@ def run_precision_test_cpu( metric_module: Metric, metric_functional: Callable, ): - # action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive. def metric_functional_ignore_indexes(preds, target, indexes): return metric_functional(preds, target) @@ -432,7 +430,6 @@ def run_precision_test_gpu( if not torch.cuda.is_available(): pytest.skip() - # action on functional version of IR metrics is to return `tensor(0.0)` if not target is positive. def metric_functional_ignore_indexes(preds, target, indexes): return metric_functional(preds, target) diff --git a/tests/retrieval/test_fallout.py b/tests/retrieval/test_fallout.py index aef9ba9e35c..70622a24e67 100644 --- a/tests/retrieval/test_fallout.py +++ b/tests/retrieval/test_fallout.py @@ -18,11 +18,13 @@ from tests.helpers import seed_all from tests.retrieval.helpers import ( RetrievalMetricTester, + _concat_tests, _default_metric_class_input_arguments, _default_metric_functional_input_arguments, - _errors_test_class_metric_parameters, + _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_k, - _errors_test_functional_metric_parameters, + _errors_test_class_metric_parameters_no_neg_target, + _errors_test_functional_metric_parameters_default, _errors_test_functional_metric_parameters_k, ) from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out @@ -130,7 +132,11 @@ def test_precision_gpu( metric_functional=retrieval_fall_out, ) - @pytest.mark.parametrize(*_errors_test_class_metric_parameters) + @pytest.mark.parametrize(*_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_neg_target, + _errors_test_class_metric_parameters_k, + )) def test_arguments_class_metric( self, indexes: Tensor, @@ -150,48 +156,15 @@ def test_arguments_class_metric( kwargs_update={}, ) - @pytest.mark.parametrize(*_errors_test_class_metric_parameters_k) - def test_additional_arguments_class_metric( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, - ): - self.run_metric_class_arguments_test( - indexes=indexes, - preds=preds, - target=target, - metric_class=RetrievalFallOut, - message=message, - metric_args=metric_args, - exception_type=ValueError, - kwargs_update={}, - ) - - @pytest.mark.parametrize(*_errors_test_functional_metric_parameters) + @pytest.mark.parametrize(*_concat_tests( + _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_k, + )) def test_arguments_functional_metric( self, preds: Tensor, target: Tensor, message: str, - ): - self.run_functional_metric_arguments_test( - preds=preds, - target=target, - metric_functional=retrieval_fall_out, - message=message, - exception_type=ValueError, - kwargs_update={}, - ) - - @pytest.mark.parametrize(*_errors_test_functional_metric_parameters_k) - def test_additional_arguments_functional_metric( - self, - preds: Tensor, - target: Tensor, - message: str, metric_args: dict, ): self.run_functional_metric_arguments_test( diff --git a/torchmetrics/functional/retrieval/fall_out.py b/torchmetrics/functional/retrieval/fall_out.py index 4225759cc7c..43a732021d9 100644 --- a/torchmetrics/functional/retrieval/fall_out.py +++ b/torchmetrics/functional/retrieval/fall_out.py @@ -39,7 +39,7 @@ def retrieval_fall_out(preds: Tensor, target: Tensor, k: int = None) -> Tensor: >>> from torchmetrics.functional import retrieval_fall_out >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) - >>> retrieval_recall(preds, target, k=2) + >>> retrieval_fall_out(preds, target, k=2) tensor(1.) """ preds, target = _check_retrieval_functional_inputs(preds, target) diff --git a/torchmetrics/functional/retrieval/recall.py b/torchmetrics/functional/retrieval/recall.py index 7207027c8a3..16aaff85d4d 100644 --- a/torchmetrics/functional/retrieval/recall.py +++ b/torchmetrics/functional/retrieval/recall.py @@ -21,7 +21,7 @@ def retrieval_recall(preds: Tensor, target: Tensor, k: int = None) -> Tensor: """ Computes the recall metric (for information retrieval), as explained `here `__. - Recall is the fraction of relevant documents among all the relevant documents. + Recall is the fraction of relevant documents retrieved among all the relevant documents. ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, ``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`, From 62c93f52cf180339ce70ca4629aa85e972c46c44 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 6 Apr 2021 10:04:05 +0200 Subject: [PATCH 21/33] formatting --- tests/helpers/testers.py | 14 +--- tests/retrieval/helpers.py | 104 +++++++++++++++++------------- tests/retrieval/test_map.py | 10 +-- tests/retrieval/test_mrr.py | 11 ++-- tests/retrieval/test_precision.py | 22 ++++--- tests/retrieval/test_recall.py | 22 ++++--- torchmetrics/utilities/checks.py | 2 +- 7 files changed, 100 insertions(+), 85 deletions(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 5788930a120..2fa72fd2df9 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -404,12 +404,7 @@ def run_precision_test_cpu( target when running update on the metric. """ _assert_half_support( - metric_module(**metric_args), - metric_functional, - preds, - target, - device="cpu", - **kwargs_update + metric_module(**metric_args), metric_functional, preds, target, device="cpu", **kwargs_update ) def run_precision_test_gpu( @@ -432,12 +427,7 @@ def run_precision_test_gpu( target when running update on the metric. """ _assert_half_support( - metric_module(**metric_args), - metric_functional, - preds, - target, - device="cuda", - **kwargs_update + metric_module(**metric_args), metric_functional, preds, target, device="cuda", **kwargs_update ) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index ed1e9884267..40c92d6f60e 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -93,7 +93,8 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: _errors_test_functional_metric_parameters_default = [ - "preds, target, message, metric_args", [ + "preds, target, message, metric_args", + [ # check input shapes are consistent (func) ( _input_retrieval_scores_mismatching_sizes_func.preds, @@ -131,61 +132,71 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: ] ] - _errors_test_functional_metric_parameters_k = [ - "preds, target, message, metric_args", [ + "preds, target, message, metric_args", + [ ( _input_retrieval_scores.preds, _input_retrieval_scores.target, "`k` has to be a positive integer or None", - {'k': -10}, + { + 'k': -10 + }, ), ( _input_retrieval_scores.preds, _input_retrieval_scores.target, "`k` has to be a positive integer or None", - {'k': 4.0}, + { + 'k': 4.0 + }, ), ] ] - _errors_test_class_metric_parameters_no_pos_target = [ - "indexes, preds, target, message, metric_args", [ + "indexes, preds, target, message, metric_args", + [ # check when error when there are no positive targets ( _input_retrieval_scores_no_target.indexes, _input_retrieval_scores_no_target.preds, _input_retrieval_scores_no_target.target, "`compute` method was provided with a query with no positive target.", - {'empty_target_action': "error"}, + { + 'empty_target_action': "error" + }, ), ] ] - _errors_test_class_metric_parameters_no_neg_target = [ - "indexes, preds, target, message, metric_args", [ + "indexes, preds, target, message, metric_args", + [ # check when error when there are no negative targets ( _input_retrieval_scores_all_target.indexes, _input_retrieval_scores_all_target.preds, _input_retrieval_scores_all_target.target, "`compute` method was provided with a query with no negative target.", - {'empty_target_action': "error"}, + { + 'empty_target_action': "error" + }, ), ] ] - _errors_test_class_metric_parameters_default = [ - "indexes, preds, target, message, metric_args", [ + "indexes, preds, target, message, metric_args", + [ ( None, _input_retrieval_scores.preds, _input_retrieval_scores.target, "`indexes` cannot be None", - {'empty_target_action': "error"}, + { + 'empty_target_action': "error" + }, ), # check when input arguments are invalid ( @@ -193,7 +204,9 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: _input_retrieval_scores.preds, _input_retrieval_scores.target, "`empty_target_action` received a wrong value `casual_argument`.", - {'empty_target_action': "casual_argument"}, + { + 'empty_target_action': "casual_argument" + }, ), # check input shapes are consistent ( @@ -201,7 +214,9 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: _input_retrieval_scores_mismatching_sizes.preds, _input_retrieval_scores_mismatching_sizes.target, "`indexes`, `preds` and `target` must be of the same shape", - {'empty_target_action': "skip"}, + { + 'empty_target_action': "skip" + }, ), # check input tensors are not empty ( @@ -209,7 +224,9 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: _input_retrieval_scores_empty.preds, _input_retrieval_scores_empty.target, "`indexes`, `preds` and `target` must be non-empty and non-scalar tensors", - {'empty_target_action': "skip"}, + { + 'empty_target_action': "skip" + }, ), # check on input dtypes ( @@ -217,21 +234,27 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: _input_retrieval_scores.preds, _input_retrieval_scores.target, "`indexes` must be a tensor of long integers", - {'empty_target_action': "skip"}, + { + 'empty_target_action': "skip" + }, ), ( _input_retrieval_scores.indexes, _input_retrieval_scores.preds.bool(), _input_retrieval_scores.target, "`preds` must be a tensor of floats", - {'empty_target_action': "skip"}, + { + 'empty_target_action': "skip" + }, ), ( _input_retrieval_scores.indexes, _input_retrieval_scores.preds, _input_retrieval_scores.target.float(), "`target` must be a tensor of booleans or integers", - {'empty_target_action': "skip"}, + { + 'empty_target_action': "skip" + }, ), # check targets are between 0 and 1 ( @@ -239,32 +262,32 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: _input_retrieval_scores_wrong_targets.preds, _input_retrieval_scores_wrong_targets.target, "`target` must contain `binary` values", - {'empty_target_action': "skip"}, + { + 'empty_target_action': "skip" + }, ), ] ] - _errors_test_class_metric_parameters_k = [ - "indexes, preds, target, message, metric_args", [ + "indexes, preds, target, message, metric_args", + [ ( _input_retrieval_scores.index, _input_retrieval_scores.preds, _input_retrieval_scores.target, "`k` has to be a positive integer or None", - {'k': -10}, + { + 'k': -10 + }, ), ] ] - _default_metric_class_input_arguments = [ - "indexes, preds, target", [ - ( - _input_retrieval_scores.indexes, - _input_retrieval_scores.preds, - _input_retrieval_scores.target - ), + "indexes, preds, target", + [ + (_input_retrieval_scores.indexes, _input_retrieval_scores.preds, _input_retrieval_scores.target), ( _input_retrieval_scores_extra.indexes, _input_retrieval_scores_extra.preds, @@ -278,21 +301,12 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: ] ] - _default_metric_functional_input_arguments = [ - "preds, target", [ - ( - _input_retrieval_scores.preds, - _input_retrieval_scores.target - ), - ( - _input_retrieval_scores_extra.preds, - _input_retrieval_scores_extra.target - ), - ( - _input_retrieval_scores_no_target.preds, - _input_retrieval_scores_no_target.target - ), + "preds, target", + [ + (_input_retrieval_scores.preds, _input_retrieval_scores.target), + (_input_retrieval_scores_extra.preds, _input_retrieval_scores_extra.target), + (_input_retrieval_scores_no_target.preds, _input_retrieval_scores_no_target.target), ] ] diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index 63275f71797..0e755ce8013 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -103,10 +103,12 @@ def test_precision_gpu( metric_functional=retrieval_average_precision, ) - @pytest.mark.parametrize(*_concat_tests( - _errors_test_class_metric_parameters_default, - _errors_test_class_metric_parameters_no_pos_target, - )) + @pytest.mark.parametrize( + *_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + ) + ) def test_arguments_class_metric( self, indexes: Tensor, diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 110b8242eb2..3c692c00159 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -1,4 +1,3 @@ - # Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -127,10 +126,12 @@ def test_precision_gpu( metric_functional=retrieval_reciprocal_rank, ) - @pytest.mark.parametrize(*_concat_tests( - _errors_test_class_metric_parameters_default, - _errors_test_class_metric_parameters_no_pos_target, - )) + @pytest.mark.parametrize( + *_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + ) + ) def test_arguments_class_metric( self, indexes: Tensor, diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index 70af5da9458..8e404afd1d0 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -130,11 +130,13 @@ def test_precision_gpu( metric_functional=retrieval_precision, ) - @pytest.mark.parametrize(*_concat_tests( - _errors_test_class_metric_parameters_default, - _errors_test_class_metric_parameters_no_pos_target, - _errors_test_class_metric_parameters_k, - )) + @pytest.mark.parametrize( + *_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_class_metric_parameters_k, + ) + ) def test_arguments_class_metric( self, indexes: Tensor, @@ -154,10 +156,12 @@ def test_arguments_class_metric( kwargs_update={}, ) - @pytest.mark.parametrize(*_concat_tests( - _errors_test_functional_metric_parameters_default, - _errors_test_functional_metric_parameters_k, - )) + @pytest.mark.parametrize( + *_concat_tests( + _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_k, + ) + ) def test_arguments_functional_metric( self, preds: Tensor, diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index f180c1bab45..78f90c6d0ec 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -129,11 +129,13 @@ def test_precision_gpu( metric_functional=retrieval_recall, ) - @pytest.mark.parametrize(*_concat_tests( - _errors_test_class_metric_parameters_default, - _errors_test_class_metric_parameters_no_pos_target, - _errors_test_class_metric_parameters_k, - )) + @pytest.mark.parametrize( + *_concat_tests( + _errors_test_class_metric_parameters_default, + _errors_test_class_metric_parameters_no_pos_target, + _errors_test_class_metric_parameters_k, + ) + ) def test_arguments_class_metric( self, indexes: Tensor, @@ -153,10 +155,12 @@ def test_arguments_class_metric( kwargs_update={}, ) - @pytest.mark.parametrize(*_concat_tests( - _errors_test_functional_metric_parameters_default, - _errors_test_functional_metric_parameters_k, - )) + @pytest.mark.parametrize( + *_concat_tests( + _errors_test_functional_metric_parameters_default, + _errors_test_functional_metric_parameters_k, + ) + ) def test_arguments_functional_metric( self, preds: Tensor, diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 23278b9c86a..010fbd15d00 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -554,7 +554,7 @@ def _check_retrieval_inputs( raise ValueError("`indexes`, `preds` and `target` must be of the same shape") if not indexes.numel() or not indexes.size(): - raise ValueError("`indexes`, `preds` and `target` must be non-empty and non-scalar tensors",) + raise ValueError("`indexes`, `preds` and `target` must be non-empty and non-scalar tensors", ) if indexes.dtype is not torch.long: raise ValueError("`indexes` must be a tensor of long integers") From 5b2ce58569e2393004115d0623abe0f7477a6455 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 6 Apr 2021 10:25:35 +0200 Subject: [PATCH 22/33] simple --- tests/retrieval/helpers.py | 208 +++++++++---------------------------- 1 file changed, 51 insertions(+), 157 deletions(-) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 40c92d6f60e..97abc9984ea 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -22,16 +22,14 @@ from tests.helpers import seed_all from tests.helpers.testers import Metric, MetricTester -from tests.retrieval.inputs import ( - _input_retrieval_scores, - _input_retrieval_scores_all_target, - _input_retrieval_scores_empty, - _input_retrieval_scores_extra, - _input_retrieval_scores_mismatching_sizes, - _input_retrieval_scores_mismatching_sizes_func, - _input_retrieval_scores_no_target, - _input_retrieval_scores_wrong_targets, -) +from tests.retrieval.inputs import _input_retrieval_scores as _irs +from tests.retrieval.inputs import _input_retrieval_scores_all_target as _irs_all +from tests.retrieval.inputs import _input_retrieval_scores_empty as _irs_empty +from tests.retrieval.inputs import _input_retrieval_scores_extra as _irs_extra +from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes as _irs_mis_sz +from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes_func as _irs_mis_sz_fn +from tests.retrieval.inputs import _input_retrieval_scores_no_target as _irs_no_tgt +from tests.retrieval.inputs import _input_retrieval_scores_wrong_targets as _irs_bad_tgt from torchmetrics.utilities.data import get_group_indexes seed_all(42) @@ -96,61 +94,22 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: "preds, target, message, metric_args", [ # check input shapes are consistent (func) - ( - _input_retrieval_scores_mismatching_sizes_func.preds, - _input_retrieval_scores_mismatching_sizes_func.target, - "`preds` and `target` must be of the same shape", - {}, - ), + (_irs_mis_sz_fn.preds, _irs_mis_sz_fn.target, "`preds` and `target` must be of the same shape", {}), # check input tensors are not empty - ( - _input_retrieval_scores_empty.preds, - _input_retrieval_scores_empty.target, - "`preds` and `target` must be non-empty and non-scalar tensors", - {}, - ), + (_irs_empty.preds, _irs_empty.target, "`preds` and `target` must be non-empty and non-scalar tensors", {}), # check on input dtypes - ( - _input_retrieval_scores.preds.bool(), - _input_retrieval_scores.target, - "`preds` must be a tensor of floats", - {}, - ), - ( - _input_retrieval_scores.preds, - _input_retrieval_scores.target.float(), - "`target` must be a tensor of booleans or integers", - {}, - ), + (_irs.preds.bool(), _irs.target, "`preds` must be a tensor of floats", {}), + (_irs.preds, _irs.target.float(), "`target` must be a tensor of booleans or integers", {}), # check targets are between 0 and 1 - ( - _input_retrieval_scores_wrong_targets.preds, - _input_retrieval_scores_wrong_targets.target, - "`target` must contain `binary` values", - {}, - ), + (_irs_bad_tgt.preds, _irs_bad_tgt.target, "`target` must contain `binary` values", {}), ] ] _errors_test_functional_metric_parameters_k = [ "preds, target, message, metric_args", [ - ( - _input_retrieval_scores.preds, - _input_retrieval_scores.target, - "`k` has to be a positive integer or None", - { - 'k': -10 - }, - ), - ( - _input_retrieval_scores.preds, - _input_retrieval_scores.target, - "`k` has to be a positive integer or None", - { - 'k': 4.0 - }, - ), + (_irs.preds, _irs.target, "`k` has to be a positive integer or None", dict(k=-10)), + (_irs.preds, _irs.target, "`k` has to be a positive integer or None", dict(k=4.0)), ] ] @@ -159,13 +118,8 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: [ # check when error when there are no positive targets ( - _input_retrieval_scores_no_target.indexes, - _input_retrieval_scores_no_target.preds, - _input_retrieval_scores_no_target.target, - "`compute` method was provided with a query with no positive target.", - { - 'empty_target_action': "error" - }, + _irs_no_tgt.indexes, _irs_no_tgt.preds, _irs_no_tgt.target, + "`compute` method was provided with a query with no positive target.", dict(empty_target_action="error") ), ] ] @@ -175,13 +129,8 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: [ # check when error when there are no negative targets ( - _input_retrieval_scores_all_target.indexes, - _input_retrieval_scores_all_target.preds, - _input_retrieval_scores_all_target.target, - "`compute` method was provided with a query with no negative target.", - { - 'empty_target_action': "error" - }, + _irs_all.indexes, _irs_all.preds, _irs_all.target, + "`compute` method was provided with a query with no negative target.", dict(empty_target_action="error") ), ] ] @@ -189,82 +138,40 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: _errors_test_class_metric_parameters_default = [ "indexes, preds, target, message, metric_args", [ - ( - None, - _input_retrieval_scores.preds, - _input_retrieval_scores.target, - "`indexes` cannot be None", - { - 'empty_target_action': "error" - }, - ), + (None, _irs.preds, _irs.target, "`indexes` cannot be None", dict(empty_target_action="error")), # check when input arguments are invalid ( - _input_retrieval_scores.indexes, - _input_retrieval_scores.preds, - _input_retrieval_scores.target, - "`empty_target_action` received a wrong value `casual_argument`.", - { - 'empty_target_action': "casual_argument" - }, + _irs.indexes, _irs.preds, _irs.target, "`empty_target_action` received a wrong value `casual_argument`.", + dict(empty_target_action="casual_argument") ), # check input shapes are consistent ( - _input_retrieval_scores_mismatching_sizes.indexes, - _input_retrieval_scores_mismatching_sizes.preds, - _input_retrieval_scores_mismatching_sizes.target, - "`indexes`, `preds` and `target` must be of the same shape", - { - 'empty_target_action': "skip" - }, + _irs_mis_sz.indexes, _irs_mis_sz.preds, _irs_mis_sz.target, + "`indexes`, `preds` and `target` must be of the same shape", dict(empty_target_action="skip") ), # check input tensors are not empty ( - _input_retrieval_scores_empty.indexes, - _input_retrieval_scores_empty.preds, - _input_retrieval_scores_empty.target, - "`indexes`, `preds` and `target` must be non-empty and non-scalar tensors", - { - 'empty_target_action': "skip" - }, + _irs_empty.indexes, _irs_empty.preds, + _irs_empty.target, "`indexes`, `preds` and `target` must be non-empty and non-scalar tensors", + dict(empty_target_action="skip") ), # check on input dtypes ( - _input_retrieval_scores.indexes.bool(), - _input_retrieval_scores.preds, - _input_retrieval_scores.target, - "`indexes` must be a tensor of long integers", - { - 'empty_target_action': "skip" - }, + _irs.indexes.bool(), _irs.preds, _irs.target, "`indexes` must be a tensor of long integers", + dict(empty_target_action="skip") ), ( - _input_retrieval_scores.indexes, - _input_retrieval_scores.preds.bool(), - _input_retrieval_scores.target, - "`preds` must be a tensor of floats", - { - 'empty_target_action': "skip" - }, + _irs.indexes, _irs.preds.bool(), _irs.target, "`preds` must be a tensor of floats", + dict(empty_target_action="skip") ), ( - _input_retrieval_scores.indexes, - _input_retrieval_scores.preds, - _input_retrieval_scores.target.float(), - "`target` must be a tensor of booleans or integers", - { - 'empty_target_action': "skip" - }, + _irs.indexes, _irs.preds, _irs.target.float(), "`target` must be a tensor of booleans or integers", + dict(empty_target_action="skip") ), # check targets are between 0 and 1 ( - _input_retrieval_scores_wrong_targets.indexes, - _input_retrieval_scores_wrong_targets.preds, - _input_retrieval_scores_wrong_targets.target, - "`target` must contain `binary` values", - { - 'empty_target_action': "skip" - }, + _irs_bad_tgt.indexes, _irs_bad_tgt.preds, _irs_bad_tgt.target, "`target` must contain `binary` values", + dict(empty_target_action="skip") ), ] ] @@ -272,41 +179,25 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: _errors_test_class_metric_parameters_k = [ "indexes, preds, target, message, metric_args", [ - ( - _input_retrieval_scores.index, - _input_retrieval_scores.preds, - _input_retrieval_scores.target, - "`k` has to be a positive integer or None", - { - 'k': -10 - }, - ), + (_irs.index, _irs.preds, _irs.target, "`k` has to be a positive integer or None", dict(k=-10)), ] ] _default_metric_class_input_arguments = [ "indexes, preds, target", [ - (_input_retrieval_scores.indexes, _input_retrieval_scores.preds, _input_retrieval_scores.target), - ( - _input_retrieval_scores_extra.indexes, - _input_retrieval_scores_extra.preds, - _input_retrieval_scores_extra.target, - ), - ( - _input_retrieval_scores_no_target.indexes, - _input_retrieval_scores_no_target.preds, - _input_retrieval_scores_no_target.target, - ), + (_irs.indexes, _irs.preds, _irs.target), + (_irs_extra.indexes, _irs_extra.preds, _irs_extra.target), + (_irs_no_tgt.indexes, _irs_no_tgt.preds, _irs_no_tgt.target), ] ] _default_metric_functional_input_arguments = [ "preds, target", [ - (_input_retrieval_scores.preds, _input_retrieval_scores.target), - (_input_retrieval_scores_extra.preds, _input_retrieval_scores_extra.target), - (_input_retrieval_scores_no_target.preds, _input_retrieval_scores_no_target.target), + (_irs.preds, _irs.target), + (_irs_extra.preds, _irs_extra.target), + (_irs_no_tgt.preds, _irs_no_tgt.target), ] ] @@ -317,9 +208,9 @@ def _errors_test_class_metric( target: Tensor, metric_class: Metric, message: str = "", - metric_args: dict = {}, + metric_args: dict = None, exception_type: Exception = ValueError, - kwargs_update: dict = {}, + kwargs_update: dict = None, ): """Utility function doing checks about types, parameters and errors. @@ -334,6 +225,8 @@ def _errors_test_class_metric( kwargs_update: Additional keyword arguments that will be passed with indexes, preds and target when running update on the metric. """ + metric_args = metric_args or {} + kwargs_update = kwargs_update or {} with pytest.raises(exception_type, match=message): metric = metric_class(**metric_args) metric(preds, target, indexes=indexes, **kwargs_update) @@ -345,7 +238,7 @@ def _errors_test_functional_metric( metric_functional: Metric, message: str = "", exception_type: Exception = ValueError, - kwargs_update: dict = {}, + kwargs_update: dict = None, ): """Utility function doing checks about types, parameters and errors. @@ -358,6 +251,7 @@ def _errors_test_functional_metric( kwargs_update: Additional keyword arguments that will be passed with indexes, preds and target when running update on the metric. """ + kwargs_update = kwargs_update or {} with pytest.raises(exception_type, match=message): metric_functional(preds, target, **kwargs_update) @@ -463,9 +357,9 @@ def run_metric_class_arguments_test( target: Tensor, metric_class: Metric, message: str = "", - metric_args: dict = {}, + metric_args: dict = None, exception_type: Exception = ValueError, - kwargs_update: dict = {}, + kwargs_update: dict = None, ): _errors_test_class_metric( indexes=indexes, @@ -485,7 +379,7 @@ def run_functional_metric_arguments_test( metric_functional: Callable, message: str = "", exception_type: Exception = ValueError, - kwargs_update: dict = {}, + kwargs_update: dict = None, ): _errors_test_functional_metric( preds=preds, From 3a3871bfc7d8ab2deacedb8db9325067d3a93276 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 6 Apr 2021 10:36:06 +0200 Subject: [PATCH 23/33] agrs --- tests/retrieval/helpers.py | 64 +++++++++++++++---------------- tests/retrieval/test_map.py | 10 ++--- tests/retrieval/test_mrr.py | 10 ++--- tests/retrieval/test_precision.py | 8 ++-- tests/retrieval/test_recall.py | 8 ++-- 5 files changed, 50 insertions(+), 50 deletions(-) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 97abc9984ea..82efb4e1cca 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -90,9 +90,9 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: return (tests[0][0], sum([x[1] for x in tests], [])) -_errors_test_functional_metric_parameters_default = [ - "preds, target, message, metric_args", - [ +_errors_test_functional_metric_parameters_default = dict( + argnames="preds,target,message,metric_args", + argvalues=[ # check input shapes are consistent (func) (_irs_mis_sz_fn.preds, _irs_mis_sz_fn.target, "`preds` and `target` must be of the same shape", {}), # check input tensors are not empty @@ -103,41 +103,41 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: # check targets are between 0 and 1 (_irs_bad_tgt.preds, _irs_bad_tgt.target, "`target` must contain `binary` values", {}), ] -] +) -_errors_test_functional_metric_parameters_k = [ - "preds, target, message, metric_args", - [ +_errors_test_functional_metric_parameters_k = dict( + argnames="preds,target,message,metric_args", + argvalues=[ (_irs.preds, _irs.target, "`k` has to be a positive integer or None", dict(k=-10)), (_irs.preds, _irs.target, "`k` has to be a positive integer or None", dict(k=4.0)), ] -] +) -_errors_test_class_metric_parameters_no_pos_target = [ - "indexes, preds, target, message, metric_args", - [ +_errors_test_class_metric_parameters_no_pos_target = dict( + argnames="indexes,preds,target,message,metric_args", + argvalues=[ # check when error when there are no positive targets ( _irs_no_tgt.indexes, _irs_no_tgt.preds, _irs_no_tgt.target, "`compute` method was provided with a query with no positive target.", dict(empty_target_action="error") ), ] -] +) -_errors_test_class_metric_parameters_no_neg_target = [ - "indexes, preds, target, message, metric_args", - [ +_errors_test_class_metric_parameters_no_neg_target = dict( + argnames="indexes,preds,target,message,metric_args", + argvalues=[ # check when error when there are no negative targets ( _irs_all.indexes, _irs_all.preds, _irs_all.target, "`compute` method was provided with a query with no negative target.", dict(empty_target_action="error") ), ] -] +) -_errors_test_class_metric_parameters_default = [ - "indexes, preds, target, message, metric_args", - [ +_errors_test_class_metric_parameters_default = dict( + argnames="indexes,preds,target,message,metric_args", + argvalues=[ (None, _irs.preds, _irs.target, "`indexes` cannot be None", dict(empty_target_action="error")), # check when input arguments are invalid ( @@ -174,32 +174,32 @@ def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: dict(empty_target_action="skip") ), ] -] +) -_errors_test_class_metric_parameters_k = [ - "indexes, preds, target, message, metric_args", - [ +_errors_test_class_metric_parameters_k = dict( + argnames="indexes,preds,target,message,metric_args", + argvalues=[ (_irs.index, _irs.preds, _irs.target, "`k` has to be a positive integer or None", dict(k=-10)), ] -] +) -_default_metric_class_input_arguments = [ - "indexes, preds, target", - [ +_default_metric_class_input_arguments = dict( + argnames="indexes,preds,target", + argvalues=[ (_irs.indexes, _irs.preds, _irs.target), (_irs_extra.indexes, _irs_extra.preds, _irs_extra.target), (_irs_no_tgt.indexes, _irs_no_tgt.preds, _irs_no_tgt.target), ] -] +) -_default_metric_functional_input_arguments = [ - "preds, target", - [ +_default_metric_functional_input_arguments = dict( + argnames="preds,target", + argvalues=[ (_irs.preds, _irs.target), (_irs_extra.preds, _irs_extra.target), (_irs_no_tgt.preds, _irs_no_tgt.target), ] -] +) def _errors_test_class_metric( diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index 0e755ce8013..52016a50aa8 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -36,7 +36,7 @@ class TestMAP(RetrievalMetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos']) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( self, ddp: bool, @@ -59,7 +59,7 @@ def test_class_metric( metric_args=metric_args, ) - @pytest.mark.parametrize(*_default_metric_functional_input_arguments) + @pytest.mark.parametrize(**_default_metric_functional_input_arguments) def test_functional_metric( self, preds: Tensor, @@ -73,7 +73,7 @@ def test_functional_metric( metric_args={}, ) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_precision_cpu( self, indexes: Tensor, @@ -88,7 +88,7 @@ def test_precision_cpu( metric_functional=retrieval_average_precision, ) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_precision_gpu( self, indexes: Tensor, @@ -128,7 +128,7 @@ def test_arguments_class_metric( kwargs_update={}, ) - @pytest.mark.parametrize(*_errors_test_functional_metric_parameters_default) + @pytest.mark.parametrize(**_errors_test_functional_metric_parameters_default) def test_arguments_functional_metric( self, preds: Tensor, diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 3c692c00159..0306f116458 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -59,7 +59,7 @@ class TestMRR(RetrievalMetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos']) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( self, ddp: bool, @@ -82,7 +82,7 @@ def test_class_metric( metric_args=metric_args, ) - @pytest.mark.parametrize(*_default_metric_functional_input_arguments) + @pytest.mark.parametrize(**_default_metric_functional_input_arguments) def test_functional_metric( self, preds: Tensor, @@ -96,7 +96,7 @@ def test_functional_metric( metric_args={}, ) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_precision_cpu( self, indexes: Tensor, @@ -111,7 +111,7 @@ def test_precision_cpu( metric_functional=retrieval_reciprocal_rank, ) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_precision_gpu( self, indexes: Tensor, @@ -151,7 +151,7 @@ def test_arguments_class_metric( kwargs_update={}, ) - @pytest.mark.parametrize(*_errors_test_functional_metric_parameters_default) + @pytest.mark.parametrize(**_errors_test_functional_metric_parameters_default) def test_arguments_functional_metric( self, preds: Tensor, diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index 8e404afd1d0..3ff57aa89e1 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -59,7 +59,7 @@ class TestPrecision(RetrievalMetricTester): @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos']) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( self, ddp: bool, @@ -83,7 +83,7 @@ def test_class_metric( metric_args=metric_args, ) - @pytest.mark.parametrize(*_default_metric_functional_input_arguments) + @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) def test_functional_metric( self, @@ -100,7 +100,7 @@ def test_functional_metric( k=k, ) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_precision_cpu( self, indexes: Tensor, @@ -115,7 +115,7 @@ def test_precision_cpu( metric_functional=retrieval_precision, ) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_precision_gpu( self, indexes: Tensor, diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index 78f90c6d0ec..d5c41a547db 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -58,7 +58,7 @@ class TestRecall(RetrievalMetricTester): @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos']) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( self, ddp: bool, @@ -82,7 +82,7 @@ def test_class_metric( metric_args=metric_args, ) - @pytest.mark.parametrize(*_default_metric_functional_input_arguments) + @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) def test_functional_metric( self, @@ -99,7 +99,7 @@ def test_functional_metric( k=k, ) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_precision_cpu( self, indexes: Tensor, @@ -114,7 +114,7 @@ def test_precision_cpu( metric_functional=retrieval_recall, ) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_precision_gpu( self, indexes: Tensor, From 7a3f8c9c95537b5fd2f8d19d14f65c6b2790b2cb Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 6 Apr 2021 10:42:37 +0200 Subject: [PATCH 24/33] format --- tests/retrieval/test_map.py | 35 +++++------------------------- tests/retrieval/test_mrr.py | 35 +++++------------------------- tests/retrieval/test_precision.py | 36 +++++-------------------------- tests/retrieval/test_recall.py | 36 +++++-------------------------- 4 files changed, 20 insertions(+), 122 deletions(-) diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index 52016a50aa8..86fcf63b73a 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -60,11 +60,7 @@ def test_class_metric( ) @pytest.mark.parametrize(**_default_metric_functional_input_arguments) - def test_functional_metric( - self, - preds: Tensor, - target: Tensor, - ): + def test_functional_metric(self, preds: Tensor, target: Tensor): self.run_functional_metric_test( preds=preds, target=target, @@ -74,12 +70,7 @@ def test_functional_metric( ) @pytest.mark.parametrize(**_default_metric_class_input_arguments) - def test_precision_cpu( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - ): + def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor): self.run_precision_test_cpu( indexes=indexes, preds=preds, @@ -89,12 +80,7 @@ def test_precision_cpu( ) @pytest.mark.parametrize(**_default_metric_class_input_arguments) - def test_precision_gpu( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - ): + def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): self.run_precision_test_gpu( indexes=indexes, preds=preds, @@ -110,12 +96,7 @@ def test_precision_gpu( ) ) def test_arguments_class_metric( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, + self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict ): self.run_metric_class_arguments_test( indexes=indexes, @@ -129,13 +110,7 @@ def test_arguments_class_metric( ) @pytest.mark.parametrize(**_errors_test_functional_metric_parameters_default) - def test_arguments_functional_metric( - self, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, - ): + def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict): self.run_functional_metric_arguments_test( preds=preds, target=target, diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 0306f116458..ecf0f314ece 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -83,11 +83,7 @@ def test_class_metric( ) @pytest.mark.parametrize(**_default_metric_functional_input_arguments) - def test_functional_metric( - self, - preds: Tensor, - target: Tensor, - ): + def test_functional_metric(self, preds: Tensor, target: Tensor): self.run_functional_metric_test( preds=preds, target=target, @@ -97,12 +93,7 @@ def test_functional_metric( ) @pytest.mark.parametrize(**_default_metric_class_input_arguments) - def test_precision_cpu( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - ): + def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor): self.run_precision_test_cpu( indexes=indexes, preds=preds, @@ -112,12 +103,7 @@ def test_precision_cpu( ) @pytest.mark.parametrize(**_default_metric_class_input_arguments) - def test_precision_gpu( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - ): + def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): self.run_precision_test_gpu( indexes=indexes, preds=preds, @@ -133,12 +119,7 @@ def test_precision_gpu( ) ) def test_arguments_class_metric( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, + self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict ): self.run_metric_class_arguments_test( indexes=indexes, @@ -152,13 +133,7 @@ def test_arguments_class_metric( ) @pytest.mark.parametrize(**_errors_test_functional_metric_parameters_default) - def test_arguments_functional_metric( - self, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, - ): + def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict): self.run_functional_metric_arguments_test( preds=preds, target=target, diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index 3ff57aa89e1..56e491e24d8 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -85,12 +85,7 @@ def test_class_metric( @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - def test_functional_metric( - self, - preds: Tensor, - target: Tensor, - k: int, - ): + def test_functional_metric(self, preds: Tensor, target: Tensor, k: int): self.run_functional_metric_test( preds=preds, target=target, @@ -101,12 +96,7 @@ def test_functional_metric( ) @pytest.mark.parametrize(**_default_metric_class_input_arguments) - def test_precision_cpu( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - ): + def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor): self.run_precision_test_cpu( indexes=indexes, preds=preds, @@ -116,12 +106,7 @@ def test_precision_cpu( ) @pytest.mark.parametrize(**_default_metric_class_input_arguments) - def test_precision_gpu( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - ): + def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): self.run_precision_test_gpu( indexes=indexes, preds=preds, @@ -138,12 +123,7 @@ def test_precision_gpu( ) ) def test_arguments_class_metric( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, + self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict ): self.run_metric_class_arguments_test( indexes=indexes, @@ -162,13 +142,7 @@ def test_arguments_class_metric( _errors_test_functional_metric_parameters_k, ) ) - def test_arguments_functional_metric( - self, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, - ): + def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict): self.run_functional_metric_arguments_test( preds=preds, target=target, diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index d5c41a547db..3f6beefb797 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -84,12 +84,7 @@ def test_class_metric( @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - def test_functional_metric( - self, - preds: Tensor, - target: Tensor, - k: int, - ): + def test_functional_metric(self, preds: Tensor, target: Tensor, k: int): self.run_functional_metric_test( preds=preds, target=target, @@ -100,12 +95,7 @@ def test_functional_metric( ) @pytest.mark.parametrize(**_default_metric_class_input_arguments) - def test_precision_cpu( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - ): + def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor): self.run_precision_test_cpu( indexes=indexes, preds=preds, @@ -115,12 +105,7 @@ def test_precision_cpu( ) @pytest.mark.parametrize(**_default_metric_class_input_arguments) - def test_precision_gpu( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - ): + def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): self.run_precision_test_gpu( indexes=indexes, preds=preds, @@ -137,12 +122,7 @@ def test_precision_gpu( ) ) def test_arguments_class_metric( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, + self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict ): self.run_metric_class_arguments_test( indexes=indexes, @@ -161,13 +141,7 @@ def test_arguments_class_metric( _errors_test_functional_metric_parameters_k, ) ) - def test_arguments_functional_metric( - self, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, - ): + def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict): self.run_functional_metric_arguments_test( preds=preds, target=target, From d3507656b0572b102dee5fd274d5ba8522fd26bc Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 6 Apr 2021 11:06:15 +0200 Subject: [PATCH 25/33] _sk --- tests/regression/test_ssim.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/regression/test_ssim.py b/tests/regression/test_ssim.py index cea7ecb6d0f..ca33d06ccce 100644 --- a/tests/regression/test_ssim.py +++ b/tests/regression/test_ssim.py @@ -42,7 +42,7 @@ )) -def _sk_metric(preds, target, data_range, multichannel): +def _sk_ssim(preds, target, data_range, multichannel): c, h, w = preds.shape[-3:] sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() @@ -77,7 +77,7 @@ def test_ssim(self, preds, target, multichannel, ddp, dist_sync_on_step): preds, target, SSIM, - partial(_sk_metric, data_range=1.0, multichannel=multichannel), + partial(_sk_ssim, data_range=1.0, multichannel=multichannel), metric_args={"data_range": 1.0}, dist_sync_on_step=dist_sync_on_step, ) @@ -87,7 +87,7 @@ def test_ssim_functional(self, preds, target, multichannel): preds, target, ssim, - partial(_sk_metric, data_range=1.0, multichannel=multichannel), + partial(_sk_ssim, data_range=1.0, multichannel=multichannel), metric_args={"data_range": 1.0}, ) From c796557ae6e0f594fadd5a796236b3cc75122308 Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Tue, 6 Apr 2021 11:13:52 +0200 Subject: [PATCH 26/33] fixed typo in tests --- tests/retrieval/helpers.py | 10 +++++----- tests/retrieval/test_map.py | 2 +- tests/retrieval/test_mrr.py | 2 +- tests/retrieval/test_precision.py | 4 ++-- tests/retrieval/test_recall.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 82efb4e1cca..eee3c2e2bd8 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Tuple, Union +from typing import Callable, Dict, Tuple, Union import numpy as np import pytest @@ -83,11 +83,11 @@ def _compute_sklearn_metric( return np.array(0.0) -def _concat_tests(*tests: Tuple[str, Tuple]) -> Tuple[str, Tuple]: +def _concat_tests(*tests: Tuple[Dict]) -> Dict: """Concat tests composed by a string and a list of arguments.""" - assert len(tests), "cannot concatenate " - assert all([tests[0][0] == x[0] for x in tests[1:]]), "the header must be the same for all tests" - return (tests[0][0], sum([x[1] for x in tests], [])) + assert len(tests), "`_concat_tests` expects at least an argument" + assert all([tests[0]['argnames'] == x['argnames'] for x in tests[1:]]), "the header must be the same for all tests" + return dict(argnames=tests[0]['argnames'], argvalues=sum([x['argvalues'] for x in tests], [])) _errors_test_functional_metric_parameters_default = dict( diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index 86fcf63b73a..829bede28cc 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -90,7 +90,7 @@ def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): ) @pytest.mark.parametrize( - *_concat_tests( + **_concat_tests( _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_no_pos_target, ) diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index ecf0f314ece..ec5f2f14ed8 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -113,7 +113,7 @@ def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): ) @pytest.mark.parametrize( - *_concat_tests( + **_concat_tests( _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_no_pos_target, ) diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index 56e491e24d8..edf8d6bc975 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -116,7 +116,7 @@ def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): ) @pytest.mark.parametrize( - *_concat_tests( + **_concat_tests( _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_no_pos_target, _errors_test_class_metric_parameters_k, @@ -137,7 +137,7 @@ def test_arguments_class_metric( ) @pytest.mark.parametrize( - *_concat_tests( + **_concat_tests( _errors_test_functional_metric_parameters_default, _errors_test_functional_metric_parameters_k, ) diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index 3f6beefb797..d8e522ffaba 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -115,7 +115,7 @@ def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): ) @pytest.mark.parametrize( - *_concat_tests( + **_concat_tests( _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_no_pos_target, _errors_test_class_metric_parameters_k, @@ -136,7 +136,7 @@ def test_arguments_class_metric( ) @pytest.mark.parametrize( - *_concat_tests( + **_concat_tests( _errors_test_functional_metric_parameters_default, _errors_test_functional_metric_parameters_k, ) From 6b1ed29b6386f5d6f61200fa87dc8e870a219259 Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Tue, 6 Apr 2021 12:45:26 +0200 Subject: [PATCH 27/33] updated tests experiments from tuple to dict --- tests/retrieval/test_fallout.py | 48 ++++++++------------------------- 1 file changed, 11 insertions(+), 37 deletions(-) diff --git a/tests/retrieval/test_fallout.py b/tests/retrieval/test_fallout.py index 70622a24e67..a7f5c3c9364 100644 --- a/tests/retrieval/test_fallout.py +++ b/tests/retrieval/test_fallout.py @@ -59,7 +59,7 @@ class TestFallOut(RetrievalMetricTester): @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos']) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) + @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( self, ddp: bool, @@ -84,14 +84,9 @@ def test_class_metric( metric_args=metric_args, ) - @pytest.mark.parametrize(*_default_metric_functional_input_arguments) + @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - def test_functional_metric( - self, - preds: Tensor, - target: Tensor, - k: int, - ): + def test_functional_metric(self, preds: Tensor, target: Tensor, k: int): self.run_functional_metric_test( preds=preds, target=target, @@ -102,13 +97,8 @@ def test_functional_metric( k=k, ) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) - def test_precision_cpu( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - ): + @pytest.mark.parametrize(**_default_metric_class_input_arguments) + def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor): self.run_precision_test_cpu( indexes=indexes, preds=preds, @@ -117,13 +107,8 @@ def test_precision_cpu( metric_functional=retrieval_fall_out, ) - @pytest.mark.parametrize(*_default_metric_class_input_arguments) - def test_precision_gpu( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - ): + @pytest.mark.parametrize(**_default_metric_class_input_arguments) + def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor): self.run_precision_test_gpu( indexes=indexes, preds=preds, @@ -132,18 +117,13 @@ def test_precision_gpu( metric_functional=retrieval_fall_out, ) - @pytest.mark.parametrize(*_concat_tests( + @pytest.mark.parametrize(**_concat_tests( _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_no_neg_target, _errors_test_class_metric_parameters_k, )) def test_arguments_class_metric( - self, - indexes: Tensor, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, + self, indexes: Tensor, preds: Tensor, target: Tensor, message: str, metric_args: dict ): self.run_metric_class_arguments_test( indexes=indexes, @@ -156,17 +136,11 @@ def test_arguments_class_metric( kwargs_update={}, ) - @pytest.mark.parametrize(*_concat_tests( + @pytest.mark.parametrize(**_concat_tests( _errors_test_functional_metric_parameters_default, _errors_test_functional_metric_parameters_k, )) - def test_arguments_functional_metric( - self, - preds: Tensor, - target: Tensor, - message: str, - metric_args: dict, - ): + def test_arguments_functional_metric(self, preds: Tensor, target: Tensor, message: str, metric_args: dict): self.run_functional_metric_arguments_test( preds=preds, target=target, From 0256157d73b7612d0814c8caedee398e30bfc45d Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Tue, 6 Apr 2021 14:31:26 +0200 Subject: [PATCH 28/33] fixed typo in doc --- torchmetrics/functional/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 7ad036f1ce4..488794fdaf6 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -38,6 +38,7 @@ from torchmetrics.functional.regression.r2score import r2score # noqa: F401 from torchmetrics.functional.regression.ssim import ssim # noqa: F401 from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision # noqa: F401 +from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out # noqa: F401 from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401 from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401 from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401 From 44f753714f950fadd126cd1bbff04c78ab5de1a2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 6 Apr 2021 16:30:26 +0200 Subject: [PATCH 29/33] Apply suggestions from code review --- docs/source/references/functional.rst | 2 +- docs/source/references/modules.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 3d8619ffafb..e0888c880e6 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -269,7 +269,7 @@ retrieval_recall [func] retrieval_fall_out [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: torchmetrics.functional.retrieval_fall_out :noindex: diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 7459f87ff78..54f9f10e480 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -351,7 +351,7 @@ RetrievalRecall RetrievalFallOut -~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~ .. autoclass:: torchmetrics.RetrievalFallOut :noindex: From 873df5c2da12359b906117ad1c481e9f43b4095f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 6 Apr 2021 16:40:19 +0200 Subject: [PATCH 30/33] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 963fa761111..5b25a56b671 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `RetrievalRecall` metric for Information Retrieval ([#146](https://github.com/PyTorchLightning/metrics/pull/146)) +- Added `RetrievalFallOut` metric for Information Retrieval ([#161](https://github.com/PyTorchLightning/metrics/pull/161)) + + - Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110)) From 31e4d2f37675c895655581b9d3315518732a6af5 Mon Sep 17 00:00:00 2001 From: Luca Di Liello Date: Tue, 6 Apr 2021 19:15:27 +0200 Subject: [PATCH 31/33] Update tests/retrieval/test_fallout.py Co-authored-by: Nicki Skafte --- tests/retrieval/test_fallout.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/retrieval/test_fallout.py b/tests/retrieval/test_fallout.py index a7f5c3c9364..e40274e5992 100644 --- a/tests/retrieval/test_fallout.py +++ b/tests/retrieval/test_fallout.py @@ -41,8 +41,7 @@ def _fallout_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): assert target.shape == preds.shape assert len(target.shape) == 1 # works only with single dimension inputs - if k is None: - k = len(preds) + k = len(preds) if k is None else k target = 1 - target if target.sum(): From a1c83e29fc4752fdfe7a79b0938552c70fa8a1ab Mon Sep 17 00:00:00 2001 From: Luca Di Liello Date: Tue, 6 Apr 2021 19:16:09 +0200 Subject: [PATCH 32/33] Update torchmetrics/functional/retrieval/fall_out.py Co-authored-by: Nicki Skafte --- torchmetrics/functional/retrieval/fall_out.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchmetrics/functional/retrieval/fall_out.py b/torchmetrics/functional/retrieval/fall_out.py index 43a732021d9..1b86b90fedd 100644 --- a/torchmetrics/functional/retrieval/fall_out.py +++ b/torchmetrics/functional/retrieval/fall_out.py @@ -44,8 +44,7 @@ def retrieval_fall_out(preds: Tensor, target: Tensor, k: int = None) -> Tensor: """ preds, target = _check_retrieval_functional_inputs(preds, target) - if k is None: - k = preds.shape[-1] + k = preds.shape[-1] if k is None else k if not (isinstance(k, int) and k > 0): raise ValueError("`k` has to be a positive integer or None") From 4d7f644f3f8f686f54c500eedc94b3ea817848c9 Mon Sep 17 00:00:00 2001 From: lucadiliello Date: Tue, 6 Apr 2021 19:22:48 +0200 Subject: [PATCH 33/33] added metric link to wikipedia --- tests/retrieval/test_fallout.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/retrieval/test_fallout.py b/tests/retrieval/test_fallout.py index e40274e5992..b5f5c31735b 100644 --- a/tests/retrieval/test_fallout.py +++ b/tests/retrieval/test_fallout.py @@ -36,7 +36,9 @@ def _fallout_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): """ Didn't find a reliable implementation of Fall-out in Information Retrieval, so, - reimplementing here. See wikipedia for more information about definition. + reimplementing here. + See `Wikipedia `__ + for more information about the metric definition. """ assert target.shape == preds.shape assert len(target.shape) == 1 # works only with single dimension inputs