Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Test strategy #688

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions tests/audio/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from tests.helpers.testers import MetricTester, MetricTesterDDPCases
from torchmetrics.audio.pesq import PESQ
from torchmetrics.functional.audio.pesq import pesq
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6
Expand Down Expand Up @@ -75,25 +75,28 @@ def average_metric(preds, target, metric_func):
class TestPESQ(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say it is cleaner if it will be a constant or function returning tuple:

def ddp_name_options():
    return "ddp", (True, False)

and later:

@pytest.mark.parametrize(*ddp_name_options())

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but anyway why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I added this stategy was that now all available test cases are generated and you can explicitly test them locally. This way you can guarantee full coverage, don't depend on the ci machines or some launch flags and each test states on which device it runs. This way you dont have to check the machine stats to e.g. find out, if a test failed on single or multi gpu.
In order to optimize the tests, having an additional skip condition with an env var could achieve the same. Especially for test optimization you could then run ddp=false on a machine with a single gpu and ddp=True on one with multiple.
I'm fine putting it into a single function, my idea was to stick with two parameter style of parametrize

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @Borda that it would be better to return a tuple there.

For the general purpose: I think we should split this between arguments for every test (like devices and maybe also ddp strategies) for which we should use pytest fixtures (that's what they are there for) and then you could call pytest --devices=XX tests. For test-specific arguments this approach is fine IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure how to continue with this. Any suggestions?

@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_pesq(self, preds, target, sk_metric, fs, mode, ddp, dist_sync_on_step):
def test_pesq(self, preds, target, sk_metric, fs, mode, ddp, dist_sync_on_step, device):
self.run_class_metric_test(
ddp,
preds,
target,
PESQ,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
device=device,
metric_args=dict(fs=fs, mode=mode),
)

def test_pesq_functional(self, preds, target, sk_metric, fs, mode):
@pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device())
def test_pesq_functional(self, preds, target, sk_metric, fs, mode, device):
self.run_functional_metric_test(
preds,
target,
pesq,
sk_metric,
device=device,
metric_args=dict(fs=fs, mode=mode),
)

Expand Down
11 changes: 7 additions & 4 deletions tests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases
from torchmetrics.audio import PIT
from torchmetrics.functional import pit, si_sdr, snr
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6
Expand Down Expand Up @@ -112,25 +112,28 @@ def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Ten
class TestPIT(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy())
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_pit(self, preds, target, sk_metric, metric_func, eval_func, ddp, dist_sync_on_step):
def test_pit(self, preds, target, sk_metric, metric_func, eval_func, ddp, device, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
PIT,
sk_metric=partial(_average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
device=device,
metric_args=dict(metric_func=metric_func, eval_func=eval_func),
)

def test_pit_functional(self, preds, target, sk_metric, metric_func, eval_func):
@pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device())
def test_pit_functional(self, preds, target, sk_metric, device, metric_func, eval_func):
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=pit,
sk_metric=sk_metric,
device=device,
metric_args=dict(metric_func=metric_func, eval_func=eval_func),
)

Expand Down
11 changes: 7 additions & 4 deletions tests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from tests.helpers.testers import MetricTester, MetricTesterDDPCases
from torchmetrics.audio import SDR
from torchmetrics.functional import sdr
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8
Expand Down Expand Up @@ -76,25 +76,28 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tens
class TestSDR(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy())
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step):
def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step, device):
self.run_class_metric_test(
ddp,
preds,
target,
SDR,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
device=device,
metric_args=dict(),
)

def test_sdr_functional(self, preds, target, sk_metric):
@pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device())
def test_sdr_functional(self, preds, target, sk_metric, device):
self.run_functional_metric_test(
preds,
target,
sdr,
sk_metric,
device=device,
metric_args=dict(),
)

Expand Down
11 changes: 7 additions & 4 deletions tests/audio/test_si_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases
from torchmetrics.audio import SI_SDR
from torchmetrics.functional import si_sdr
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6
Expand Down Expand Up @@ -77,25 +77,28 @@ def average_metric(preds, target, metric_func):
class TestSISDR(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy())
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step):
def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step, device):
self.run_class_metric_test(
ddp,
preds,
target,
SI_SDR,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
device=device,
metric_args=dict(zero_mean=zero_mean),
)

def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean):
@pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device())
def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean, device):
self.run_functional_metric_test(
preds,
target,
si_sdr,
sk_metric,
device=device,
metric_args=dict(zero_mean=zero_mean),
)

Expand Down
11 changes: 7 additions & 4 deletions tests/audio/test_si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases
from torchmetrics.audio import SI_SNR
from torchmetrics.functional import si_snr
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6
Expand Down Expand Up @@ -72,24 +72,27 @@ def average_metric(preds, target, metric_func):
class TestSISNR(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy())
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step):
def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step, device):
self.run_class_metric_test(
ddp,
preds,
target,
SI_SNR,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
device=device,
)

def test_si_snr_functional(self, preds, target, sk_metric):
@pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device())
def test_si_snr_functional(self, preds, target, sk_metric, device):
self.run_functional_metric_test(
preds,
target,
si_snr,
sk_metric,
device=device,
)

