From 2e7175978cf1923201a352e8f8638303d4f149f4 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 00:51:37 +0000 Subject: [PATCH 01/18] Add test cases --- requirements/test.txt | 1 + tests/helpers/testers.py | 50 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index c78e1e9a453..6c57c0b2cda 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -2,6 +2,7 @@ coverage>5.2 codecov>=2.1 pytest>=6.0 pytest-cov>2.10 +pytest-cases>=3.6.5 # pytest-xdist # pytest-flake8 flake8 diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 66877b9cd7a..0695bf1e14d 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -20,6 +20,7 @@ import numpy as np import pytest import torch +from pytest_cases import case from torch import Tensor, tensor from torch.multiprocessing import Pool, set_start_method @@ -326,6 +327,48 @@ def _assert_half_support( _assert_tensor(metric_functional(y_hat, y, **kwargs_update)) +# https://github.com/pytest-dev/pytest/issues/349 +class MetricTesterDDPCases: + @case(tags="ddp") + def case_ddp_false(self): + return False + + @case(tags="ddp") + def case_ddp_true(self): + return True + + @case(tags="device") + def case_device_cpu(self): + return "cpu" + + @case(tags="device") + def case_device_gpu(self): + return "gpu" + + @case(tags="strategy") + def case_ddp_false_device_cpu(self): + return False, "cpu" + + # @pytest.mark.skipif(os.cpu_count() < 2, reason="More than one CPU Core required") + @case(tags="strategy") # order is important, @case after @pytest.mark or markers as argument for @case + def case_ddp_true_device_cpu(self): + return True, "cpu" + + @case(tags="strategy", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required")) + def case_ddp_false_device_gpu(self): + return False, "gpu" + + @case( + tags="strategy", + marks=pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="More than one GPU required for DDP", + ), + ) + def case_ddp_true_device_gpu(self): + return True, "gpu" + + class MetricTester: """Class used for efficiently run alot of parametrized tests in ddp mode. Makes sure that ddp is only setup once and that pool of processes are used for all tests. @@ -359,6 +402,7 @@ def run_functional_metric_test( sk_metric: Callable, metric_args: dict = None, fragment_kwargs: bool = False, + device: str = "cpu", **kwargs_update, ): """Main method that should be used for testing functions. Call this inside testing method. @@ -373,8 +417,6 @@ def run_functional_metric_test( kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ - device = "cuda" if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu" - _functional_test( preds=preds, target=target, @@ -398,6 +440,7 @@ def run_class_metric_test( metric_args: dict = None, check_dist_sync_on_step: bool = True, check_batch: bool = True, + device: str = "cpu", fragment_kwargs: bool = False, check_scriptable: bool = True, **kwargs_update, @@ -439,6 +482,7 @@ def run_class_metric_test( check_dist_sync_on_step=check_dist_sync_on_step, check_batch=check_batch, atol=self.atol, + device=device, fragment_kwargs=fragment_kwargs, check_scriptable=check_scriptable, **kwargs_update, @@ -446,8 +490,6 @@ def run_class_metric_test( [(rank, self.poolSize) for rank in range(self.poolSize)], ) else: - device = "cuda" if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu" - _class_test( rank=0, worldsize=1, From 17578f69424b24fe569d51017c712eac018bb87c Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 00:57:28 +0000 Subject: [PATCH 02/18] Remove comment --- tests/helpers/testers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 0695bf1e14d..f2b653ceaf5 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -349,8 +349,7 @@ def case_device_gpu(self): def case_ddp_false_device_cpu(self): return False, "cpu" - # @pytest.mark.skipif(os.cpu_count() < 2, reason="More than one CPU Core required") - @case(tags="strategy") # order is important, @case after @pytest.mark or markers as argument for @case + @case(tags="strategy") def case_ddp_true_device_cpu(self): return True, "cpu" From 3131b5a4e0baa8152f59f5632e70215f7bfac582 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 01:04:20 +0000 Subject: [PATCH 03/18] Add test strategy to map --- tests/detection/test_map.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 8b02ca160f6..eff35689ead 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -16,8 +16,9 @@ import pytest import torch +from pytest_cases import parametrize_with_cases -from tests.helpers.testers import MetricTester +from tests.helpers.testers import MetricTester, MetricTesterDDPCases from torchmetrics.detection.map import MAP from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 @@ -174,8 +175,8 @@ class TestMAP(MetricTester): atol = 1e-1 - @pytest.mark.parametrize("ddp", [False, True]) - def test_map(self, ddp): + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + def test_map(self, ddp, device): """Test modular implementation for correctness.""" self.run_class_metric_test( ddp=ddp, @@ -185,6 +186,7 @@ def test_map(self, ddp): sk_metric=_compare_fn, dist_sync_on_step=False, check_batch=False, + device=device, metric_args={"class_metrics": True}, ) From 7f51e998efcd8938ebb64a3ee483bf2b63066c0c Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 01:11:09 +0000 Subject: [PATCH 04/18] Fix device type --- tests/helpers/testers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index f2b653ceaf5..702be5acd15 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -343,7 +343,7 @@ def case_device_cpu(self): @case(tags="device") def case_device_gpu(self): - return "gpu" + return "cuda" @case(tags="strategy") def case_ddp_false_device_cpu(self): @@ -355,7 +355,7 @@ def case_ddp_true_device_cpu(self): @case(tags="strategy", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required")) def case_ddp_false_device_gpu(self): - return False, "gpu" + return False, "cuda" @case( tags="strategy", @@ -365,7 +365,7 @@ def case_ddp_false_device_gpu(self): ), ) def case_ddp_true_device_gpu(self): - return True, "gpu" + return True, "cuda" class MetricTester: From 27aa6c2a1fed0a387425c1e21ec2b606353fb7f4 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 15:13:18 +0000 Subject: [PATCH 05/18] Add test strategy to audio metrics --- tests/audio/test_pesq.py | 12 ++++++++---- tests/audio/test_pit.py | 12 ++++++++---- tests/audio/test_sdr.py | 12 ++++++++---- tests/audio/test_si_sdr.py | 12 ++++++++---- tests/audio/test_si_snr.py | 12 ++++++++---- tests/audio/test_snr.py | 12 ++++++++---- tests/audio/test_stoi.py | 12 ++++++++---- 7 files changed, 56 insertions(+), 28 deletions(-) diff --git a/tests/audio/test_pesq.py b/tests/audio/test_pesq.py index 93fce3fe365..80184351b8a 100644 --- a/tests/audio/test_pesq.py +++ b/tests/audio/test_pesq.py @@ -17,10 +17,11 @@ import pytest import torch from pesq import pesq as pesq_backend +from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all -from tests.helpers.testers import MetricTester +from tests.helpers.testers import MetricTester, MetricTesterDDPCases from torchmetrics.audio.pesq import PESQ from torchmetrics.functional.audio.pesq import pesq from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 @@ -75,9 +76,9 @@ def average_metric(preds, target, metric_func): class TestPESQ(MetricTester): atol = 1e-2 - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_pesq(self, preds, target, sk_metric, fs, mode, ddp, dist_sync_on_step): + def test_pesq(self, preds, target, sk_metric, fs, mode, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp, preds, @@ -85,15 +86,18 @@ def test_pesq(self, preds, target, sk_metric, fs, mode, ddp, dist_sync_on_step): PESQ, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=dict(fs=fs, mode=mode), ) - def test_pesq_functional(self, preds, target, sk_metric, fs, mode): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_pesq_functional(self, preds, target, sk_metric, fs, mode, device): self.run_functional_metric_test( preds, target, pesq, sk_metric, + device=device, metric_args=dict(fs=fs, mode=mode), ) diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index 22b229cdbb5..c01e95b67d8 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -18,11 +18,12 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from scipy.optimize import linear_sum_assignment from torch import Tensor from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.audio import PIT from torchmetrics.functional import pit, si_sdr, snr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 @@ -112,9 +113,9 @@ def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Ten class TestPIT(MetricTester): atol = 1e-2 - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_pit(self, preds, target, sk_metric, metric_func, eval_func, ddp, dist_sync_on_step): + def test_pit(self, preds, target, sk_metric, metric_func, eval_func, ddp, device, dist_sync_on_step): self.run_class_metric_test( ddp, preds, @@ -122,15 +123,18 @@ def test_pit(self, preds, target, sk_metric, metric_func, eval_func, ddp, dist_s PIT, sk_metric=partial(_average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=dict(metric_func=metric_func, eval_func=eval_func), ) - def test_pit_functional(self, preds, target, sk_metric, metric_func, eval_func): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_pit_functional(self, preds, target, sk_metric, device, metric_func, eval_func): self.run_functional_metric_test( preds=preds, target=target, metric_functional=pit, sk_metric=sk_metric, + device=device, metric_args=dict(metric_func=metric_func, eval_func=eval_func), ) diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index 2f28e8976f3..eb3f58762f5 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -19,11 +19,12 @@ import pytest import torch from mir_eval.separation import bss_eval_sources +from pytest_cases import parametrize_with_cases from scipy.io import wavfile from torch import Tensor from tests.helpers import seed_all -from tests.helpers.testers import MetricTester +from tests.helpers.testers import MetricTester, MetricTesterDDPCases from torchmetrics.audio import SDR from torchmetrics.functional import sdr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8 @@ -76,9 +77,9 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tens class TestSDR(MetricTester): atol = 1e-2 - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step): + def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp, preds, @@ -86,15 +87,18 @@ def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step): SDR, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=dict(), ) - def test_sdr_functional(self, preds, target, sk_metric): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_sdr_functional(self, preds, target, sk_metric, device): self.run_functional_metric_test( preds, target, sdr, sk_metric, + device=device, metric_args=dict(), ) diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index 479b2c09c6a..b172b3dd8f2 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -17,10 +17,11 @@ import pytest import speechmetrics import torch +from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.audio import SI_SDR from torchmetrics.functional import si_sdr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 @@ -77,9 +78,9 @@ def average_metric(preds, target, metric_func): class TestSISDR(MetricTester): atol = 1e-2 - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): + def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp, preds, @@ -87,15 +88,18 @@ def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_ste SI_SDR, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=dict(zero_mean=zero_mean), ) - def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean, device): self.run_functional_metric_test( preds, target, si_sdr, sk_metric, + device=device, metric_args=dict(zero_mean=zero_mean), ) diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index 063790b3579..744ed097da5 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -17,10 +17,11 @@ import pytest import speechmetrics import torch +from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.audio import SI_SNR from torchmetrics.functional import si_snr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 @@ -72,9 +73,9 @@ def average_metric(preds, target, metric_func): class TestSISNR(MetricTester): atol = 1e-2 - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step): + def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp, preds, @@ -82,14 +83,17 @@ def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step): SI_SNR, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, + device=device, ) - def test_si_snr_functional(self, preds, target, sk_metric): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_si_snr_functional(self, preds, target, sk_metric, device): self.run_functional_metric_test( preds, target, si_snr, sk_metric, + device=device, ) def test_si_snr_differentiability(self, preds, target, sk_metric): diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index ad63f99c9b5..758751a2cd8 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -18,10 +18,11 @@ import pytest import torch from mir_eval.separation import bss_eval_images as mir_eval_bss_eval_images +from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.audio import SNR from torchmetrics.functional import snr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 @@ -79,9 +80,9 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable): class TestSNR(MetricTester): atol = 1e-2 - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): + def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp, preds, @@ -89,15 +90,18 @@ def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): SNR, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=dict(zero_mean=zero_mean), ) - def test_snr_functional(self, preds, target, sk_metric, zero_mean): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_snr_functional(self, preds, target, sk_metric, zero_mean, device): self.run_functional_metric_test( preds, target, snr, sk_metric, + device=device, metric_args=dict(zero_mean=zero_mean), ) diff --git a/tests/audio/test_stoi.py b/tests/audio/test_stoi.py index cd4192e83d7..e53f13c6fea 100644 --- a/tests/audio/test_stoi.py +++ b/tests/audio/test_stoi.py @@ -17,10 +17,11 @@ import pytest import torch from pystoi import stoi as stoi_backend +from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all -from tests.helpers.testers import MetricTester +from tests.helpers.testers import MetricTester, MetricTesterDDPCases from torchmetrics.audio.stoi import STOI from torchmetrics.functional.audio.stoi import stoi from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 @@ -75,9 +76,9 @@ def average_metric(preds, target, metric_func): class TestSTOI(MetricTester): atol = 1e-2 - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_step): + def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp, preds, @@ -85,15 +86,18 @@ def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_st STOI, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=dict(fs=fs, extended=extended), ) - def test_stoi_functional(self, preds, target, sk_metric, fs, extended): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_stoi_functional(self, preds, target, sk_metric, fs, extended, device): self.run_functional_metric_test( preds, target, stoi, sk_metric, + device=device, metric_args=dict(fs=fs, extended=extended), ) From 258da03fc2074fd4e641d4b3af06684b78923fe9 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 15:20:26 +0000 Subject: [PATCH 06/18] Add test strategy to bases --- tests/bases/test_aggregation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 106621e9cb4..75668e1ad1c 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -1,8 +1,9 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric @@ -80,15 +81,16 @@ def update(self, values, weights): class TestAggregation(MetricTester): """Test aggregation metrics.""" - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False]) - def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights): + def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights, device): """test modular implementation.""" self.run_class_metric_test( ddp=ddp, dist_sync_on_step=dist_sync_on_step, metric_class=metric_class, sk_metric=compare_fn, + device=device, check_scriptable=True, # Abuse of names here preds=values, From b4dae4d8a02d016326d1671cbdf02a20d645a9ab Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 15:45:26 +0000 Subject: [PATCH 07/18] Add test strategy to classification --- tests/classification/test_accuracy.py | 12 ++++++++---- tests/classification/test_auc.py | 18 +++++++++++++----- tests/classification/test_auroc.py | 12 ++++++++---- tests/classification/test_average_precision.py | 12 ++++++++---- .../test_binned_precision_recall.py | 15 ++++++++++----- tests/classification/test_calibration_error.py | 12 ++++++++---- tests/classification/test_cohen_kappa.py | 12 ++++++++---- tests/classification/test_confusion_matrix.py | 12 ++++++++---- tests/classification/test_f_beta.py | 10 ++++++++-- tests/classification/test_hamming_distance.py | 12 ++++++++---- tests/classification/test_hinge.py | 12 ++++++++---- tests/classification/test_jaccard.py | 12 ++++++++---- tests/classification/test_kl_divergence.py | 12 ++++++++---- tests/classification/test_matthews_corrcoef.py | 12 ++++++++---- tests/classification/test_precision_recall.py | 10 ++++++++-- .../test_precision_recall_curve.py | 12 ++++++++---- tests/classification/test_roc.py | 12 ++++++++---- tests/classification/test_specificity.py | 10 ++++++++-- tests/classification/test_stat_scores.py | 10 +++++++++- 19 files changed, 160 insertions(+), 69 deletions(-) diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index d8e1cef230a..eff44f5dddd 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -16,6 +16,7 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import accuracy_score as sk_accuracy from torch import tensor @@ -32,7 +33,7 @@ from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester +from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester, MetricTesterDDPCases from torchmetrics import Accuracy from torchmetrics.functional import accuracy from torchmetrics.utilities.checks import _input_format_classification @@ -81,9 +82,9 @@ def _sk_accuracy(preds, target, subset_accuracy): ], ) class TestAccuracies(MetricTester): - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy): + def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy, device): self.run_class_metric_test( ddp=ddp, preds=preds, @@ -91,15 +92,18 @@ def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accu metric_class=Accuracy, sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy}, ) - def test_accuracy_fn(self, preds, target, subset_accuracy): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_accuracy_fn(self, preds, target, subset_accuracy, device): self.run_functional_metric_test( preds, target, metric_functional=accuracy, sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), + device=device, metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy}, ) diff --git a/tests/classification/test_auc.py b/tests/classification/test_auc.py index 177186e7e02..6bd6462e133 100644 --- a/tests/classification/test_auc.py +++ b/tests/classification/test_auc.py @@ -16,11 +16,12 @@ import numpy as np import pytest +from pytest_cases import parametrize_with_cases from sklearn.metrics import auc as _sk_auc from torch import tensor from tests.helpers import seed_all -from tests.helpers.testers import NUM_BATCHES, MetricTester +from tests.helpers.testers import NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.classification.auc import AUC from torchmetrics.functional import auc @@ -55,9 +56,9 @@ def sk_auc(x, y, reorder=False): @pytest.mark.parametrize("x, y", _examples) class TestAUC(MetricTester): - @pytest.mark.parametrize("ddp", [False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_auc(self, x, y, ddp, dist_sync_on_step): + def test_auc(self, x, y, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp=ddp, preds=x, @@ -65,12 +66,19 @@ def test_auc(self, x, y, ddp, dist_sync_on_step): metric_class=AUC, sk_metric=sk_auc, dist_sync_on_step=dist_sync_on_step, + device=device, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") @pytest.mark.parametrize("reorder", [True, False]) - def test_auc_functional(self, x, y, reorder): + def test_auc_functional(self, x, y, reorder, device): self.run_functional_metric_test( - x, y, metric_functional=auc, sk_metric=partial(sk_auc, reorder=reorder), metric_args={"reorder": reorder} + x, + y, + metric_functional=auc, + sk_metric=partial(sk_auc, reorder=reorder), + device=device, + metric_args={"reorder": reorder}, ) @pytest.mark.parametrize("reorder", [True, False]) diff --git a/tests/classification/test_auroc.py b/tests/classification/test_auroc.py index f616c8c341d..d4d4c131fa9 100644 --- a/tests/classification/test_auroc.py +++ b/tests/classification/test_auroc.py @@ -15,6 +15,7 @@ import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import roc_auc_score as sk_roc_auc_score from tests.classification.inputs import _input_binary_prob @@ -23,7 +24,7 @@ from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, MetricTester +from tests.helpers.testers import NUM_CLASSES, MetricTester, MetricTesterDDPCases from torchmetrics.classification.auroc import AUROC from torchmetrics.functional import auroc from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 @@ -99,9 +100,9 @@ def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average="macr ], ) class TestAUROC(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step): + def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step, device): # max_fpr different from None is not support in multi class if max_fpr is not None and num_classes != 1: pytest.skip("max_fpr parameter not support for multi class or multi label") @@ -121,10 +122,12 @@ def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, dd metric_class=AUROC, sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr}, ) - def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr, device): # max_fpr different from None is not support in multi class if max_fpr is not None and num_classes != 1: pytest.skip("max_fpr parameter not support for multi class or multi label") @@ -142,6 +145,7 @@ def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, target, metric_functional=auroc, sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), + device=device, metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr}, ) diff --git a/tests/classification/test_average_precision.py b/tests/classification/test_average_precision.py index cb44624c99f..d747fa96bd0 100644 --- a/tests/classification/test_average_precision.py +++ b/tests/classification/test_average_precision.py @@ -15,6 +15,7 @@ import numpy as np import pytest +from pytest_cases import parametrize_with_cases from sklearn.metrics import average_precision_score as sk_average_precision_score from torch import tensor @@ -23,7 +24,7 @@ from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.classification.inputs import _input_multilabel from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, MetricTester +from tests.helpers.testers import NUM_CLASSES, MetricTester, MetricTesterDDPCases from torchmetrics.classification.avg_precision import AveragePrecision from torchmetrics.functional import average_precision @@ -87,9 +88,9 @@ def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1, average= ) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) class TestAveragePrecision(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_average_precision(self, preds, target, sk_metric, num_classes, average, ddp, dist_sync_on_step): + def test_average_precision(self, preds, target, sk_metric, num_classes, average, ddp, dist_sync_on_step, device): if target.max() > 1 and average == "micro": pytest.skip("average=micro and multiclass input cannot be used together") @@ -100,10 +101,12 @@ def test_average_precision(self, preds, target, sk_metric, num_classes, average, metric_class=AveragePrecision, sk_metric=partial(sk_metric, num_classes=num_classes, average=average), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={"num_classes": num_classes, "average": average}, ) - def test_average_precision_functional(self, preds, target, sk_metric, num_classes, average): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_average_precision_functional(self, preds, target, sk_metric, num_classes, average, device): if target.max() > 1 and average == "micro": pytest.skip("average=micro and multiclass input cannot be used together") @@ -112,6 +115,7 @@ def test_average_precision_functional(self, preds, target, sk_metric, num_classe target=target, metric_functional=average_precision, sk_metric=partial(sk_metric, num_classes=num_classes, average=average), + device=device, metric_args={"num_classes": num_classes, "average": average}, ) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 95f288670bc..b5256d11407 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -18,6 +18,7 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import average_precision_score as _sk_average_precision_score from sklearn.metrics import precision_recall_curve as _sk_precision_recall_curve from torch import Tensor @@ -27,7 +28,7 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.classification.inputs import _input_multilabel_prob_plausible as _input_mlb_prob_ok from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, MetricTester +from tests.helpers.testers import NUM_CLASSES, MetricTester, MetricTesterDDPCases from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision, BinnedRecallAtFixedPrecision seed_all(42) @@ -77,11 +78,11 @@ def _sk_avg_prec_multiclass(predictions, targets, num_classes): class TestBinnedRecallAtPrecision(MetricTester): atol = 0.02 - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8, 0.95]) def test_binned_recall_at_precision( - self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, min_precision + self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, min_precision, device ): # rounding will simulate binning for both implementations preds = Tensor(np.round(preds.numpy(), 2)) + 1e-6 @@ -93,6 +94,7 @@ def test_binned_recall_at_precision( metric_class=BinnedRecallAtFixedPrecision, sk_metric=partial(sk_metric, num_classes=num_classes, min_precision=min_precision), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={ "num_classes": num_classes, "min_precision": min_precision, @@ -111,10 +113,12 @@ def test_binned_recall_at_precision( ], ) class TestBinnedAveragePrecision(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("thresholds", (301, torch.linspace(0.0, 1.0, 101))) - def test_binned_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, thresholds): + def test_binned_average_precision( + self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, thresholds, device + ): # rounding will simulate binning for both implementations preds = Tensor(np.round(preds.numpy(), 2)) + 1e-6 @@ -125,5 +129,6 @@ def test_binned_average_precision(self, preds, target, sk_metric, num_classes, d metric_class=BinnedAveragePrecision, sk_metric=partial(sk_metric, num_classes=num_classes), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={"num_classes": num_classes, "thresholds": thresholds}, ) diff --git a/tests/classification/test_calibration_error.py b/tests/classification/test_calibration_error.py index 68822b863f8..87ff211432d 100644 --- a/tests/classification/test_calibration_error.py +++ b/tests/classification/test_calibration_error.py @@ -3,6 +3,7 @@ import numpy as np import pytest +from pytest_cases import parametrize_with_cases from tests.classification.inputs import _input_binary_prob from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob @@ -12,7 +13,7 @@ # TODO: replace this with official sklearn implementation after next sklearn release from tests.helpers.non_sklearn_metrics import calibration_error as sk_calib -from tests.helpers.testers import THRESHOLD, MetricTester +from tests.helpers.testers import THRESHOLD, MetricTester, MetricTesterDDPCases from torchmetrics import CalibrationError from torchmetrics.functional import calibration_error from torchmetrics.utilities.checks import _input_format_classification @@ -51,9 +52,9 @@ def _sk_calibration(preds, target, n_bins, norm, debias=False): ], ) class TestCE(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_ce(self, preds, target, n_bins, ddp, dist_sync_on_step, norm): + def test_ce(self, preds, target, n_bins, ddp, dist_sync_on_step, norm, device): self.run_class_metric_test( ddp=ddp, preds=preds, @@ -61,15 +62,18 @@ def test_ce(self, preds, target, n_bins, ddp, dist_sync_on_step, norm): metric_class=CalibrationError, sk_metric=functools.partial(_sk_calibration, n_bins=n_bins, norm=norm), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={"n_bins": n_bins, "norm": norm}, ) - def test_ce_functional(self, preds, target, n_bins, norm): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_ce_functional(self, preds, target, n_bins, norm, device): self.run_functional_metric_test( preds, target, metric_functional=calibration_error, sk_metric=functools.partial(_sk_calibration, n_bins=n_bins, norm=norm), + device=device, metric_args={"n_bins": n_bins, "norm": norm}, ) diff --git a/tests/classification/test_cohen_kappa.py b/tests/classification/test_cohen_kappa.py index ce83e85ea46..83c91397f69 100644 --- a/tests/classification/test_cohen_kappa.py +++ b/tests/classification/test_cohen_kappa.py @@ -3,6 +3,7 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa from tests.classification.inputs import _input_binary, _input_binary_prob @@ -13,7 +14,7 @@ from tests.classification.inputs import _input_multilabel as _input_mlb from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, MetricTesterDDPCases from torchmetrics.classification.cohen_kappa import CohenKappa from torchmetrics.functional.classification.cohen_kappa import cohen_kappa @@ -93,9 +94,9 @@ def _sk_cohen_kappa_multidim_multiclass(preds, target, weights=None): class TestCohenKappa(MetricTester): atol = 1e-5 - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_cohen_kappa(self, weights, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_cohen_kappa(self, weights, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp=ddp, preds=preds, @@ -103,15 +104,18 @@ def test_cohen_kappa(self, weights, preds, target, sk_metric, num_classes, ddp, metric_class=CohenKappa, sk_metric=partial(sk_metric, weights=weights), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, ) - def test_cohen_kappa_functional(self, weights, preds, target, sk_metric, num_classes): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_cohen_kappa_functional(self, weights, preds, target, sk_metric, num_classes, device): self.run_functional_metric_test( preds, target, metric_functional=cohen_kappa, sk_metric=partial(sk_metric, weights=weights), + device=device, metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, ) diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index 79b05f4a2f4..e3631f11437 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -16,6 +16,7 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import confusion_matrix as sk_confusion_matrix from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix @@ -29,7 +30,7 @@ from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, MetricTesterDDPCases from torchmetrics.classification.confusion_matrix import ConfusionMatrix from torchmetrics.functional import confusion_matrix @@ -128,10 +129,10 @@ def _sk_cm_multidim_multiclass(preds, target, normalize=None): ], ) class TestConfusionMatrix(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_confusion_matrix( - self, normalize, preds, target, sk_metric, num_classes, multilabel, ddp, dist_sync_on_step + self, normalize, preds, target, sk_metric, num_classes, multilabel, ddp, dist_sync_on_step, device ): self.run_class_metric_test( ddp=ddp, @@ -140,6 +141,7 @@ def test_confusion_matrix( metric_class=ConfusionMatrix, sk_metric=partial(sk_metric, normalize=normalize), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={ "num_classes": num_classes, "threshold": THRESHOLD, @@ -148,12 +150,14 @@ def test_confusion_matrix( }, ) - def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel, device): self.run_functional_metric_test( preds=preds, target=target, metric_functional=confusion_matrix, sk_metric=partial(sk_metric, normalize=normalize), + device=device, metric_args={ "num_classes": num_classes, "threshold": THRESHOLD, diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index 2cabdd1017d..4dffad4c476 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -17,6 +17,7 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import f1_score, fbeta_score from torch import Tensor @@ -31,7 +32,7 @@ from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester +from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester, MetricTesterDDPCases from torchmetrics import F1, FBeta, Metric from torchmetrics.functional import f1, fbeta from torchmetrics.utilities.checks import _input_format_classification @@ -239,7 +240,7 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected): ], ) class TestFBeta(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_fbeta_f1( self, @@ -256,6 +257,7 @@ def test_fbeta_f1( average: str, mdmc_average: Optional[str], ignore_index: Optional[int], + device: str, ): if num_classes == 1 and average != "micro": pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") @@ -281,6 +283,7 @@ def test_fbeta_f1( mdmc_average=mdmc_average, ), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={ "num_classes": num_classes, "average": average, @@ -293,6 +296,7 @@ def test_fbeta_f1( check_batch=True, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") def test_fbeta_f1_functional( self, preds: Tensor, @@ -306,6 +310,7 @@ def test_fbeta_f1_functional( average: str, mdmc_average: Optional[str], ignore_index: Optional[int], + device: str, ): if num_classes == 1 and average != "micro": pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") @@ -329,6 +334,7 @@ def test_fbeta_f1_functional( ignore_index=ignore_index, mdmc_average=mdmc_average, ), + device=device, metric_args={ "num_classes": num_classes, "average": average, diff --git a/tests/classification/test_hamming_distance.py b/tests/classification/test_hamming_distance.py index eeac1bce8c9..d75b4267fed 100644 --- a/tests/classification/test_hamming_distance.py +++ b/tests/classification/test_hamming_distance.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +from pytest_cases import parametrize_with_cases from sklearn.metrics import hamming_loss as sk_hamming_loss from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob @@ -26,7 +27,7 @@ from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import THRESHOLD, MetricTester +from tests.helpers.testers import THRESHOLD, MetricTester, MetricTesterDDPCases from torchmetrics import HammingDistance from torchmetrics.functional import hamming_distance from torchmetrics.utilities.checks import _input_format_classification @@ -61,9 +62,9 @@ def _sk_hamming_loss(preds, target): ], ) class TestHammingDistance(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target): + def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target, device): self.run_class_metric_test( ddp=ddp, preds=preds, @@ -71,15 +72,18 @@ def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target): metric_class=HammingDistance, sk_metric=_sk_hamming_loss, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={"threshold": THRESHOLD}, ) - def test_hamming_distance_fn(self, preds, target): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_hamming_distance_fn(self, preds, target, device): self.run_functional_metric_test( preds=preds, target=target, metric_functional=hamming_distance, sk_metric=_sk_hamming_loss, + device=device, metric_args={"threshold": THRESHOLD}, ) diff --git a/tests/classification/test_hinge.py b/tests/classification/test_hinge.py index 07e9f81de9c..06bcdb55d9b 100644 --- a/tests/classification/test_hinge.py +++ b/tests/classification/test_hinge.py @@ -16,11 +16,12 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import hinge_loss as sk_hinge from sklearn.preprocessing import OneHotEncoder from tests.classification.inputs import Input -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester, MetricTesterDDPCases from torchmetrics import Hinge from torchmetrics.functional import hinge from torchmetrics.functional.classification.hinge import MulticlassMode @@ -88,9 +89,9 @@ def _sk_hinge(preds, target, squared, multiclass_mode): ], ) class TestHinge(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multiclass_mode): + def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multiclass_mode, device): self.run_class_metric_test( ddp=ddp, preds=preds, @@ -98,18 +99,21 @@ def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multi metric_class=Hinge, sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={ "squared": squared, "multiclass_mode": multiclass_mode, }, ) - def test_hinge_fn(self, preds, target, squared, multiclass_mode): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_hinge_fn(self, preds, target, squared, multiclass_mode, device): self.run_functional_metric_test( preds=preds, target=target, metric_functional=partial(hinge, squared=squared, multiclass_mode=multiclass_mode), sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode), + device=device, ) def test_hinge_differentiability(self, preds, target, squared, multiclass_mode): diff --git a/tests/classification/test_jaccard.py b/tests/classification/test_jaccard.py index 3b6260d0b92..4c62afc419b 100644 --- a/tests/classification/test_jaccard.py +++ b/tests/classification/test_jaccard.py @@ -16,6 +16,7 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import jaccard_score as sk_jaccard_score from torch import Tensor, tensor @@ -26,7 +27,7 @@ from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.classification.inputs import _input_multilabel as _input_mlb from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, MetricTesterDDPCases from torchmetrics.classification.jaccard import JaccardIndex from torchmetrics.functional import jaccard_index @@ -102,9 +103,9 @@ def _sk_jaccard_multidim_multiclass(preds, target, average=None): ], ) class TestJaccardIndex(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_jaccard(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_jaccard(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, device): average = "macro" if reduction == "elementwise_mean" else None # convert tags self.run_class_metric_test( ddp=ddp, @@ -113,16 +114,19 @@ def test_jaccard(self, reduction, preds, target, sk_metric, num_classes, ddp, di metric_class=JaccardIndex, sk_metric=partial(sk_metric, average=average), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction}, ) - def test_jaccard_functional(self, reduction, preds, target, sk_metric, num_classes): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_jaccard_functional(self, reduction, preds, target, sk_metric, num_classes, device): average = "macro" if reduction == "elementwise_mean" else None # convert tags self.run_functional_metric_test( preds, target, metric_functional=jaccard_index, sk_metric=partial(sk_metric, average=average), + device=device, metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction}, ) diff --git a/tests/classification/test_kl_divergence.py b/tests/classification/test_kl_divergence.py index 09ac20428dd..a7a8ff64d34 100644 --- a/tests/classification/test_kl_divergence.py +++ b/tests/classification/test_kl_divergence.py @@ -18,11 +18,12 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from scipy.stats import entropy from torch import Tensor from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.classification import KLDivergence from torchmetrics.functional import kl_divergence @@ -60,9 +61,9 @@ def _sk_metric(p: Tensor, q: Tensor, log_prob: bool, reduction: Optional[str] = class TestKLDivergence(MetricTester): atol = 1e-6 - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_kldivergence(self, reduction, p, q, log_prob, ddp, dist_sync_on_step): + def test_kldivergence(self, reduction, p, q, log_prob, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp, p, @@ -70,16 +71,19 @@ def test_kldivergence(self, reduction, p, q, log_prob, ddp, dist_sync_on_step): KLDivergence, partial(_sk_metric, log_prob=log_prob, reduction=reduction), dist_sync_on_step, + device=device, metric_args=dict(log_prob=log_prob, reduction=reduction), ) - def test_kldivergence_functional(self, reduction, p, q, log_prob): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_kldivergence_functional(self, reduction, p, q, log_prob, device): # todo: `num_outputs` is unused self.run_functional_metric_test( p, q, kl_divergence, partial(_sk_metric, log_prob=log_prob, reduction=reduction), + device=device, metric_args=dict(log_prob=log_prob, reduction=reduction), ) diff --git a/tests/classification/test_matthews_corrcoef.py b/tests/classification/test_matthews_corrcoef.py index 4e412dece51..cc1d0d674b0 100644 --- a/tests/classification/test_matthews_corrcoef.py +++ b/tests/classification/test_matthews_corrcoef.py @@ -14,6 +14,7 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import matthews_corrcoef as sk_matthews_corrcoef from tests.classification.inputs import _input_binary, _input_binary_prob @@ -24,7 +25,7 @@ from tests.classification.inputs import _input_multilabel as _input_mlb from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, MetricTesterDDPCases from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef @@ -101,9 +102,9 @@ def _sk_matthews_corrcoef_multidim_multiclass(preds, target): ], ) class TestMatthewsCorrCoef(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp=ddp, preds=preds, @@ -111,18 +112,21 @@ def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dis metric_class=MatthewsCorrcoef, sk_metric=sk_metric, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={ "num_classes": num_classes, "threshold": THRESHOLD, }, ) - def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classes): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classes, device): self.run_functional_metric_test( preds, target, metric_functional=matthews_corrcoef, sk_metric=sk_metric, + device=device, metric_args={ "num_classes": num_classes, "threshold": THRESHOLD, diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index 6c35cf5e5f4..ecdb6cefb12 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -17,6 +17,7 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import precision_score, recall_score from torch import Tensor, tensor @@ -31,7 +32,7 @@ from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester +from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester, MetricTesterDDPCases from torchmetrics import Metric, Precision, Recall from torchmetrics.functional import precision, precision_recall, recall from torchmetrics.utilities.checks import _input_format_classification @@ -210,7 +211,7 @@ def test_no_support(metric_class, metric_fn): ], ) class TestPrecisionRecall(MetricTester): - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False]) def test_precision_recall_class( self, @@ -227,6 +228,7 @@ def test_precision_recall_class( average: str, mdmc_average: Optional[str], ignore_index: Optional[int], + device: str, ): # todo: `metric_fn` is unused if num_classes == 1 and average != "micro": @@ -253,6 +255,7 @@ def test_precision_recall_class( mdmc_average=mdmc_average, ), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={ "num_classes": num_classes, "average": average, @@ -265,6 +268,7 @@ def test_precision_recall_class( check_batch=True, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") def test_precision_recall_fn( self, preds: Tensor, @@ -278,6 +282,7 @@ def test_precision_recall_fn( average: str, mdmc_average: Optional[str], ignore_index: Optional[int], + device: str, ): # todo: `metric_class` is unused if num_classes == 1 and average != "micro": @@ -302,6 +307,7 @@ def test_precision_recall_fn( ignore_index=ignore_index, mdmc_average=mdmc_average, ), + device=device, metric_args={ "num_classes": num_classes, "average": average, diff --git a/tests/classification/test_precision_recall_curve.py b/tests/classification/test_precision_recall_curve.py index d7d731d0342..7921c57ce07 100644 --- a/tests/classification/test_precision_recall_curve.py +++ b/tests/classification/test_precision_recall_curve.py @@ -16,6 +16,7 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve from torch import Tensor, tensor @@ -23,7 +24,7 @@ from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, MetricTester +from tests.helpers.testers import NUM_CLASSES, MetricTester, MetricTesterDDPCases from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve from torchmetrics.functional import precision_recall_curve from torchmetrics.functional.classification.precision_recall_curve import _binary_clf_curve @@ -76,9 +77,9 @@ def _sk_prec_rc_multidim_multiclass_prob(preds, target, num_classes=1): ], ) class TestPrecisionRecallCurve(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp=ddp, preds=preds, @@ -86,15 +87,18 @@ def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp metric_class=PrecisionRecallCurve, sk_metric=partial(sk_metric, num_classes=num_classes), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={"num_classes": num_classes}, ) - def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_classes): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_classes, device): self.run_functional_metric_test( preds, target, metric_functional=precision_recall_curve, sk_metric=partial(sk_metric, num_classes=num_classes), + device=device, metric_args={"num_classes": num_classes}, ) diff --git a/tests/classification/test_roc.py b/tests/classification/test_roc.py index 7dd1b83c25b..177ddb9c834 100644 --- a/tests/classification/test_roc.py +++ b/tests/classification/test_roc.py @@ -16,6 +16,7 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import roc_curve as sk_roc_curve from torch import tensor @@ -25,7 +26,7 @@ from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, MetricTester +from tests.helpers.testers import NUM_CLASSES, MetricTester, MetricTesterDDPCases from torchmetrics.classification.roc import ROC from torchmetrics.functional import roc @@ -95,9 +96,9 @@ def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1): ], ) class TestROC(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp=ddp, preds=preds, @@ -105,15 +106,18 @@ def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step metric_class=ROC, sk_metric=partial(sk_metric, num_classes=num_classes), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={"num_classes": num_classes}, ) - def test_roc_functional(self, preds, target, sk_metric, num_classes): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_roc_functional(self, preds, target, sk_metric, num_classes, device): self.run_functional_metric_test( preds, target, metric_functional=roc, sk_metric=partial(sk_metric, num_classes=num_classes), + device=device, metric_args={"num_classes": num_classes}, ) diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index ada49edb332..ccc24c85654 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -18,6 +18,7 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import multilabel_confusion_matrix from torch import Tensor, tensor @@ -29,7 +30,7 @@ from tests.classification.inputs import _input_multilabel as _input_mlb from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, MetricTesterDDPCases from torchmetrics import Metric, Specificity from torchmetrics.functional import specificity from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores @@ -218,7 +219,7 @@ def test_no_support(metric_class, metric_fn): ], ) class TestSpecificity(MetricTester): - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_specificity_class( self, @@ -234,6 +235,7 @@ def test_specificity_class( average: str, mdmc_average: Optional[str], ignore_index: Optional[int], + device: str, ): # todo: `metric_fn` is unused if num_classes == 1 and average != "micro": @@ -259,6 +261,7 @@ def test_specificity_class( mdmc_reduce=mdmc_average, ), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={ "num_classes": num_classes, "average": average, @@ -271,6 +274,7 @@ def test_specificity_class( check_batch=True, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") def test_specificity_fn( self, preds: Tensor, @@ -283,6 +287,7 @@ def test_specificity_fn( average: str, mdmc_average: Optional[str], ignore_index: Optional[int], + device: str, ): # todo: `metric_class` is unused if num_classes == 1 and average != "micro": @@ -306,6 +311,7 @@ def test_specificity_fn( ignore_index=ignore_index, mdmc_reduce=mdmc_average, ), + device=device, metric_args={ "num_classes": num_classes, "average": average, diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 5a550db19cd..fc5ba4c2b3d 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -17,6 +17,7 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import multilabel_confusion_matrix from torch import Tensor, tensor @@ -29,7 +30,7 @@ from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, MetricTester +from tests.helpers.testers import NUM_CLASSES, MetricTester, MetricTesterDDPCases from torchmetrics import StatScores from torchmetrics.functional import stat_scores from torchmetrics.utilities.checks import _input_format_classification @@ -172,6 +173,8 @@ def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index): class TestStatScores(MetricTester): # DDP tests temporarily disabled due to hanging issues @pytest.mark.parametrize("ddp", [False]) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + # @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_stat_scores_class( self, @@ -187,6 +190,7 @@ def test_stat_scores_class( ignore_index: Optional[int], top_k: Optional[int], threshold: Optional[float], + device: str, ): if ignore_index is not None and preds.ndim == 2: pytest.skip("Skipping ignore_index test with binary inputs.") @@ -207,6 +211,7 @@ def test_stat_scores_class( threshold=threshold, ), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args={ "num_classes": num_classes, "reduce": reduce, @@ -220,6 +225,7 @@ def test_stat_scores_class( check_batch=True, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") def test_stat_scores_fn( self, sk_fn: Callable, @@ -232,6 +238,7 @@ def test_stat_scores_fn( ignore_index: Optional[int], top_k: Optional[int], threshold: Optional[float], + device: str, ): if ignore_index is not None and preds.ndim == 2: pytest.skip("Skipping ignore_index test with binary inputs.") @@ -250,6 +257,7 @@ def test_stat_scores_fn( top_k=top_k, threshold=threshold, ), + device=device, metric_args={ "num_classes": num_classes, "reduce": reduce, From 8439b4a3d55e42db654617350605a490d05a9739 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 15:58:08 +0000 Subject: [PATCH 08/18] Add test strategy to image --- tests/image/test_lpips.py | 8 +++++--- tests/image/test_psnr.py | 12 ++++++++---- tests/image/test_ssim.py | 12 ++++++++---- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/tests/image/test_lpips.py b/tests/image/test_lpips.py index 5a51bef987f..5b42bdbbe54 100644 --- a/tests/image/test_lpips.py +++ b/tests/image/test_lpips.py @@ -17,10 +17,11 @@ import pytest import torch from lpips import LPIPS as reference_LPIPS +from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.image.lpip_similarity import LPIPS from torchmetrics.utilities.imports import _LPIPS_AVAILABLE @@ -46,8 +47,8 @@ def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, reduction: str = "mea @pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed") @pytest.mark.parametrize("net_type", ["vgg", "alex", "squeeze"]) class TestLPIPS(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) - def test_lpips(self, net_type, ddp): + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + def test_lpips(self, net_type, ddp, device): """test modular implementation for correctness.""" self.run_class_metric_test( ddp=ddp, @@ -56,6 +57,7 @@ def test_lpips(self, net_type, ddp): metric_class=LPIPS, sk_metric=partial(_compare_fn, net_type=net_type), dist_sync_on_step=False, + device=device, check_scriptable=False, metric_args={"net_type": net_type}, ) diff --git a/tests/image/test_psnr.py b/tests/image/test_psnr.py index 1cc27bed303..8c2a920b544 100644 --- a/tests/image/test_psnr.py +++ b/tests/image/test_psnr.py @@ -18,10 +18,11 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from skimage.metrics import peak_signal_noise_ratio from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.functional import psnr from torchmetrics.image import PSNR @@ -93,9 +94,9 @@ def _base_e_sk_psnr(preds, target, data_range, reduction, dim): ], ) class TestPSNR(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, ddp, dist_sync_on_step): + def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, ddp, dist_sync_on_step, device): _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} self.run_class_metric_test( ddp, @@ -103,17 +104,20 @@ def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, target, PSNR, partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim), + device=device, metric_args=_args, dist_sync_on_step=dist_sync_on_step, ) - def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduction, dim): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduction, dim, device): _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} self.run_functional_metric_test( preds, target, psnr, partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim), + device=device, metric_args=_args, ) diff --git a/tests/image/test_ssim.py b/tests/image/test_ssim.py index 9ef37504fce..ca24b68bced 100644 --- a/tests/image/test_ssim.py +++ b/tests/image/test_ssim.py @@ -16,10 +16,11 @@ import pytest import torch +from pytest_cases import parametrize_with_cases from skimage.metrics import structural_similarity from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.functional import ssim from torchmetrics.image import SSIM @@ -72,25 +73,28 @@ def _sk_ssim(preds, target, data_range, multichannel, kernel_size): class TestSSIM(MetricTester): atol = 6e-3 - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_ssim(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_step): + def test_ssim(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp, preds, target, SSIM, partial(_sk_ssim, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size), + device=device, metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)}, dist_sync_on_step=dist_sync_on_step, ) - def test_ssim_functional(self, preds, target, multichannel, kernel_size): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_ssim_functional(self, preds, target, multichannel, kernel_size, device): self.run_functional_metric_test( preds, target, ssim, partial(_sk_ssim, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size), + device=device, metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)}, ) From 30b4ea768b1d981d732a2acc12a2f106ac533657 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 16:00:59 +0000 Subject: [PATCH 09/18] Add test strategy to pairwise --- tests/pairwise/test_pairwise_distance.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/pairwise/test_pairwise_distance.py b/tests/pairwise/test_pairwise_distance.py index 5a809b9e282..a424df61358 100644 --- a/tests/pairwise/test_pairwise_distance.py +++ b/tests/pairwise/test_pairwise_distance.py @@ -16,10 +16,11 @@ import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, linear_kernel, manhattan_distances from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.functional import ( pairwise_cosine_similarity, pairwise_euclidean_distance, @@ -81,13 +82,15 @@ class TestPairwise(MetricTester): atol = 1e-4 - def test_pairwise_functional(self, x, y, metric_functional, sk_fn, reduction): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_pairwise_functional(self, x, y, metric_functional, sk_fn, reduction, device): """test functional pairwise implementations.""" self.run_functional_metric_test( preds=x, target=y, metric_functional=metric_functional, sk_metric=partial(_sk_metric, sk_fn=sk_fn, reduction=reduction), + device=device, metric_args={"reduction": reduction}, ) From 2f5dc34141745f8a8a04cf0ca2d39510b4ed476e Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 16:07:56 +0000 Subject: [PATCH 10/18] Add test strategy to regression --- tests/regression/test_cosine_similarity.py | 12 +++++++---- tests/regression/test_explained_variance.py | 12 +++++++---- tests/regression/test_mean_error.py | 24 +++++++++++++++++---- tests/regression/test_pearson.py | 13 ++++++----- tests/regression/test_r2.py | 12 +++++++---- tests/regression/test_spearman.py | 13 ++++++----- tests/regression/test_tweedie_deviance.py | 12 +++++++---- 7 files changed, 68 insertions(+), 30 deletions(-) diff --git a/tests/regression/test_cosine_similarity.py b/tests/regression/test_cosine_similarity.py index 2009f402ecf..04219236b86 100644 --- a/tests/regression/test_cosine_similarity.py +++ b/tests/regression/test_cosine_similarity.py @@ -17,10 +17,11 @@ import numpy as np import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics.pairwise import cosine_similarity as sk_cosine from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.functional.regression.cosine_similarity import cosine_similarity from torchmetrics.regression.cosine_similarity import CosineSimilarity @@ -82,9 +83,9 @@ def _single_target_sk_metric(preds, target, reduction, sk_fn=sk_cosine): ], ) class TestCosineSimilarity(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_cosine_similarity(self, reduction, preds, target, sk_metric, ddp, dist_sync_on_step): + def test_cosine_similarity(self, reduction, preds, target, sk_metric, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp, preds, @@ -92,15 +93,18 @@ def test_cosine_similarity(self, reduction, preds, target, sk_metric, ddp, dist_ CosineSimilarity, partial(sk_metric, reduction=reduction), dist_sync_on_step, + device=device, metric_args=dict(reduction=reduction), ) - def test_cosine_similarity_functional(self, reduction, preds, target, sk_metric): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_cosine_similarity_functional(self, reduction, preds, target, sk_metric, device): self.run_functional_metric_test( preds, target, cosine_similarity, partial(sk_metric, reduction=reduction), + device=device, metric_args=dict(reduction=reduction), ) diff --git a/tests/regression/test_explained_variance.py b/tests/regression/test_explained_variance.py index 0bf0e48d069..5f8e3eecf7e 100644 --- a/tests/regression/test_explained_variance.py +++ b/tests/regression/test_explained_variance.py @@ -16,10 +16,11 @@ import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import explained_variance_score from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.functional import explained_variance from torchmetrics.regression import ExplainedVariance from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 @@ -62,9 +63,9 @@ def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score): ], ) class TestExplainedVariance(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, dist_sync_on_step): + def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp, preds, @@ -72,15 +73,18 @@ def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, di ExplainedVariance, partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), dist_sync_on_step, + device=device, metric_args=dict(multioutput=multioutput), ) - def test_explained_variance_functional(self, multioutput, preds, target, sk_metric): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_explained_variance_functional(self, multioutput, preds, target, sk_metric, device): self.run_functional_metric_test( preds, target, explained_variance, partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), + device=device, metric_args=dict(multioutput=multioutput), ) diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index aaa35a41f6f..7d4d9b70a7b 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -17,6 +17,7 @@ import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error from sklearn.metrics import mean_absolute_percentage_error as sk_mean_abs_percentage_error from sklearn.metrics import mean_squared_error as sk_mean_squared_error @@ -26,7 +27,7 @@ from tests.helpers.non_sklearn_metrics import ( symmetric_mean_absolute_percentage_error as sk_sym_mean_abs_percentage_error, ) -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.functional import ( mean_absolute_error, mean_absolute_percentage_error, @@ -108,10 +109,20 @@ def _multi_target_sk_metric(preds, target, sk_fn, metric_args): ], ) class TestMeanError(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_mean_error_class( - self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args, ddp, dist_sync_on_step + self, + preds, + target, + sk_metric, + metric_class, + metric_functional, + sk_fn, + metric_args, + ddp, + dist_sync_on_step, + device, ): # todo: `metric_functional` is unused self.run_class_metric_test( @@ -121,16 +132,21 @@ def test_mean_error_class( metric_class=metric_class, sk_metric=partial(sk_metric, sk_fn=sk_fn, metric_args=metric_args), dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) - def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_mean_error_functional( + self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args, device + ): # todo: `metric_class` is unused self.run_functional_metric_test( preds=preds, target=target, metric_functional=metric_functional, sk_metric=partial(sk_metric, sk_fn=sk_fn, metric_args=metric_args), + device=device, metric_args=metric_args, ) diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index c4540f5c377..cb0fb1917e0 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -15,10 +15,11 @@ import pytest import torch +from pytest_cases import parametrize_with_cases from scipy.stats import pearsonr from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.functional.regression.pearson import pearson_corrcoef from torchmetrics.regression.pearson import PearsonCorrcoef @@ -53,8 +54,8 @@ def _sk_pearsonr(preds, target): class TestPearsonCorrcoef(MetricTester): atol = 1e-2 - @pytest.mark.parametrize("ddp", [True, False]) - def test_pearson_corrcoef(self, preds, target, ddp): + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + def test_pearson_corrcoef(self, preds, target, ddp, device): self.run_class_metric_test( ddp=ddp, preds=preds, @@ -62,11 +63,13 @@ def test_pearson_corrcoef(self, preds, target, ddp): metric_class=PearsonCorrcoef, sk_metric=_sk_pearsonr, dist_sync_on_step=False, + device=device, ) - def test_pearson_corrcoef_functional(self, preds, target): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_pearson_corrcoef_functional(self, preds, target, device): self.run_functional_metric_test( - preds=preds, target=target, metric_functional=pearson_corrcoef, sk_metric=_sk_pearsonr + preds=preds, target=target, metric_functional=pearson_corrcoef, sk_metric=_sk_pearsonr, device=device ) def test_pearson_corrcoef_differentiability(self, preds, target): diff --git a/tests/regression/test_r2.py b/tests/regression/test_r2.py index 6882191b870..4ad2b8f882d 100644 --- a/tests/regression/test_r2.py +++ b/tests/regression/test_r2.py @@ -16,10 +16,11 @@ import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import r2_score as sk_r2score from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.functional import r2_score from torchmetrics.regression import R2Score from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 @@ -69,9 +70,9 @@ def _multi_target_sk_metric(preds, target, adjusted, multioutput): ], ) class TestR2Score(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step): + def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp, preds, @@ -79,16 +80,19 @@ def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, R2Score, partial(sk_metric, adjusted=adjusted, multioutput=multioutput), dist_sync_on_step, + device=device, metric_args=dict(adjusted=adjusted, multioutput=multioutput, num_outputs=num_outputs), ) - def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, device): # todo: `num_outputs` is unused self.run_functional_metric_test( preds, target, r2_score, partial(sk_metric, adjusted=adjusted, multioutput=multioutput), + device=device, metric_args=dict(adjusted=adjusted, multioutput=multioutput), ) diff --git a/tests/regression/test_spearman.py b/tests/regression/test_spearman.py index ef59dd7a697..d2ed7b4c861 100644 --- a/tests/regression/test_spearman.py +++ b/tests/regression/test_spearman.py @@ -15,10 +15,11 @@ import pytest import torch +from pytest_cases import parametrize_with_cases from scipy.stats import rankdata, spearmanr from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.functional.regression.spearman import _rank_data, spearman_corrcoef from torchmetrics.regression.spearman import SpearmanCorrcoef @@ -76,9 +77,9 @@ def _sk_metric(preds, target): class TestSpearmanCorrcoef(MetricTester): atol = 1e-2 - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_spearman_corrcoef(self, preds, target, ddp, dist_sync_on_step): + def test_spearman_corrcoef(self, preds, target, ddp, dist_sync_on_step, device): self.run_class_metric_test( ddp, preds, @@ -86,10 +87,12 @@ def test_spearman_corrcoef(self, preds, target, ddp, dist_sync_on_step): SpearmanCorrcoef, _sk_metric, dist_sync_on_step, + device=device, ) - def test_spearman_corrcoef_functional(self, preds, target): - self.run_functional_metric_test(preds, target, spearman_corrcoef, _sk_metric) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_spearman_corrcoef_functional(self, preds, target, device): + self.run_functional_metric_test(preds, target, spearman_corrcoef, _sk_metric, device=device) def test_spearman_corrcoef_differentiability(self, preds, target): self.run_differentiability_test( diff --git a/tests/regression/test_tweedie_deviance.py b/tests/regression/test_tweedie_deviance.py index 3b2d61ec63d..2a2298fa3ea 100644 --- a/tests/regression/test_tweedie_deviance.py +++ b/tests/regression/test_tweedie_deviance.py @@ -16,11 +16,12 @@ import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import mean_tweedie_deviance from torch import Tensor from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.functional.regression.tweedie_deviance import tweedie_deviance_score from torchmetrics.regression.tweedie_deviance import TweedieDevianceScore @@ -60,9 +61,9 @@ def _sk_deviance(preds: Tensor, targets: Tensor, power: float): ], ) class TestDevianceScore(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_deviance_scores_class(self, ddp, dist_sync_on_step, preds, targets, power): + def test_deviance_scores_class(self, ddp, dist_sync_on_step, preds, targets, power, device): self.run_class_metric_test( ddp, preds, @@ -70,15 +71,18 @@ def test_deviance_scores_class(self, ddp, dist_sync_on_step, preds, targets, pow TweedieDevianceScore, partial(_sk_deviance, power=power), dist_sync_on_step, + device=device, metric_args=dict(power=power), ) - def test_deviance_scores_functional(self, preds, targets, power): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_deviance_scores_functional(self, preds, targets, power, device): self.run_functional_metric_test( preds, targets, tweedie_deviance_score, partial(_sk_deviance, power=power), + device=device, metric_args=dict(power=power), ) From 2dfe6c8f1550f2a2d7b82e7f09aa03850c1771d1 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 16:20:55 +0000 Subject: [PATCH 11/18] Add strategy to pairwise --- tests/retrieval/helpers.py | 4 ++++ tests/retrieval/test_fallout.py | 14 +++++++++++--- tests/retrieval/test_hit_rate.py | 14 +++++++++++--- tests/retrieval/test_map.py | 14 +++++++++++--- tests/retrieval/test_mrr.py | 14 +++++++++++--- tests/retrieval/test_ndcg.py | 14 +++++++++++--- tests/retrieval/test_precision.py | 14 +++++++++++--- tests/retrieval/test_r_precision.py | 14 +++++++++++--- tests/retrieval/test_recall.py | 14 +++++++++++--- 9 files changed, 92 insertions(+), 24 deletions(-) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 1cbf3078096..2ffec010c3f 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -417,6 +417,7 @@ def run_class_metric_test( metric_class: Metric, sk_metric: Callable, dist_sync_on_step: bool, + device: str, metric_args: dict, reverse: bool = False, ): @@ -429,6 +430,7 @@ def run_class_metric_test( metric_class=metric_class, sk_metric=_sk_metric_adapted, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, fragment_kwargs=True, indexes=indexes, # every additional argument will be passed to metric_class and _sk_metric_adapted @@ -440,6 +442,7 @@ def run_functional_metric_test( target: Tensor, metric_functional: Callable, sk_metric: Callable, + device: str, metric_args: dict, reverse: bool = False, **kwargs, @@ -451,6 +454,7 @@ def run_functional_metric_test( target=target, metric_functional=metric_functional, sk_metric=_sk_metric_adapted, + device=device, metric_args=metric_args, fragment_kwargs=True, **kwargs, diff --git a/tests/retrieval/test_fallout.py b/tests/retrieval/test_fallout.py index 61cd94960d5..330976e4d8e 100644 --- a/tests/retrieval/test_fallout.py +++ b/tests/retrieval/test_fallout.py @@ -13,9 +13,11 @@ # limitations under the License. import numpy as np import pytest +from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all +from tests.helpers.testers import MetricTesterDDPCases from tests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, @@ -53,7 +55,7 @@ def _fallout_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): class TestFallOut(RetrievalMetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -69,6 +71,7 @@ def test_class_metric( empty_target_action: str, ignore_index: int, k: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index) @@ -80,11 +83,12 @@ def test_class_metric( metric_class=RetrievalFallOut, sk_metric=_fallout_at_k, dist_sync_on_step=dist_sync_on_step, + device=device, reverse=True, metric_args=metric_args, ) - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("k", [None, 1, 4, 10]) @@ -98,6 +102,7 @@ def test_class_metric_ignore_index( dist_sync_on_step: bool, empty_target_action: str, k: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100) @@ -109,18 +114,21 @@ def test_class_metric_ignore_index( metric_class=RetrievalFallOut, sk_metric=_fallout_at_k, dist_sync_on_step=dist_sync_on_step, + device=device, reverse=True, metric_args=metric_args, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - def test_functional_metric(self, preds: Tensor, target: Tensor, k: int): + def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, device: str): self.run_functional_metric_test( preds=preds, target=target, metric_functional=retrieval_fall_out, sk_metric=_fallout_at_k, + device=device, reverse=True, metric_args={}, k=k, diff --git a/tests/retrieval/test_hit_rate.py b/tests/retrieval/test_hit_rate.py index 06da9be6587..250305aecd7 100644 --- a/tests/retrieval/test_hit_rate.py +++ b/tests/retrieval/test_hit_rate.py @@ -13,9 +13,11 @@ # limitations under the License. import numpy as np import pytest +from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all +from tests.helpers.testers import MetricTesterDDPCases from tests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, @@ -50,7 +52,7 @@ def _hit_rate_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): class TestHitRate(RetrievalMetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -66,6 +68,7 @@ def test_class_metric( empty_target_action: str, ignore_index: int, k: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index) @@ -77,10 +80,11 @@ def test_class_metric( metric_class=RetrievalHitRate, sk_metric=_hit_rate_at_k, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("k", [None, 1, 4, 10]) @@ -94,6 +98,7 @@ def test_class_metric_ignore_index( dist_sync_on_step: bool, empty_target_action: str, k: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100) @@ -105,17 +110,20 @@ def test_class_metric_ignore_index( metric_class=RetrievalHitRate, sk_metric=_hit_rate_at_k, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - def test_functional_metric(self, preds: Tensor, target: Tensor, k: int): + def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, device: str): self.run_functional_metric_test( preds=preds, target=target, metric_functional=retrieval_hit_rate, sk_metric=_hit_rate_at_k, + device=device, metric_args={}, k=k, ) diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index 8a7e3a67a75..a874ff42fe7 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +from pytest_cases import parametrize_with_cases from sklearn.metrics import average_precision_score as sk_average_precision_score from torch import Tensor from tests.helpers import seed_all +from tests.helpers.testers import MetricTesterDDPCases from tests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, @@ -33,7 +35,7 @@ class TestMAP(RetrievalMetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -47,6 +49,7 @@ def test_class_metric( dist_sync_on_step: bool, empty_target_action: str, ignore_index: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, ignore_index=ignore_index) @@ -58,10 +61,11 @@ def test_class_metric( metric_class=RetrievalMAP, sk_metric=sk_average_precision_score, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) @@ -73,6 +77,7 @@ def test_class_metric_ignore_index( target: Tensor, dist_sync_on_step: bool, empty_target_action: str, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, ignore_index=-100) @@ -84,16 +89,19 @@ def test_class_metric_ignore_index( metric_class=RetrievalMAP, sk_metric=sk_average_precision_score, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") @pytest.mark.parametrize(**_default_metric_functional_input_arguments) - def test_functional_metric(self, preds: Tensor, target: Tensor): + def test_functional_metric(self, preds: Tensor, target: Tensor, device: str): self.run_functional_metric_test( preds=preds, target=target, metric_functional=retrieval_average_precision, sk_metric=sk_average_precision_score, + device=device, metric_args={}, ) diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 9e3cc318876..19c01a8a854 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -13,10 +13,12 @@ # limitations under the License. import numpy as np import pytest +from pytest_cases import parametrize_with_cases from sklearn.metrics import label_ranking_average_precision_score from torch import Tensor from tests.helpers import seed_all +from tests.helpers.testers import MetricTesterDDPCases from tests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, @@ -55,7 +57,7 @@ def _reciprocal_rank(target: np.ndarray, preds: np.ndarray): class TestMRR(RetrievalMetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -69,6 +71,7 @@ def test_class_metric( dist_sync_on_step: bool, empty_target_action: str, ignore_index: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, ignore_index=ignore_index) @@ -80,10 +83,11 @@ def test_class_metric( metric_class=RetrievalMRR, sk_metric=_reciprocal_rank, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) @@ -95,6 +99,7 @@ def test_class_metric_ignore_index( target: Tensor, dist_sync_on_step: bool, empty_target_action: str, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, ignore_index=-100) @@ -106,16 +111,19 @@ def test_class_metric_ignore_index( metric_class=RetrievalMRR, sk_metric=_reciprocal_rank, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") @pytest.mark.parametrize(**_default_metric_functional_input_arguments) - def test_functional_metric(self, preds: Tensor, target: Tensor): + def test_functional_metric(self, preds: Tensor, target: Tensor, device: str): self.run_functional_metric_test( preds=preds, target=target, metric_functional=retrieval_reciprocal_rank, sk_metric=_reciprocal_rank, + device=device, metric_args={}, ) diff --git a/tests/retrieval/test_ndcg.py b/tests/retrieval/test_ndcg.py index ff6b5a0737a..cf8f31ea78f 100644 --- a/tests/retrieval/test_ndcg.py +++ b/tests/retrieval/test_ndcg.py @@ -13,10 +13,12 @@ # limitations under the License. import numpy as np import pytest +from pytest_cases import parametrize_with_cases from sklearn.metrics import ndcg_score from torch import Tensor from tests.helpers import seed_all +from tests.helpers.testers import MetricTesterDDPCases from tests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, @@ -49,7 +51,7 @@ def _ndcg_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): class TestNDCG(RetrievalMetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 3]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -65,6 +67,7 @@ def test_class_metric( empty_target_action: str, ignore_index: int, k: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index) @@ -76,10 +79,11 @@ def test_class_metric( metric_class=RetrievalNormalizedDCG, sk_metric=_ndcg_at_k, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("k", [None, 1, 4, 10]) @@ -93,6 +97,7 @@ def test_class_metric_ignore_index( dist_sync_on_step: bool, empty_target_action: str, k: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100) @@ -104,17 +109,20 @@ def test_class_metric_ignore_index( metric_class=RetrievalNormalizedDCG, sk_metric=_ndcg_at_k, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") @pytest.mark.parametrize(**_default_metric_functional_input_arguments_with_non_binary_target) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - def test_functional_metric(self, preds: Tensor, target: Tensor, k: int): + def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, device: str): self.run_functional_metric_test( preds=preds, target=target, metric_functional=retrieval_normalized_dcg, sk_metric=_ndcg_at_k, + device=device, metric_args={}, k=k, ) diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index e8541a60cc4..db7e2c83739 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -13,9 +13,11 @@ # limitations under the License. import numpy as np import pytest +from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all +from tests.helpers.testers import MetricTesterDDPCases from tests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, @@ -54,7 +56,7 @@ def _precision_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): class TestPrecision(RetrievalMetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -70,6 +72,7 @@ def test_class_metric( empty_target_action: str, ignore_index: int, k: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index) @@ -81,10 +84,11 @@ def test_class_metric( metric_class=RetrievalPrecision, sk_metric=_precision_at_k, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("k", [None, 1, 4, 10]) @@ -98,6 +102,7 @@ def test_class_metric_ignore_index( dist_sync_on_step: bool, empty_target_action: str, k: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100) @@ -109,17 +114,20 @@ def test_class_metric_ignore_index( metric_class=RetrievalPrecision, sk_metric=_precision_at_k, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - def test_functional_metric(self, preds: Tensor, target: Tensor, k: int): + def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, device: str): self.run_functional_metric_test( preds=preds, target=target, metric_functional=retrieval_precision, sk_metric=_precision_at_k, + device=device, metric_args={}, k=k, ) diff --git a/tests/retrieval/test_r_precision.py b/tests/retrieval/test_r_precision.py index 5d2c103c916..b3e590cc09c 100644 --- a/tests/retrieval/test_r_precision.py +++ b/tests/retrieval/test_r_precision.py @@ -13,9 +13,11 @@ # limitations under the License. import numpy as np import pytest +from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all +from tests.helpers.testers import MetricTesterDDPCases from tests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, @@ -49,7 +51,7 @@ def _r_precision(target: np.ndarray, preds: np.ndarray): class TestRPrecision(RetrievalMetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -63,6 +65,7 @@ def test_class_metric( dist_sync_on_step: bool, empty_target_action: str, ignore_index: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, ignore_index=ignore_index) @@ -74,10 +77,11 @@ def test_class_metric( metric_class=RetrievalRPrecision, sk_metric=_r_precision, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) @@ -89,6 +93,7 @@ def test_class_metric_ignore_index( target: Tensor, dist_sync_on_step: bool, empty_target_action: str, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, ignore_index=-100) @@ -100,16 +105,19 @@ def test_class_metric_ignore_index( metric_class=RetrievalRPrecision, sk_metric=_r_precision, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") @pytest.mark.parametrize(**_default_metric_functional_input_arguments) - def test_functional_metric(self, preds: Tensor, target: Tensor): + def test_functional_metric(self, preds: Tensor, target: Tensor, device: str): self.run_functional_metric_test( preds=preds, target=target, metric_functional=retrieval_r_precision, sk_metric=_r_precision, + device=device, metric_args={}, ) diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index cbb5b8fd40d..63cbe890c70 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -13,9 +13,11 @@ # limitations under the License. import numpy as np import pytest +from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all +from tests.helpers.testers import MetricTesterDDPCases from tests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, @@ -53,7 +55,7 @@ def _recall_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): class TestRecall(RetrievalMetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -69,6 +71,7 @@ def test_class_metric( empty_target_action: str, ignore_index: int, k: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index) @@ -80,10 +83,11 @@ def test_class_metric( metric_class=RetrievalRecall, sk_metric=_recall_at_k, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("k", [None, 1, 4, 10]) @@ -97,6 +101,7 @@ def test_class_metric_ignore_index( dist_sync_on_step: bool, empty_target_action: str, k: int, + device: str, ): metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100) @@ -108,17 +113,20 @@ def test_class_metric_ignore_index( metric_class=RetrievalRecall, sk_metric=_recall_at_k, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, ) + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - def test_functional_metric(self, preds: Tensor, target: Tensor, k: int): + def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, device: str): self.run_functional_metric_test( preds=preds, target=target, metric_functional=retrieval_recall, sk_metric=_recall_at_k, + device=device, metric_args={}, k=k, ) From 2e9f2355b4a2fec5a9fb6defbf40ca40a5c2dbaa Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 16:34:50 +0000 Subject: [PATCH 12/18] Add test strategy to text --- tests/text/helpers.py | 7 +++---- tests/text/test_bleu.py | 13 ++++++++++--- tests/text/test_cer.py | 11 ++++++++--- tests/text/test_chrf.py | 11 ++++++++--- tests/text/test_mer.py | 11 ++++++++--- tests/text/test_rouge.py | 11 ++++++++--- tests/text/test_sacre_bleu.py | 11 ++++++++--- tests/text/test_ter.py | 11 ++++++++--- tests/text/test_wer.py | 11 ++++++++--- tests/text/test_wil.py | 11 ++++++++--- tests/text/test_wip.py | 11 ++++++++--- 11 files changed, 85 insertions(+), 34 deletions(-) diff --git a/tests/text/helpers.py b/tests/text/helpers.py index 9d328e81362..00fd83362f3 100644 --- a/tests/text/helpers.py +++ b/tests/text/helpers.py @@ -269,6 +269,7 @@ def run_functional_metric_test( targets: TEXT_METRIC_INPUT, metric_functional: Callable, sk_metric: Callable, + device: str = "cpu", metric_args: dict = None, fragment_kwargs: bool = False, input_order: INPUT_ORDER = INPUT_ORDER.PREDS_FIRST, @@ -290,8 +291,6 @@ def run_functional_metric_test( kwargs_update: Additional keyword arguments that will be passed with preds and targets when running update on the metric. """ - device = "cuda" if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu" - _functional_test( preds=preds, targets=targets, @@ -314,6 +313,7 @@ def run_class_metric_test( metric_class: Metric, sk_metric: Callable, dist_sync_on_step: bool, + device: str = "cpu", metric_args: dict = None, check_dist_sync_on_step: bool = True, check_batch: bool = True, @@ -363,6 +363,7 @@ def run_class_metric_test( check_dist_sync_on_step=check_dist_sync_on_step, check_batch=check_batch, atol=self.atol, + device=device, fragment_kwargs=fragment_kwargs, check_scriptable=check_scriptable, input_order=input_order, @@ -372,8 +373,6 @@ def run_class_metric_test( [(rank, self.poolSize) for rank in range(self.poolSize)], ) else: - device = "cuda" if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu" - _class_test( rank=0, worldsize=1, diff --git a/tests/text/test_bleu.py b/tests/text/test_bleu.py index 1866094fb6f..264273cc16f 100644 --- a/tests/text/test_bleu.py +++ b/tests/text/test_bleu.py @@ -16,8 +16,10 @@ import pytest from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu +from pytest_cases import parametrize_with_cases from torch import tensor +from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester from tests.text.inputs import _inputs_multiple_references from torchmetrics.functional.text.bleu import bleu_score @@ -53,9 +55,11 @@ def _compute_bleu_metric_nltk(list_of_references, hypotheses, weights, smoothing [(_inputs_multiple_references.preds, _inputs_multiple_references.targets)], ) class TestBLEUScore(TextTester): - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, weights, n_gram, smooth_func, smooth): + def test_bleu_score_class( + self, ddp, dist_sync_on_step, preds, targets, weights, n_gram, smooth_func, smooth, device + ): metric_args = {"n_gram": n_gram, "smooth": smooth} compute_bleu_metric_nltk = partial(_compute_bleu_metric_nltk, weights=weights, smoothing_function=smooth_func) @@ -66,11 +70,13 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, weights, metric_class=BLEUScore, sk_metric=compute_bleu_metric_nltk, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, input_order=INPUT_ORDER.TARGETS_FIRST, ) - def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_func, smooth): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_func, smooth, device): metric_args = {"n_gram": n_gram, "smooth": smooth} compute_bleu_metric_nltk = partial(_compute_bleu_metric_nltk, weights=weights, smoothing_function=smooth_func) @@ -79,6 +85,7 @@ def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_fun targets, metric_functional=bleu_score, sk_metric=compute_bleu_metric_nltk, + device=device, metric_args=metric_args, input_order=INPUT_ORDER.TARGETS_FIRST, ) diff --git a/tests/text/test_cer.py b/tests/text/test_cer.py index 729ec27d9ce..b29c271dba9 100644 --- a/tests/text/test_cer.py +++ b/tests/text/test_cer.py @@ -1,7 +1,9 @@ from typing import Callable, List, Union import pytest +from pytest_cases import parametrize_with_cases +from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from torchmetrics.functional.text.cer import char_error_rate @@ -30,9 +32,9 @@ def compare_fn(prediction: Union[str, List[str]], reference: Union[str, List[str class TestCharErrorRate(TextTester): """test class for character error rate.""" - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_cer_class(self, ddp, dist_sync_on_step, preds, targets): + def test_cer_class(self, ddp, dist_sync_on_step, preds, targets, device): """test modular version of cer.""" self.run_class_metric_test( ddp=ddp, @@ -41,16 +43,19 @@ def test_cer_class(self, ddp, dist_sync_on_step, preds, targets): metric_class=CharErrorRate, sk_metric=compare_fn, dist_sync_on_step=dist_sync_on_step, + device=device, input_order=INPUT_ORDER.PREDS_FIRST, ) - def test_cer_functional(self, preds, targets): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_cer_functional(self, preds, targets, device): """test functional version of cer.""" self.run_functional_metric_test( preds, targets, metric_functional=char_error_rate, sk_metric=compare_fn, + device=device, input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_chrf.py b/tests/text/test_chrf.py index 76743b6ecd4..c938bc57d3e 100644 --- a/tests/text/test_chrf.py +++ b/tests/text/test_chrf.py @@ -2,8 +2,10 @@ from typing import Sequence import pytest +from pytest_cases import parametrize_with_cases from torch import Tensor, tensor +from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester from tests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references from torchmetrics.functional.text.chrf import chrf_score @@ -48,10 +50,10 @@ def sacrebleu_chrf_fn( ) @pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestCHRFScore(TextTester): - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_chrf_score_class( - self, ddp, dist_sync_on_step, preds, targets, char_order, word_order, lowercase, whitespace + self, ddp, dist_sync_on_step, preds, targets, char_order, word_order, lowercase, whitespace, device ): metric_args = { "n_char_order": char_order, @@ -70,11 +72,13 @@ def test_chrf_score_class( metric_class=CHRFScore, sk_metric=nltk_metric, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, input_order=INPUT_ORDER.TARGETS_FIRST, ) - def test_chrf_score_functional(self, preds, targets, char_order, word_order, lowercase, whitespace): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_chrf_score_functional(self, preds, targets, char_order, word_order, lowercase, whitespace, device): metric_args = { "n_char_order": char_order, "n_word_order": word_order, @@ -91,6 +95,7 @@ def test_chrf_score_functional(self, preds, targets, char_order, word_order, low metric_functional=chrf_score, sk_metric=nltk_metric, metric_args=metric_args, + device=device, input_order=INPUT_ORDER.TARGETS_FIRST, ) diff --git a/tests/text/test_mer.py b/tests/text/test_mer.py index 3a9f126648a..c1dc4a040f4 100644 --- a/tests/text/test_mer.py +++ b/tests/text/test_mer.py @@ -1,7 +1,9 @@ from typing import Callable, List, Union import pytest +from pytest_cases import parametrize_with_cases +from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from torchmetrics.utilities.imports import _JIWER_AVAILABLE @@ -28,9 +30,9 @@ def _compute_mer_metric_jiwer(prediction: Union[str, List[str]], reference: Unio ], ) class TestMatchErrorRate(TextTester): - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_mer_class(self, ddp, dist_sync_on_step, preds, targets): + def test_mer_class(self, ddp, dist_sync_on_step, preds, targets, device): self.run_class_metric_test( ddp=ddp, @@ -39,16 +41,19 @@ def test_mer_class(self, ddp, dist_sync_on_step, preds, targets): metric_class=MatchErrorRate, sk_metric=_compute_mer_metric_jiwer, dist_sync_on_step=dist_sync_on_step, + device=device, input_order=INPUT_ORDER.PREDS_FIRST, ) - def test_mer_functional(self, preds, targets): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_mer_functional(self, preds, targets, device): self.run_functional_metric_test( preds, targets, metric_functional=match_error_rate, sk_metric=_compute_mer_metric_jiwer, + device=device, input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_rouge.py b/tests/text/test_rouge.py index 67b84cbc3bd..e94afa124a9 100644 --- a/tests/text/test_rouge.py +++ b/tests/text/test_rouge.py @@ -17,7 +17,9 @@ import pytest import torch +from pytest_cases import parametrize_with_cases +from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester from tests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_single_reference from torchmetrics.functional.text.rouge import rouge_score @@ -102,10 +104,10 @@ def _compute_rouge_score( ) @pytest.mark.parametrize("accumulate", ["avg", "best"]) class TestROUGEScore(TextTester): - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_rouge_score_class( - self, ddp, dist_sync_on_step, preds, targets, pl_rouge_metric_key, use_stemmer, accumulate + self, ddp, dist_sync_on_step, preds, targets, pl_rouge_metric_key, use_stemmer, accumulate, device ): metric_args = {"use_stemmer": use_stemmer, "accumulate": accumulate} rouge_level, metric = pl_rouge_metric_key.split("_") @@ -119,12 +121,14 @@ def test_rouge_score_class( metric_class=ROUGEScore, sk_metric=rouge_metric, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, input_order=INPUT_ORDER.PREDS_FIRST, key=pl_rouge_metric_key, ) - def test_rouge_score_functional(self, preds, targets, pl_rouge_metric_key, use_stemmer, accumulate): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_rouge_score_functional(self, preds, targets, pl_rouge_metric_key, use_stemmer, accumulate, device): metric_args = {"use_stemmer": use_stemmer, "accumulate": accumulate} rouge_level, metric = pl_rouge_metric_key.split("_") @@ -136,6 +140,7 @@ def test_rouge_score_functional(self, preds, targets, pl_rouge_metric_key, use_s targets, metric_functional=rouge_score, sk_metric=rouge_metric, + device=device, metric_args=metric_args, input_order=INPUT_ORDER.PREDS_FIRST, key=pl_rouge_metric_key, diff --git a/tests/text/test_sacre_bleu.py b/tests/text/test_sacre_bleu.py index 8cd34a807ff..78af45145c8 100644 --- a/tests/text/test_sacre_bleu.py +++ b/tests/text/test_sacre_bleu.py @@ -16,8 +16,10 @@ from typing import Sequence import pytest +from pytest_cases import parametrize_with_cases from torch import Tensor, tensor +from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester from tests.text.inputs import _inputs_multiple_references from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score @@ -47,9 +49,9 @@ def sacrebleu_fn(targets: Sequence[Sequence[str]], preds: Sequence[str], tokeniz @pytest.mark.parametrize("tokenize", TOKENIZERS) @pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestSacreBLEUScore(TextTester): - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, tokenize, lowercase): + def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, tokenize, lowercase, device): metric_args = {"tokenize": tokenize, "lowercase": lowercase} original_sacrebleu = partial(sacrebleu_fn, tokenize=tokenize, lowercase=lowercase) @@ -60,11 +62,13 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, tokenize metric_class=SacreBLEUScore, sk_metric=original_sacrebleu, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, input_order=INPUT_ORDER.TARGETS_FIRST, ) - def test_bleu_score_functional(self, preds, targets, tokenize, lowercase): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_bleu_score_functional(self, preds, targets, tokenize, lowercase, device): metric_args = {"tokenize": tokenize, "lowercase": lowercase} original_sacrebleu = partial(sacrebleu_fn, tokenize=tokenize, lowercase=lowercase) @@ -73,6 +77,7 @@ def test_bleu_score_functional(self, preds, targets, tokenize, lowercase): targets, metric_functional=sacre_bleu_score, sk_metric=original_sacrebleu, + device=device, metric_args=metric_args, input_order=INPUT_ORDER.TARGETS_FIRST, ) diff --git a/tests/text/test_ter.py b/tests/text/test_ter.py index 4f49cc4665c..1335c038381 100644 --- a/tests/text/test_ter.py +++ b/tests/text/test_ter.py @@ -2,8 +2,10 @@ from typing import Sequence import pytest +from pytest_cases import parametrize_with_cases from torch import Tensor, tensor +from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester from tests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references from torchmetrics.functional.text.ter import ter @@ -48,10 +50,10 @@ def sacrebleu_ter_fn( ) @pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestTER(TextTester): - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_chrf_score_class( - self, ddp, dist_sync_on_step, preds, targets, normalize, no_punctuation, asian_support, lowercase + self, ddp, dist_sync_on_step, preds, targets, normalize, no_punctuation, asian_support, lowercase, device ): metric_args = { "normalize": normalize, @@ -74,11 +76,13 @@ def test_chrf_score_class( metric_class=TER, sk_metric=nltk_metric, dist_sync_on_step=dist_sync_on_step, + device=device, metric_args=metric_args, input_order=INPUT_ORDER.TARGETS_FIRST, ) - def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, asian_support, lowercase): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, asian_support, lowercase, device): metric_args = { "normalize": normalize, "no_punctuation": no_punctuation, @@ -98,6 +102,7 @@ def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, a targets, metric_functional=ter, sk_metric=nltk_metric, + device=device, metric_args=metric_args, input_order=INPUT_ORDER.TARGETS_FIRST, ) diff --git a/tests/text/test_wer.py b/tests/text/test_wer.py index f9791594317..72b0804a128 100644 --- a/tests/text/test_wer.py +++ b/tests/text/test_wer.py @@ -1,7 +1,9 @@ from typing import Callable, List, Union import pytest +from pytest_cases import parametrize_with_cases +from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from torchmetrics.utilities.imports import _JIWER_AVAILABLE @@ -28,9 +30,9 @@ def _compute_wer_metric_jiwer(prediction: Union[str, List[str]], reference: Unio ], ) class TestWER(TextTester): - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_wer_class(self, ddp, dist_sync_on_step, preds, targets): + def test_wer_class(self, ddp, dist_sync_on_step, preds, targets, device): self.run_class_metric_test( ddp=ddp, @@ -39,16 +41,19 @@ def test_wer_class(self, ddp, dist_sync_on_step, preds, targets): metric_class=WER, sk_metric=_compute_wer_metric_jiwer, dist_sync_on_step=dist_sync_on_step, + device=device, input_order=INPUT_ORDER.PREDS_FIRST, ) - def test_wer_functional(self, preds, targets): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_wer_functional(self, preds, targets, device): self.run_functional_metric_test( preds, targets, metric_functional=wer, sk_metric=_compute_wer_metric_jiwer, + device=device, input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_wil.py b/tests/text/test_wil.py index 08b42159d11..eb82f1d136c 100644 --- a/tests/text/test_wil.py +++ b/tests/text/test_wil.py @@ -2,7 +2,9 @@ import pytest from jiwer import wil +from pytest_cases import parametrize_with_cases +from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from torchmetrics.functional.text.wil import word_information_lost @@ -23,9 +25,9 @@ def _compute_wil_metric_jiwer(prediction: Union[str, List[str]], reference: Unio ], ) class TestWordInfoLost(TextTester): - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_wil_class(self, ddp, dist_sync_on_step, preds, targets): + def test_wil_class(self, ddp, dist_sync_on_step, preds, targets, device): self.run_class_metric_test( ddp=ddp, @@ -34,16 +36,19 @@ def test_wil_class(self, ddp, dist_sync_on_step, preds, targets): metric_class=WordInfoLost, sk_metric=_compute_wil_metric_jiwer, dist_sync_on_step=dist_sync_on_step, + device=device, input_order=INPUT_ORDER.PREDS_FIRST, ) - def test_wil_functional(self, preds, targets): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_wil_functional(self, preds, targets, device): self.run_functional_metric_test( preds, targets, metric_functional=word_information_lost, sk_metric=_compute_wil_metric_jiwer, + device=device, input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_wip.py b/tests/text/test_wip.py index 0d655823232..b22560dda84 100644 --- a/tests/text/test_wip.py +++ b/tests/text/test_wip.py @@ -2,7 +2,9 @@ import pytest from jiwer import wip +from pytest_cases import parametrize_with_cases +from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from torchmetrics.functional.text.wip import word_information_preserved @@ -23,9 +25,9 @@ def _compute_wip_metric_jiwer(prediction: Union[str, List[str]], reference: Unio ], ) class TestWordInfoPreserved(TextTester): - @pytest.mark.parametrize("ddp", [False, True]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_wip_class(self, ddp, dist_sync_on_step, preds, targets): + def test_wip_class(self, ddp, dist_sync_on_step, preds, targets, device): self.run_class_metric_test( ddp=ddp, @@ -34,16 +36,19 @@ def test_wip_class(self, ddp, dist_sync_on_step, preds, targets): metric_class=WordInfoPreserved, sk_metric=_compute_wip_metric_jiwer, dist_sync_on_step=dist_sync_on_step, + device=device, input_order=INPUT_ORDER.PREDS_FIRST, ) - def test_wip_functional(self, preds, targets): + @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + def test_wip_functional(self, preds, targets, device): self.run_functional_metric_test( preds, targets, metric_functional=word_information_preserved, sk_metric=_compute_wip_metric_jiwer, + device=device, input_order=INPUT_ORDER.PREDS_FIRST, ) From 70bce49133b77a4dcfa83ca2d150bc7dc6bc0689 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 19 Dec 2021 16:37:35 +0000 Subject: [PATCH 13/18] Add test strategy to wrappers --- tests/wrappers/test_multioutput.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 35ab90af092..8a1419d88cf 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -4,11 +4,12 @@ import pytest import torch +from pytest_cases import parametrize_with_cases from sklearn.metrics import accuracy_score from sklearn.metrics import r2_score as sk_r2score from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester, MetricTesterDDPCases from torchmetrics import Metric from torchmetrics.classification import Accuracy from torchmetrics.regression import R2Score @@ -124,10 +125,19 @@ def _multi_target_sk_accuracy(preds, target, num_outputs): class TestMultioutputWrapper(MetricTester): """Test the MultioutputWrapper class with regression and classification inner metrics.""" - @pytest.mark.parametrize("ddp", [True, False]) + @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_multioutput_wrapper( - self, base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs, ddp, dist_sync_on_step + self, + base_metric_class, + compare_metric, + preds, + target, + num_outputs, + metric_kwargs, + ddp, + dist_sync_on_step, + device, ): """Test that the multioutput wrapper properly slices and computes outputs along the output dimension for both classification and regression metrics.""" @@ -138,5 +148,6 @@ def test_multioutput_wrapper( _MultioutputMetric, compare_metric, dist_sync_on_step, + device=device, metric_args=dict(num_outputs=num_outputs, base_metric_class=base_metric_class, **metric_kwargs), ) From 938ceeecd5af322eda36f9278a640e5a9b096e57 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Tue, 21 Dec 2021 04:08:45 -0800 Subject: [PATCH 14/18] Update paper.md (#690) --- docs/paper_JOSS/paper.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/paper_JOSS/paper.md b/docs/paper_JOSS/paper.md index d421ad67be6..9cd5c901e43 100644 --- a/docs/paper_JOSS/paper.md +++ b/docs/paper_JOSS/paper.md @@ -24,6 +24,7 @@ authors: affiliation: '5' - name: Changsheng Quan affiliation: '6' + - name: Maxim Grechkin - name: William Falcon affiliation: '1,7' affiliations: From 0f49516246d64722e26e4ea5724ce9c1e6150827 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Thu, 23 Dec 2021 17:49:20 +0000 Subject: [PATCH 15/18] Add ddp skip condition --- tests/helpers/testers.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 702be5acd15..b1ff9d2dd8a 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -349,20 +349,31 @@ def case_device_gpu(self): def case_ddp_false_device_cpu(self): return False, "cpu" - @case(tags="strategy") + @case( + tags="strategy", + marks=[ + pytest.mark.skipif(not torch.distributed.is_available(), reason="Distributed mode is not available."), + ], + ) def case_ddp_true_device_cpu(self): return True, "cpu" - @case(tags="strategy", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required")) + @case( + tags="strategy", + marks=[pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required")], + ) def case_ddp_false_device_gpu(self): return False, "cuda" @case( tags="strategy", - marks=pytest.mark.skipif( - not torch.cuda.is_available() or torch.cuda.device_count() < 2, - reason="More than one GPU required for DDP", - ), + marks=[ + pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="More than one GPU required for DDP", + ), + pytest.mark.skipif(not torch.distributed.is_available(), reason="Distributed mode is not available."), + ], ) def case_ddp_true_device_gpu(self): return True, "cuda" From f6eaf38f0f5fc8a3b09efd30735160da514745de Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 26 Dec 2021 16:33:17 +0000 Subject: [PATCH 16/18] Switch to pure pytest solution --- requirements/test.txt | 1 - tests/audio/test_pesq.py | 5 +- tests/audio/test_pit.py | 5 +- tests/audio/test_sdr.py | 5 +- tests/audio/test_si_sdr.py | 5 +- tests/audio/test_si_snr.py | 5 +- tests/audio/test_snr.py | 5 +- tests/audio/test_stoi.py | 5 +- tests/bases/test_aggregation.py | 3 +- tests/classification/test_accuracy.py | 5 +- tests/classification/test_auc.py | 5 +- tests/classification/test_auroc.py | 5 +- .../classification/test_average_precision.py | 5 +- .../test_binned_precision_recall.py | 5 +- .../classification/test_calibration_error.py | 5 +- tests/classification/test_cohen_kappa.py | 5 +- tests/classification/test_confusion_matrix.py | 5 +- tests/classification/test_f_beta.py | 5 +- tests/classification/test_hamming_distance.py | 5 +- tests/classification/test_hinge.py | 5 +- tests/classification/test_jaccard.py | 5 +- tests/classification/test_kl_divergence.py | 5 +- .../classification/test_matthews_corrcoef.py | 5 +- tests/classification/test_precision_recall.py | 5 +- .../test_precision_recall_curve.py | 5 +- tests/classification/test_roc.py | 5 +- tests/classification/test_specificity.py | 5 +- tests/classification/test_stat_scores.py | 7 +- tests/detection/test_map.py | 3 +- tests/helpers/testers.py | 85 +++++++++---------- tests/image/test_lpips.py | 3 +- tests/image/test_psnr.py | 5 +- tests/image/test_ssim.py | 5 +- tests/pairwise/test_pairwise_distance.py | 3 +- tests/regression/test_cosine_similarity.py | 5 +- tests/regression/test_explained_variance.py | 5 +- tests/regression/test_mean_error.py | 5 +- tests/regression/test_pearson.py | 5 +- tests/regression/test_r2.py | 5 +- tests/regression/test_spearman.py | 5 +- tests/regression/test_tweedie_deviance.py | 5 +- tests/retrieval/test_fallout.py | 7 +- tests/retrieval/test_hit_rate.py | 7 +- tests/retrieval/test_map.py | 7 +- tests/retrieval/test_mrr.py | 7 +- tests/retrieval/test_ndcg.py | 7 +- tests/retrieval/test_precision.py | 7 +- tests/retrieval/test_r_precision.py | 7 +- tests/retrieval/test_recall.py | 7 +- tests/text/test_bleu.py | 5 +- tests/text/test_cer.py | 5 +- tests/text/test_chrf.py | 5 +- tests/text/test_mer.py | 5 +- tests/text/test_rouge.py | 5 +- tests/text/test_sacre_bleu.py | 5 +- tests/text/test_ter.py | 5 +- tests/text/test_wer.py | 5 +- tests/text/test_wil.py | 5 +- tests/text/test_wip.py | 5 +- tests/wrappers/test_multioutput.py | 3 +- 60 files changed, 158 insertions(+), 226 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 6c57c0b2cda..c78e1e9a453 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -2,7 +2,6 @@ coverage>5.2 codecov>=2.1 pytest>=6.0 pytest-cov>2.10 -pytest-cases>=3.6.5 # pytest-xdist # pytest-flake8 flake8 diff --git a/tests/audio/test_pesq.py b/tests/audio/test_pesq.py index 80184351b8a..a45b041ad0f 100644 --- a/tests/audio/test_pesq.py +++ b/tests/audio/test_pesq.py @@ -17,7 +17,6 @@ import pytest import torch from pesq import pesq as pesq_backend -from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all @@ -76,7 +75,7 @@ def average_metric(preds, target, metric_func): class TestPESQ(MetricTester): atol = 1e-2 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_pesq(self, preds, target, sk_metric, fs, mode, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -90,7 +89,7 @@ def test_pesq(self, preds, target, sk_metric, fs, mode, ddp, dist_sync_on_step, metric_args=dict(fs=fs, mode=mode), ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_pesq_functional(self, preds, target, sk_metric, fs, mode, device): self.run_functional_metric_test( preds, diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index c01e95b67d8..0f16d12e480 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -18,7 +18,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from scipy.optimize import linear_sum_assignment from torch import Tensor @@ -113,7 +112,7 @@ def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Ten class TestPIT(MetricTester): atol = 1e-2 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_pit(self, preds, target, sk_metric, metric_func, eval_func, ddp, device, dist_sync_on_step): self.run_class_metric_test( @@ -127,7 +126,7 @@ def test_pit(self, preds, target, sk_metric, metric_func, eval_func, ddp, device metric_args=dict(metric_func=metric_func, eval_func=eval_func), ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_pit_functional(self, preds, target, sk_metric, device, metric_func, eval_func): self.run_functional_metric_test( preds=preds, diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index eb3f58762f5..d618ba4fbd3 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -19,7 +19,6 @@ import pytest import torch from mir_eval.separation import bss_eval_sources -from pytest_cases import parametrize_with_cases from scipy.io import wavfile from torch import Tensor @@ -77,7 +76,7 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tens class TestSDR(MetricTester): atol = 1e-2 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -91,7 +90,7 @@ def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step, device): metric_args=dict(), ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_sdr_functional(self, preds, target, sk_metric, device): self.run_functional_metric_test( preds, diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index b172b3dd8f2..2634471b57e 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -17,7 +17,6 @@ import pytest import speechmetrics import torch -from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all @@ -78,7 +77,7 @@ def average_metric(preds, target, metric_func): class TestSISDR(MetricTester): atol = 1e-2 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -92,7 +91,7 @@ def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_ste metric_args=dict(zero_mean=zero_mean), ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean, device): self.run_functional_metric_test( preds, diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index 744ed097da5..4f45706a7e0 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -17,7 +17,6 @@ import pytest import speechmetrics import torch -from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all @@ -73,7 +72,7 @@ def average_metric(preds, target, metric_func): class TestSISNR(MetricTester): atol = 1e-2 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -86,7 +85,7 @@ def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step, device): device=device, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_si_snr_functional(self, preds, target, sk_metric, device): self.run_functional_metric_test( preds, diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index 758751a2cd8..2f5115428b6 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -18,7 +18,6 @@ import pytest import torch from mir_eval.separation import bss_eval_images as mir_eval_bss_eval_images -from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all @@ -80,7 +79,7 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable): class TestSNR(MetricTester): atol = 1e-2 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -94,7 +93,7 @@ def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step, metric_args=dict(zero_mean=zero_mean), ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_snr_functional(self, preds, target, sk_metric, zero_mean, device): self.run_functional_metric_test( preds, diff --git a/tests/audio/test_stoi.py b/tests/audio/test_stoi.py index e53f13c6fea..004f3fee270 100644 --- a/tests/audio/test_stoi.py +++ b/tests/audio/test_stoi.py @@ -17,7 +17,6 @@ import pytest import torch from pystoi import stoi as stoi_backend -from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all @@ -76,7 +75,7 @@ def average_metric(preds, target, metric_func): class TestSTOI(MetricTester): atol = 1e-2 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -90,7 +89,7 @@ def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_st metric_args=dict(fs=fs, extended=extended), ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_stoi_functional(self, preds, target, sk_metric, fs, extended, device): self.run_functional_metric_test( preds, diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 75668e1ad1c..07f7b5f5421 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -1,7 +1,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric @@ -81,7 +80,7 @@ def update(self, values, weights): class TestAggregation(MetricTester): """Test aggregation metrics.""" - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False]) def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights, device): """test modular implementation.""" diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index eff44f5dddd..26702a5132a 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -16,7 +16,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import accuracy_score as sk_accuracy from torch import tensor @@ -82,7 +81,7 @@ def _sk_accuracy(preds, target, subset_accuracy): ], ) class TestAccuracies(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy, device): self.run_class_metric_test( @@ -96,7 +95,7 @@ def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accu metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy}, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_accuracy_fn(self, preds, target, subset_accuracy, device): self.run_functional_metric_test( preds, diff --git a/tests/classification/test_auc.py b/tests/classification/test_auc.py index 6bd6462e133..b8187c74196 100644 --- a/tests/classification/test_auc.py +++ b/tests/classification/test_auc.py @@ -16,7 +16,6 @@ import numpy as np import pytest -from pytest_cases import parametrize_with_cases from sklearn.metrics import auc as _sk_auc from torch import tensor @@ -56,7 +55,7 @@ def sk_auc(x, y, reorder=False): @pytest.mark.parametrize("x, y", _examples) class TestAUC(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_auc(self, x, y, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -69,7 +68,7 @@ def test_auc(self, x, y, ddp, dist_sync_on_step, device): device=device, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) @pytest.mark.parametrize("reorder", [True, False]) def test_auc_functional(self, x, y, reorder, device): self.run_functional_metric_test( diff --git a/tests/classification/test_auroc.py b/tests/classification/test_auroc.py index d4d4c131fa9..99adca172a0 100644 --- a/tests/classification/test_auroc.py +++ b/tests/classification/test_auroc.py @@ -15,7 +15,6 @@ import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import roc_auc_score as sk_roc_auc_score from tests.classification.inputs import _input_binary_prob @@ -100,7 +99,7 @@ def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average="macr ], ) class TestAUROC(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step, device): # max_fpr different from None is not support in multi class @@ -126,7 +125,7 @@ def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, dd metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr}, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr, device): # max_fpr different from None is not support in multi class if max_fpr is not None and num_classes != 1: diff --git a/tests/classification/test_average_precision.py b/tests/classification/test_average_precision.py index d747fa96bd0..bce338c3f55 100644 --- a/tests/classification/test_average_precision.py +++ b/tests/classification/test_average_precision.py @@ -15,7 +15,6 @@ import numpy as np import pytest -from pytest_cases import parametrize_with_cases from sklearn.metrics import average_precision_score as sk_average_precision_score from torch import tensor @@ -88,7 +87,7 @@ def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1, average= ) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) class TestAveragePrecision(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_average_precision(self, preds, target, sk_metric, num_classes, average, ddp, dist_sync_on_step, device): if target.max() > 1 and average == "micro": @@ -105,7 +104,7 @@ def test_average_precision(self, preds, target, sk_metric, num_classes, average, metric_args={"num_classes": num_classes, "average": average}, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_average_precision_functional(self, preds, target, sk_metric, num_classes, average, device): if target.max() > 1 and average == "micro": pytest.skip("average=micro and multiclass input cannot be used together") diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index b5256d11407..524d4142b6e 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -18,7 +18,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import average_precision_score as _sk_average_precision_score from sklearn.metrics import precision_recall_curve as _sk_precision_recall_curve from torch import Tensor @@ -78,7 +77,7 @@ def _sk_avg_prec_multiclass(predictions, targets, num_classes): class TestBinnedRecallAtPrecision(MetricTester): atol = 0.02 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8, 0.95]) def test_binned_recall_at_precision( @@ -113,7 +112,7 @@ def test_binned_recall_at_precision( ], ) class TestBinnedAveragePrecision(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("thresholds", (301, torch.linspace(0.0, 1.0, 101))) def test_binned_average_precision( diff --git a/tests/classification/test_calibration_error.py b/tests/classification/test_calibration_error.py index 87ff211432d..735416d5cb4 100644 --- a/tests/classification/test_calibration_error.py +++ b/tests/classification/test_calibration_error.py @@ -3,7 +3,6 @@ import numpy as np import pytest -from pytest_cases import parametrize_with_cases from tests.classification.inputs import _input_binary_prob from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob @@ -52,7 +51,7 @@ def _sk_calibration(preds, target, n_bins, norm, debias=False): ], ) class TestCE(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_ce(self, preds, target, n_bins, ddp, dist_sync_on_step, norm, device): self.run_class_metric_test( @@ -66,7 +65,7 @@ def test_ce(self, preds, target, n_bins, ddp, dist_sync_on_step, norm, device): metric_args={"n_bins": n_bins, "norm": norm}, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_ce_functional(self, preds, target, n_bins, norm, device): self.run_functional_metric_test( preds, diff --git a/tests/classification/test_cohen_kappa.py b/tests/classification/test_cohen_kappa.py index 83c91397f69..4f187eeb645 100644 --- a/tests/classification/test_cohen_kappa.py +++ b/tests/classification/test_cohen_kappa.py @@ -3,7 +3,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa from tests.classification.inputs import _input_binary, _input_binary_prob @@ -94,7 +93,7 @@ def _sk_cohen_kappa_multidim_multiclass(preds, target, weights=None): class TestCohenKappa(MetricTester): atol = 1e-5 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_cohen_kappa(self, weights, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -108,7 +107,7 @@ def test_cohen_kappa(self, weights, preds, target, sk_metric, num_classes, ddp, metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_cohen_kappa_functional(self, weights, preds, target, sk_metric, num_classes, device): self.run_functional_metric_test( preds, diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index e3631f11437..7cf1e27c04b 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -16,7 +16,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import confusion_matrix as sk_confusion_matrix from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix @@ -129,7 +128,7 @@ def _sk_cm_multidim_multiclass(preds, target, normalize=None): ], ) class TestConfusionMatrix(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_confusion_matrix( self, normalize, preds, target, sk_metric, num_classes, multilabel, ddp, dist_sync_on_step, device @@ -150,7 +149,7 @@ def test_confusion_matrix( }, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel, device): self.run_functional_metric_test( preds=preds, diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index 4dffad4c476..89884a1046d 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -17,7 +17,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import f1_score, fbeta_score from torch import Tensor @@ -240,7 +239,7 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected): ], ) class TestFBeta(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_fbeta_f1( self, @@ -296,7 +295,7 @@ def test_fbeta_f1( check_batch=True, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_fbeta_f1_functional( self, preds: Tensor, diff --git a/tests/classification/test_hamming_distance.py b/tests/classification/test_hamming_distance.py index d75b4267fed..bdb445cf7d6 100644 --- a/tests/classification/test_hamming_distance.py +++ b/tests/classification/test_hamming_distance.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from pytest_cases import parametrize_with_cases from sklearn.metrics import hamming_loss as sk_hamming_loss from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob @@ -62,7 +61,7 @@ def _sk_hamming_loss(preds, target): ], ) class TestHammingDistance(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target, device): self.run_class_metric_test( @@ -76,7 +75,7 @@ def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target, dev metric_args={"threshold": THRESHOLD}, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_hamming_distance_fn(self, preds, target, device): self.run_functional_metric_test( preds=preds, diff --git a/tests/classification/test_hinge.py b/tests/classification/test_hinge.py index 06bcdb55d9b..1d61446d31e 100644 --- a/tests/classification/test_hinge.py +++ b/tests/classification/test_hinge.py @@ -16,7 +16,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import hinge_loss as sk_hinge from sklearn.preprocessing import OneHotEncoder @@ -89,7 +88,7 @@ def _sk_hinge(preds, target, squared, multiclass_mode): ], ) class TestHinge(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multiclass_mode, device): self.run_class_metric_test( @@ -106,7 +105,7 @@ def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multi }, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_hinge_fn(self, preds, target, squared, multiclass_mode, device): self.run_functional_metric_test( preds=preds, diff --git a/tests/classification/test_jaccard.py b/tests/classification/test_jaccard.py index 4c62afc419b..f1c917aeeab 100644 --- a/tests/classification/test_jaccard.py +++ b/tests/classification/test_jaccard.py @@ -16,7 +16,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import jaccard_score as sk_jaccard_score from torch import Tensor, tensor @@ -103,7 +102,7 @@ def _sk_jaccard_multidim_multiclass(preds, target, average=None): ], ) class TestJaccardIndex(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_jaccard(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, device): average = "macro" if reduction == "elementwise_mean" else None # convert tags @@ -118,7 +117,7 @@ def test_jaccard(self, reduction, preds, target, sk_metric, num_classes, ddp, di metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction}, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_jaccard_functional(self, reduction, preds, target, sk_metric, num_classes, device): average = "macro" if reduction == "elementwise_mean" else None # convert tags self.run_functional_metric_test( diff --git a/tests/classification/test_kl_divergence.py b/tests/classification/test_kl_divergence.py index a7a8ff64d34..4bef6cddf1f 100644 --- a/tests/classification/test_kl_divergence.py +++ b/tests/classification/test_kl_divergence.py @@ -18,7 +18,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from scipy.stats import entropy from torch import Tensor @@ -61,7 +60,7 @@ def _sk_metric(p: Tensor, q: Tensor, log_prob: bool, reduction: Optional[str] = class TestKLDivergence(MetricTester): atol = 1e-6 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_kldivergence(self, reduction, p, q, log_prob, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -75,7 +74,7 @@ def test_kldivergence(self, reduction, p, q, log_prob, ddp, dist_sync_on_step, d metric_args=dict(log_prob=log_prob, reduction=reduction), ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_kldivergence_functional(self, reduction, p, q, log_prob, device): # todo: `num_outputs` is unused self.run_functional_metric_test( diff --git a/tests/classification/test_matthews_corrcoef.py b/tests/classification/test_matthews_corrcoef.py index cc1d0d674b0..3850cddfd2b 100644 --- a/tests/classification/test_matthews_corrcoef.py +++ b/tests/classification/test_matthews_corrcoef.py @@ -14,7 +14,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import matthews_corrcoef as sk_matthews_corrcoef from tests.classification.inputs import _input_binary, _input_binary_prob @@ -102,7 +101,7 @@ def _sk_matthews_corrcoef_multidim_multiclass(preds, target): ], ) class TestMatthewsCorrCoef(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -119,7 +118,7 @@ def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dis }, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classes, device): self.run_functional_metric_test( preds, diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index ecdb6cefb12..7021bc6d687 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -17,7 +17,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import precision_score, recall_score from torch import Tensor, tensor @@ -211,7 +210,7 @@ def test_no_support(metric_class, metric_fn): ], ) class TestPrecisionRecall(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False]) def test_precision_recall_class( self, @@ -268,7 +267,7 @@ def test_precision_recall_class( check_batch=True, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_precision_recall_fn( self, preds: Tensor, diff --git a/tests/classification/test_precision_recall_curve.py b/tests/classification/test_precision_recall_curve.py index 7921c57ce07..97b1a3561e5 100644 --- a/tests/classification/test_precision_recall_curve.py +++ b/tests/classification/test_precision_recall_curve.py @@ -16,7 +16,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve from torch import Tensor, tensor @@ -77,7 +76,7 @@ def _sk_prec_rc_multidim_multiclass_prob(preds, target, num_classes=1): ], ) class TestPrecisionRecallCurve(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -91,7 +90,7 @@ def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp metric_args={"num_classes": num_classes}, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_classes, device): self.run_functional_metric_test( preds, diff --git a/tests/classification/test_roc.py b/tests/classification/test_roc.py index 177ddb9c834..263d49f86ea 100644 --- a/tests/classification/test_roc.py +++ b/tests/classification/test_roc.py @@ -16,7 +16,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import roc_curve as sk_roc_curve from torch import tensor @@ -96,7 +95,7 @@ def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1): ], ) class TestROC(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -110,7 +109,7 @@ def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step metric_args={"num_classes": num_classes}, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_roc_functional(self, preds, target, sk_metric, num_classes, device): self.run_functional_metric_test( preds, diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index ccc24c85654..9f01127ab6a 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -18,7 +18,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import multilabel_confusion_matrix from torch import Tensor, tensor @@ -219,7 +218,7 @@ def test_no_support(metric_class, metric_fn): ], ) class TestSpecificity(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_specificity_class( self, @@ -274,7 +273,7 @@ def test_specificity_class( check_batch=True, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_specificity_fn( self, preds: Tensor, diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index fc5ba4c2b3d..a8564ef2321 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -17,7 +17,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import multilabel_confusion_matrix from torch import Tensor, tensor @@ -173,8 +172,8 @@ def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index): class TestStatScores(MetricTester): # DDP tests temporarily disabled due to hanging issues @pytest.mark.parametrize("ddp", [False]) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") - # @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) + # @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_stat_scores_class( self, @@ -225,7 +224,7 @@ def test_stat_scores_class( check_batch=True, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_stat_scores_fn( self, sk_fn: Callable, diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index eff35689ead..c1dffc6f061 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -16,7 +16,6 @@ import pytest import torch -from pytest_cases import parametrize_with_cases from tests.helpers.testers import MetricTester, MetricTesterDDPCases from torchmetrics.detection.map import MAP @@ -175,7 +174,7 @@ class TestMAP(MetricTester): atol = 1e-1 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) def test_map(self, ddp, device): """Test modular implementation for correctness.""" self.run_class_metric_test( diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index b1ff9d2dd8a..2a7f01d554d 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -20,7 +20,6 @@ import numpy as np import pytest import torch -from pytest_cases import case from torch import Tensor, tensor from torch.multiprocessing import Pool, set_start_method @@ -329,54 +328,46 @@ def _assert_half_support( # https://github.com/pytest-dev/pytest/issues/349 class MetricTesterDDPCases: - @case(tags="ddp") - def case_ddp_false(self): - return False - - @case(tags="ddp") - def case_ddp_true(self): - return True - - @case(tags="device") - def case_device_cpu(self): - return "cpu" - - @case(tags="device") - def case_device_gpu(self): - return "cuda" - - @case(tags="strategy") - def case_ddp_false_device_cpu(self): - return False, "cpu" - - @case( - tags="strategy", - marks=[ - pytest.mark.skipif(not torch.distributed.is_available(), reason="Distributed mode is not available."), - ], - ) - def case_ddp_true_device_cpu(self): - return True, "cpu" + def name_ddp(): + return "ddp" - @case( - tags="strategy", - marks=[pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required")], - ) - def case_ddp_false_device_gpu(self): - return False, "cuda" - - @case( - tags="strategy", - marks=[ - pytest.mark.skipif( - not torch.cuda.is_available() or torch.cuda.device_count() < 2, - reason="More than one GPU required for DDP", + def cases_ddp(): + return [False, True] + + def name_device(): + return "device" + + def cases_device(): + return ["cpu", "cuda"] + + def name_strategy(): + return ",".join([MetricTesterDDPCases.name_ddp(), MetricTesterDDPCases.name_device()]) + + def cases_strategy(): + return [ + (False, "cpu"), + pytest.param( + True, + "cpu", + marks=pytest.mark.skipif( + not torch.distributed.is_available(), reason="Distributed mode is not available." + ), ), - pytest.mark.skipif(not torch.distributed.is_available(), reason="Distributed mode is not available."), - ], - ) - def case_ddp_true_device_gpu(self): - return True, "cuda" + pytest.param(False, "cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required")), + pytest.param( + True, + "cuda", + marks=[ + pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="More than one GPU required for DDP", + ), + pytest.mark.skipif( + not torch.distributed.is_available(), reason="Distributed mode is not available." + ), + ], + ), + ] class MetricTester: diff --git a/tests/image/test_lpips.py b/tests/image/test_lpips.py index 5b42bdbbe54..2625ccd4138 100644 --- a/tests/image/test_lpips.py +++ b/tests/image/test_lpips.py @@ -17,7 +17,6 @@ import pytest import torch from lpips import LPIPS as reference_LPIPS -from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all @@ -47,7 +46,7 @@ def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, reduction: str = "mea @pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed") @pytest.mark.parametrize("net_type", ["vgg", "alex", "squeeze"]) class TestLPIPS(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) def test_lpips(self, net_type, ddp, device): """test modular implementation for correctness.""" self.run_class_metric_test( diff --git a/tests/image/test_psnr.py b/tests/image/test_psnr.py index 8c2a920b544..f5d61c22680 100644 --- a/tests/image/test_psnr.py +++ b/tests/image/test_psnr.py @@ -18,7 +18,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from skimage.metrics import peak_signal_noise_ratio from tests.helpers import seed_all @@ -94,7 +93,7 @@ def _base_e_sk_psnr(preds, target, data_range, reduction, dim): ], ) class TestPSNR(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, ddp, dist_sync_on_step, device): _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} @@ -109,7 +108,7 @@ def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, dist_sync_on_step=dist_sync_on_step, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduction, dim, device): _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} self.run_functional_metric_test( diff --git a/tests/image/test_ssim.py b/tests/image/test_ssim.py index ca24b68bced..d45abf4023f 100644 --- a/tests/image/test_ssim.py +++ b/tests/image/test_ssim.py @@ -16,7 +16,6 @@ import pytest import torch -from pytest_cases import parametrize_with_cases from skimage.metrics import structural_similarity from tests.helpers import seed_all @@ -73,7 +72,7 @@ def _sk_ssim(preds, target, data_range, multichannel, kernel_size): class TestSSIM(MetricTester): atol = 6e-3 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_ssim(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -87,7 +86,7 @@ def test_ssim(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_ dist_sync_on_step=dist_sync_on_step, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_ssim_functional(self, preds, target, multichannel, kernel_size, device): self.run_functional_metric_test( preds, diff --git a/tests/pairwise/test_pairwise_distance.py b/tests/pairwise/test_pairwise_distance.py index a424df61358..b54a572b0f2 100644 --- a/tests/pairwise/test_pairwise_distance.py +++ b/tests/pairwise/test_pairwise_distance.py @@ -16,7 +16,6 @@ import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, linear_kernel, manhattan_distances from tests.helpers import seed_all @@ -82,7 +81,7 @@ class TestPairwise(MetricTester): atol = 1e-4 - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_pairwise_functional(self, x, y, metric_functional, sk_fn, reduction, device): """test functional pairwise implementations.""" self.run_functional_metric_test( diff --git a/tests/regression/test_cosine_similarity.py b/tests/regression/test_cosine_similarity.py index 04219236b86..5b97762b611 100644 --- a/tests/regression/test_cosine_similarity.py +++ b/tests/regression/test_cosine_similarity.py @@ -17,7 +17,6 @@ import numpy as np import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics.pairwise import cosine_similarity as sk_cosine from tests.helpers import seed_all @@ -83,7 +82,7 @@ def _single_target_sk_metric(preds, target, reduction, sk_fn=sk_cosine): ], ) class TestCosineSimilarity(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_cosine_similarity(self, reduction, preds, target, sk_metric, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -97,7 +96,7 @@ def test_cosine_similarity(self, reduction, preds, target, sk_metric, ddp, dist_ metric_args=dict(reduction=reduction), ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_cosine_similarity_functional(self, reduction, preds, target, sk_metric, device): self.run_functional_metric_test( preds, diff --git a/tests/regression/test_explained_variance.py b/tests/regression/test_explained_variance.py index 5f8e3eecf7e..b5b189ceb62 100644 --- a/tests/regression/test_explained_variance.py +++ b/tests/regression/test_explained_variance.py @@ -16,7 +16,6 @@ import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import explained_variance_score from tests.helpers import seed_all @@ -63,7 +62,7 @@ def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score): ], ) class TestExplainedVariance(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -77,7 +76,7 @@ def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, di metric_args=dict(multioutput=multioutput), ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_explained_variance_functional(self, multioutput, preds, target, sk_metric, device): self.run_functional_metric_test( preds, diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index 7d4d9b70a7b..fd0b3e4493c 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -17,7 +17,6 @@ import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error from sklearn.metrics import mean_absolute_percentage_error as sk_mean_abs_percentage_error from sklearn.metrics import mean_squared_error as sk_mean_squared_error @@ -109,7 +108,7 @@ def _multi_target_sk_metric(preds, target, sk_fn, metric_args): ], ) class TestMeanError(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_mean_error_class( self, @@ -136,7 +135,7 @@ def test_mean_error_class( metric_args=metric_args, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_mean_error_functional( self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args, device ): diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index cb0fb1917e0..2b006469765 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -15,7 +15,6 @@ import pytest import torch -from pytest_cases import parametrize_with_cases from scipy.stats import pearsonr from tests.helpers import seed_all @@ -54,7 +53,7 @@ def _sk_pearsonr(preds, target): class TestPearsonCorrcoef(MetricTester): atol = 1e-2 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) def test_pearson_corrcoef(self, preds, target, ddp, device): self.run_class_metric_test( ddp=ddp, @@ -66,7 +65,7 @@ def test_pearson_corrcoef(self, preds, target, ddp, device): device=device, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_pearson_corrcoef_functional(self, preds, target, device): self.run_functional_metric_test( preds=preds, target=target, metric_functional=pearson_corrcoef, sk_metric=_sk_pearsonr, device=device diff --git a/tests/regression/test_r2.py b/tests/regression/test_r2.py index 4ad2b8f882d..94a432c13b2 100644 --- a/tests/regression/test_r2.py +++ b/tests/regression/test_r2.py @@ -16,7 +16,6 @@ import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import r2_score as sk_r2score from tests.helpers import seed_all @@ -70,7 +69,7 @@ def _multi_target_sk_metric(preds, target, adjusted, multioutput): ], ) class TestR2Score(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -84,7 +83,7 @@ def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, metric_args=dict(adjusted=adjusted, multioutput=multioutput, num_outputs=num_outputs), ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, device): # todo: `num_outputs` is unused self.run_functional_metric_test( diff --git a/tests/regression/test_spearman.py b/tests/regression/test_spearman.py index d2ed7b4c861..5fd8781f8de 100644 --- a/tests/regression/test_spearman.py +++ b/tests/regression/test_spearman.py @@ -15,7 +15,6 @@ import pytest import torch -from pytest_cases import parametrize_with_cases from scipy.stats import rankdata, spearmanr from tests.helpers import seed_all @@ -77,7 +76,7 @@ def _sk_metric(preds, target): class TestSpearmanCorrcoef(MetricTester): atol = 1e-2 - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_spearman_corrcoef(self, preds, target, ddp, dist_sync_on_step, device): self.run_class_metric_test( @@ -90,7 +89,7 @@ def test_spearman_corrcoef(self, preds, target, ddp, dist_sync_on_step, device): device=device, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_spearman_corrcoef_functional(self, preds, target, device): self.run_functional_metric_test(preds, target, spearman_corrcoef, _sk_metric, device=device) diff --git a/tests/regression/test_tweedie_deviance.py b/tests/regression/test_tweedie_deviance.py index 2a2298fa3ea..1b68db771f2 100644 --- a/tests/regression/test_tweedie_deviance.py +++ b/tests/regression/test_tweedie_deviance.py @@ -16,7 +16,6 @@ import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import mean_tweedie_deviance from torch import Tensor @@ -61,7 +60,7 @@ def _sk_deviance(preds: Tensor, targets: Tensor, power: float): ], ) class TestDevianceScore(MetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_deviance_scores_class(self, ddp, dist_sync_on_step, preds, targets, power, device): self.run_class_metric_test( @@ -75,7 +74,7 @@ def test_deviance_scores_class(self, ddp, dist_sync_on_step, preds, targets, pow metric_args=dict(power=power), ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_deviance_scores_functional(self, preds, targets, power, device): self.run_functional_metric_test( preds, diff --git a/tests/retrieval/test_fallout.py b/tests/retrieval/test_fallout.py index 330976e4d8e..c66a54dc5ca 100644 --- a/tests/retrieval/test_fallout.py +++ b/tests/retrieval/test_fallout.py @@ -13,7 +13,6 @@ # limitations under the License. import numpy as np import pytest -from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all @@ -55,7 +54,7 @@ def _fallout_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): class TestFallOut(RetrievalMetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -88,7 +87,7 @@ def test_class_metric( metric_args=metric_args, ) - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("k", [None, 1, 4, 10]) @@ -119,7 +118,7 @@ def test_class_metric_ignore_index( metric_args=metric_args, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, device: str): diff --git a/tests/retrieval/test_hit_rate.py b/tests/retrieval/test_hit_rate.py index 250305aecd7..13b0b539b94 100644 --- a/tests/retrieval/test_hit_rate.py +++ b/tests/retrieval/test_hit_rate.py @@ -13,7 +13,6 @@ # limitations under the License. import numpy as np import pytest -from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all @@ -52,7 +51,7 @@ def _hit_rate_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): class TestHitRate(RetrievalMetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -84,7 +83,7 @@ def test_class_metric( metric_args=metric_args, ) - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("k", [None, 1, 4, 10]) @@ -114,7 +113,7 @@ def test_class_metric_ignore_index( metric_args=metric_args, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, device: str): diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index a874ff42fe7..a5027b0264c 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from pytest_cases import parametrize_with_cases from sklearn.metrics import average_precision_score as sk_average_precision_score from torch import Tensor @@ -35,7 +34,7 @@ class TestMAP(RetrievalMetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -65,7 +64,7 @@ def test_class_metric( metric_args=metric_args, ) - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) @@ -93,7 +92,7 @@ def test_class_metric_ignore_index( metric_args=metric_args, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) @pytest.mark.parametrize(**_default_metric_functional_input_arguments) def test_functional_metric(self, preds: Tensor, target: Tensor, device: str): self.run_functional_metric_test( diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 19c01a8a854..eefc383f5e4 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -13,7 +13,6 @@ # limitations under the License. import numpy as np import pytest -from pytest_cases import parametrize_with_cases from sklearn.metrics import label_ranking_average_precision_score from torch import Tensor @@ -57,7 +56,7 @@ def _reciprocal_rank(target: np.ndarray, preds: np.ndarray): class TestMRR(RetrievalMetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -87,7 +86,7 @@ def test_class_metric( metric_args=metric_args, ) - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) @@ -115,7 +114,7 @@ def test_class_metric_ignore_index( metric_args=metric_args, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) @pytest.mark.parametrize(**_default_metric_functional_input_arguments) def test_functional_metric(self, preds: Tensor, target: Tensor, device: str): self.run_functional_metric_test( diff --git a/tests/retrieval/test_ndcg.py b/tests/retrieval/test_ndcg.py index cf8f31ea78f..6c98727b42f 100644 --- a/tests/retrieval/test_ndcg.py +++ b/tests/retrieval/test_ndcg.py @@ -13,7 +13,6 @@ # limitations under the License. import numpy as np import pytest -from pytest_cases import parametrize_with_cases from sklearn.metrics import ndcg_score from torch import Tensor @@ -51,7 +50,7 @@ def _ndcg_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): class TestNDCG(RetrievalMetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 3]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -83,7 +82,7 @@ def test_class_metric( metric_args=metric_args, ) - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("k", [None, 1, 4, 10]) @@ -113,7 +112,7 @@ def test_class_metric_ignore_index( metric_args=metric_args, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) @pytest.mark.parametrize(**_default_metric_functional_input_arguments_with_non_binary_target) @pytest.mark.parametrize("k", [None, 1, 4, 10]) def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, device: str): diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index db7e2c83739..5b704aca272 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -13,7 +13,6 @@ # limitations under the License. import numpy as np import pytest -from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all @@ -56,7 +55,7 @@ def _precision_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): class TestPrecision(RetrievalMetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -88,7 +87,7 @@ def test_class_metric( metric_args=metric_args, ) - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("k", [None, 1, 4, 10]) @@ -118,7 +117,7 @@ def test_class_metric_ignore_index( metric_args=metric_args, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, device: str): diff --git a/tests/retrieval/test_r_precision.py b/tests/retrieval/test_r_precision.py index b3e590cc09c..d26454729c9 100644 --- a/tests/retrieval/test_r_precision.py +++ b/tests/retrieval/test_r_precision.py @@ -13,7 +13,6 @@ # limitations under the License. import numpy as np import pytest -from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all @@ -51,7 +50,7 @@ def _r_precision(target: np.ndarray, preds: np.ndarray): class TestRPrecision(RetrievalMetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -81,7 +80,7 @@ def test_class_metric( metric_args=metric_args, ) - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) @@ -109,7 +108,7 @@ def test_class_metric_ignore_index( metric_args=metric_args, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) @pytest.mark.parametrize(**_default_metric_functional_input_arguments) def test_functional_metric(self, preds: Tensor, target: Tensor, device: str): self.run_functional_metric_test( diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index 63cbe890c70..4bdf17f763e 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -13,7 +13,6 @@ # limitations under the License. import numpy as np import pytest -from pytest_cases import parametrize_with_cases from torch import Tensor from tests.helpers import seed_all @@ -55,7 +54,7 @@ def _recall_at_k(target: np.ndarray, preds: np.ndarray, k: int = None): class TestRecall(RetrievalMetricTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @@ -87,7 +86,7 @@ def test_class_metric( metric_args=metric_args, ) - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("k", [None, 1, 4, 10]) @@ -117,7 +116,7 @@ def test_class_metric_ignore_index( metric_args=metric_args, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) @pytest.mark.parametrize(**_default_metric_functional_input_arguments) @pytest.mark.parametrize("k", [None, 1, 4, 10]) def test_functional_metric(self, preds: Tensor, target: Tensor, k: int, device: str): diff --git a/tests/text/test_bleu.py b/tests/text/test_bleu.py index 264273cc16f..ba569c6c8c6 100644 --- a/tests/text/test_bleu.py +++ b/tests/text/test_bleu.py @@ -16,7 +16,6 @@ import pytest from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu -from pytest_cases import parametrize_with_cases from torch import tensor from tests.helpers.testers import MetricTesterDDPCases @@ -55,7 +54,7 @@ def _compute_bleu_metric_nltk(list_of_references, hypotheses, weights, smoothing [(_inputs_multiple_references.preds, _inputs_multiple_references.targets)], ) class TestBLEUScore(TextTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_bleu_score_class( self, ddp, dist_sync_on_step, preds, targets, weights, n_gram, smooth_func, smooth, device @@ -75,7 +74,7 @@ def test_bleu_score_class( input_order=INPUT_ORDER.TARGETS_FIRST, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_func, smooth, device): metric_args = {"n_gram": n_gram, "smooth": smooth} compute_bleu_metric_nltk = partial(_compute_bleu_metric_nltk, weights=weights, smoothing_function=smooth_func) diff --git a/tests/text/test_cer.py b/tests/text/test_cer.py index b29c271dba9..0cdd0238e8d 100644 --- a/tests/text/test_cer.py +++ b/tests/text/test_cer.py @@ -1,7 +1,6 @@ from typing import Callable, List, Union import pytest -from pytest_cases import parametrize_with_cases from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester @@ -32,7 +31,7 @@ def compare_fn(prediction: Union[str, List[str]], reference: Union[str, List[str class TestCharErrorRate(TextTester): """test class for character error rate.""" - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_cer_class(self, ddp, dist_sync_on_step, preds, targets, device): """test modular version of cer.""" @@ -47,7 +46,7 @@ def test_cer_class(self, ddp, dist_sync_on_step, preds, targets, device): input_order=INPUT_ORDER.PREDS_FIRST, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_cer_functional(self, preds, targets, device): """test functional version of cer.""" self.run_functional_metric_test( diff --git a/tests/text/test_chrf.py b/tests/text/test_chrf.py index c938bc57d3e..f65a6290fde 100644 --- a/tests/text/test_chrf.py +++ b/tests/text/test_chrf.py @@ -2,7 +2,6 @@ from typing import Sequence import pytest -from pytest_cases import parametrize_with_cases from torch import Tensor, tensor from tests.helpers.testers import MetricTesterDDPCases @@ -50,7 +49,7 @@ def sacrebleu_chrf_fn( ) @pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestCHRFScore(TextTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_chrf_score_class( self, ddp, dist_sync_on_step, preds, targets, char_order, word_order, lowercase, whitespace, device @@ -77,7 +76,7 @@ def test_chrf_score_class( input_order=INPUT_ORDER.TARGETS_FIRST, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_chrf_score_functional(self, preds, targets, char_order, word_order, lowercase, whitespace, device): metric_args = { "n_char_order": char_order, diff --git a/tests/text/test_mer.py b/tests/text/test_mer.py index c1dc4a040f4..b453fad58b8 100644 --- a/tests/text/test_mer.py +++ b/tests/text/test_mer.py @@ -1,7 +1,6 @@ from typing import Callable, List, Union import pytest -from pytest_cases import parametrize_with_cases from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester @@ -30,7 +29,7 @@ def _compute_mer_metric_jiwer(prediction: Union[str, List[str]], reference: Unio ], ) class TestMatchErrorRate(TextTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_mer_class(self, ddp, dist_sync_on_step, preds, targets, device): @@ -45,7 +44,7 @@ def test_mer_class(self, ddp, dist_sync_on_step, preds, targets, device): input_order=INPUT_ORDER.PREDS_FIRST, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_mer_functional(self, preds, targets, device): self.run_functional_metric_test( diff --git a/tests/text/test_rouge.py b/tests/text/test_rouge.py index e94afa124a9..95d9f8d002c 100644 --- a/tests/text/test_rouge.py +++ b/tests/text/test_rouge.py @@ -17,7 +17,6 @@ import pytest import torch -from pytest_cases import parametrize_with_cases from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester @@ -104,7 +103,7 @@ def _compute_rouge_score( ) @pytest.mark.parametrize("accumulate", ["avg", "best"]) class TestROUGEScore(TextTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_rouge_score_class( self, ddp, dist_sync_on_step, preds, targets, pl_rouge_metric_key, use_stemmer, accumulate, device @@ -127,7 +126,7 @@ def test_rouge_score_class( key=pl_rouge_metric_key, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_rouge_score_functional(self, preds, targets, pl_rouge_metric_key, use_stemmer, accumulate, device): metric_args = {"use_stemmer": use_stemmer, "accumulate": accumulate} diff --git a/tests/text/test_sacre_bleu.py b/tests/text/test_sacre_bleu.py index 78af45145c8..da50a276f7b 100644 --- a/tests/text/test_sacre_bleu.py +++ b/tests/text/test_sacre_bleu.py @@ -16,7 +16,6 @@ from typing import Sequence import pytest -from pytest_cases import parametrize_with_cases from torch import Tensor, tensor from tests.helpers.testers import MetricTesterDDPCases @@ -49,7 +48,7 @@ def sacrebleu_fn(targets: Sequence[Sequence[str]], preds: Sequence[str], tokeniz @pytest.mark.parametrize("tokenize", TOKENIZERS) @pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestSacreBLEUScore(TextTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, tokenize, lowercase, device): metric_args = {"tokenize": tokenize, "lowercase": lowercase} @@ -67,7 +66,7 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, tokenize input_order=INPUT_ORDER.TARGETS_FIRST, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_bleu_score_functional(self, preds, targets, tokenize, lowercase, device): metric_args = {"tokenize": tokenize, "lowercase": lowercase} original_sacrebleu = partial(sacrebleu_fn, tokenize=tokenize, lowercase=lowercase) diff --git a/tests/text/test_ter.py b/tests/text/test_ter.py index 1335c038381..d0754882dd2 100644 --- a/tests/text/test_ter.py +++ b/tests/text/test_ter.py @@ -2,7 +2,6 @@ from typing import Sequence import pytest -from pytest_cases import parametrize_with_cases from torch import Tensor, tensor from tests.helpers.testers import MetricTesterDDPCases @@ -50,7 +49,7 @@ def sacrebleu_ter_fn( ) @pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestTER(TextTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_chrf_score_class( self, ddp, dist_sync_on_step, preds, targets, normalize, no_punctuation, asian_support, lowercase, device @@ -81,7 +80,7 @@ def test_chrf_score_class( input_order=INPUT_ORDER.TARGETS_FIRST, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, asian_support, lowercase, device): metric_args = { "normalize": normalize, diff --git a/tests/text/test_wer.py b/tests/text/test_wer.py index 72b0804a128..1a899ccbc25 100644 --- a/tests/text/test_wer.py +++ b/tests/text/test_wer.py @@ -1,7 +1,6 @@ from typing import Callable, List, Union import pytest -from pytest_cases import parametrize_with_cases from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester @@ -30,7 +29,7 @@ def _compute_wer_metric_jiwer(prediction: Union[str, List[str]], reference: Unio ], ) class TestWER(TextTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_wer_class(self, ddp, dist_sync_on_step, preds, targets, device): @@ -45,7 +44,7 @@ def test_wer_class(self, ddp, dist_sync_on_step, preds, targets, device): input_order=INPUT_ORDER.PREDS_FIRST, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_wer_functional(self, preds, targets, device): self.run_functional_metric_test( diff --git a/tests/text/test_wil.py b/tests/text/test_wil.py index eb82f1d136c..f11d6b8052a 100644 --- a/tests/text/test_wil.py +++ b/tests/text/test_wil.py @@ -2,7 +2,6 @@ import pytest from jiwer import wil -from pytest_cases import parametrize_with_cases from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester @@ -25,7 +24,7 @@ def _compute_wil_metric_jiwer(prediction: Union[str, List[str]], reference: Unio ], ) class TestWordInfoLost(TextTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_wil_class(self, ddp, dist_sync_on_step, preds, targets, device): @@ -40,7 +39,7 @@ def test_wil_class(self, ddp, dist_sync_on_step, preds, targets, device): input_order=INPUT_ORDER.PREDS_FIRST, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_wil_functional(self, preds, targets, device): self.run_functional_metric_test( diff --git a/tests/text/test_wip.py b/tests/text/test_wip.py index b22560dda84..2751b43cd49 100644 --- a/tests/text/test_wip.py +++ b/tests/text/test_wip.py @@ -2,7 +2,6 @@ import pytest from jiwer import wip -from pytest_cases import parametrize_with_cases from tests.helpers.testers import MetricTesterDDPCases from tests.text.helpers import INPUT_ORDER, TextTester @@ -25,7 +24,7 @@ def _compute_wip_metric_jiwer(prediction: Union[str, List[str]], reference: Unio ], ) class TestWordInfoPreserved(TextTester): - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_wip_class(self, ddp, dist_sync_on_step, preds, targets, device): @@ -40,7 +39,7 @@ def test_wip_class(self, ddp, dist_sync_on_step, preds, targets, device): input_order=INPUT_ORDER.PREDS_FIRST, ) - @parametrize_with_cases("device", cases=MetricTesterDDPCases, has_tag="device") + @pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device()) def test_wip_functional(self, preds, targets, device): self.run_functional_metric_test( diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 8a1419d88cf..807b435fd2a 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -4,7 +4,6 @@ import pytest import torch -from pytest_cases import parametrize_with_cases from sklearn.metrics import accuracy_score from sklearn.metrics import r2_score as sk_r2score @@ -125,7 +124,7 @@ def _multi_target_sk_accuracy(preds, target, num_outputs): class TestMultioutputWrapper(MetricTester): """Test the MultioutputWrapper class with regression and classification inner metrics.""" - @parametrize_with_cases("ddp,device", cases=MetricTesterDDPCases, has_tag="strategy") + @pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy()) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_multioutput_wrapper( self, From debcd8f276c7a29ac899648d8c95b47a5d665ae4 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 26 Dec 2021 17:32:12 +0000 Subject: [PATCH 17/18] Add gpu device restriction --- tests/helpers/testers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 2a7f01d554d..b228aca9f10 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -338,7 +338,10 @@ def name_device(): return "device" def cases_device(): - return ["cpu", "cuda"] + return [ + "cpu", + pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required")), + ] def name_strategy(): return ",".join([MetricTesterDDPCases.name_ddp(), MetricTesterDDPCases.name_device()]) From f9d300db37b9e32778942b2f00adc8ca0c13adb1 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Sun, 26 Dec 2021 17:33:11 +0000 Subject: [PATCH 18/18] Add ddp restriction --- tests/helpers/testers.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index b228aca9f10..46b6eaf38c1 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -332,7 +332,15 @@ def name_ddp(): return "ddp" def cases_ddp(): - return [False, True] + return [ + False, + pytest.param( + True, + marks=pytest.mark.skipif( + not torch.distributed.is_available(), reason="Distributed mode is not available." + ), + ), + ] def name_device(): return "device"