def test_si_snr_differentiability(self, preds, target, sk_metric):
Expand Down
11 changes: 7 additions & 4 deletions tests/audio/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases
from torchmetrics.audio import SNR
from torchmetrics.functional import snr
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6
Expand Down Expand Up @@ -79,25 +79,28 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable):
class TestSNR(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy())
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step):
def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step, device):
self.run_class_metric_test(
ddp,
preds,
target,
SNR,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
device=device,
metric_args=dict(zero_mean=zero_mean),
)

def test_snr_functional(self, preds, target, sk_metric, zero_mean):
@pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device())
def test_snr_functional(self, preds, target, sk_metric, zero_mean, device):
self.run_functional_metric_test(
preds,
target,
snr,
sk_metric,
device=device,
metric_args=dict(zero_mean=zero_mean),
)

Expand Down
11 changes: 7 additions & 4 deletions tests/audio/test_stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from tests.helpers.testers import MetricTester, MetricTesterDDPCases
from torchmetrics.audio.stoi import STOI
from torchmetrics.functional.audio.stoi import stoi
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6
Expand Down Expand Up @@ -75,25 +75,28 @@ def average_metric(preds, target, metric_func):
class TestSTOI(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy())
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_step):
def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_step, device):
self.run_class_metric_test(
ddp,
preds,
target,
STOI,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
device=device,
metric_args=dict(fs=fs, extended=extended),
)

def test_stoi_functional(self, preds, target, sk_metric, fs, extended):
@pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device())
def test_stoi_functional(self, preds, target, sk_metric, fs, extended, device):
self.run_functional_metric_test(
preds,
target,
stoi,
sk_metric,
device=device,
metric_args=dict(fs=fs, extended=extended),
)

Expand Down
7 changes: 4 additions & 3 deletions tests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import torch

from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester, MetricTesterDDPCases
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric


Expand Down Expand Up @@ -80,15 +80,16 @@ def update(self, values, weights):
class TestAggregation(MetricTester):
"""Test aggregation metrics."""

@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy())
@pytest.mark.parametrize("dist_sync_on_step", [False])
def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights):
def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights, device):
"""test modular implementation."""
self.run_class_metric_test(
ddp=ddp,
dist_sync_on_step=dist_sync_on_step,
metric_class=metric_class,
sk_metric=compare_fn,
device=device,
check_scriptable=True,
# Abuse of names here
preds=values,
Expand Down
11 changes: 7 additions & 4 deletions tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester, MetricTesterDDPCases
from torchmetrics import Accuracy
from torchmetrics.functional import accuracy
from torchmetrics.utilities.checks import _input_format_classification
Expand Down Expand Up @@ -81,25 +81,28 @@ def _sk_accuracy(preds, target, subset_accuracy):
],
)
class TestAccuracies(MetricTester):
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy())
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy):
def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy, device):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=Accuracy,
sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy),
dist_sync_on_step=dist_sync_on_step,
device=device,
metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy},
)

def test_accuracy_fn(self, preds, target, subset_accuracy):
@pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device())
def test_accuracy_fn(self, preds, target, subset_accuracy, device):
self.run_functional_metric_test(
preds,
target,
metric_functional=accuracy,
sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy),
device=device,
metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy},
)

Expand Down
17 changes: 12 additions & 5 deletions tests/classification/test_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import tensor

from tests.helpers import seed_all
from tests.helpers.testers import NUM_BATCHES, MetricTester
from tests.helpers.testers import NUM_BATCHES, MetricTester, MetricTesterDDPCases
from torchmetrics.classification.auc import AUC
from torchmetrics.functional import auc

Expand Down Expand Up @@ -55,22 +55,29 @@ def sk_auc(x, y, reorder=False):

@pytest.mark.parametrize("x, y", _examples)
class TestAUC(MetricTester):
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize(MetricTesterDDPCases.name_strategy(), MetricTesterDDPCases.cases_strategy())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DDP apparently fails and therefore was probably excluded in the tests, yet there is no mention in the docs, that DDP is not supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda @SkafteNicki any idea how to handle this? removing the full ddp/device strategy and opening an issue to fix auc would be my recommendation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its hard for me to remember, but basically AUC needs ordered input to be working and while we are testing on ordered input when we move to ddp it somehow gets unordered. I wonder if we should just set the reorder flag to True whenever ddp=True.

@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_auc(self, x, y, ddp, dist_sync_on_step):
def test_auc(self, x, y, ddp, dist_sync_on_step, device):
self.run_class_metric_test(
ddp=ddp,
preds=x,
target=y,
metric_class=AUC,
sk_metric=sk_auc,
dist_sync_on_step=dist_sync_on_step,
device=device,
)

@pytest.mark.parametrize(MetricTesterDDPCases.name_device(), MetricTesterDDPCases.cases_device())
@pytest.mark.parametrize("reorder", [True, False])
def test_auc_functional(self, x, y, reorder):
def test_auc_functional(self, x, y, reorder, device):
self.run_functional_metric_test(
x, y, metric_functional=auc, sk_metric=partial(sk_auc, reorder=reorder), metric_args={"reorder": reorder}
x,
y,
metric_functional=auc,
sk_metric=partial(sk_auc, reorder=reorder),
device=device,
metric_args={"reorder": reorder},
)

@pytest.mark.parametrize("reorder", [True, False])
Expand Down
Loading