Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve testing time #820

Merged
merged 22 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

_REQUEST_TIMEOUT = 10
_PATH_ROOT = os.path.dirname(os.path.dirname(__file__))
_PKG_WIDE_SUBPACKAGES = ("utilities",)
_PKG_WIDE_SUBPACKAGES = ("utilities", "helpers")
LUT_PYTHON_TORCH = {
"3.8": "1.4",
"3.9": "1.7.1",
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci_test-full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ jobs:
env:
PYTEST_ARTEFACT: test-results-${{ matrix.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}.xml
PYTORCH_URL: https://download.pytorch.org/whl/cpu/torch_stable.html
TRANSFORMERS_CACHE: .cache/huggingface/
Borda marked this conversation as resolved.
Show resolved Hide resolved

# Timeout: https://stackoverflow.com/a/59076067/4521646
# seems that MacOS jobs take much more than orger OS
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Improve mAP performance ([#742](https://github.com/PyTorchLightning/metrics/pull/742))


- Improved testing speed ([#820](https://github.com/PyTorchLightning/metrics/pull/820))


## [0.7.0] - 2022-01-17

### Added
Expand Down
1 change: 1 addition & 0 deletions requirements/image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ scipy
torchvision # this is needed to internally set TV version according installed PT
torch-fidelity
lpips
Pillow==8.4.0
Borda marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion tests/audio/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from tests.helpers.testers import NUM_BATCHES, MetricTester
from torchmetrics.audio import SignalNoiseRatio
from torchmetrics.functional import signal_noise_ratio
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6
Expand All @@ -32,6 +32,7 @@

Input = namedtuple("Input", ["preds", "target"])

BATCH_SIZE = 2
inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, TIME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, TIME),
Expand Down
3 changes: 3 additions & 0 deletions tests/classification/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@

import torch

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES

seed_all(1)

Input = namedtuple("Input", ["preds", "target"])

_input_binary_prob = Input(
Expand Down
11 changes: 4 additions & 7 deletions tests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average="macr
)


@pytest.mark.parametrize("average", ["macro", "weighted", "micro"])
@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
Expand All @@ -99,6 +98,7 @@ def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average="macr
],
)
class TestAUROC(MetricTester):
@pytest.mark.parametrize("average", ["macro", "weighted", "micro"])
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step):
Expand All @@ -124,6 +124,7 @@ def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, dd
metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr},
)

@pytest.mark.parametrize("average", ["macro", "weighted", "micro"])
def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr):
# max_fpr different from None is not support in multi class
if max_fpr is not None and num_classes != 1:
Expand All @@ -145,7 +146,7 @@ def test_auroc_functional(self, preds, target, sk_metric, num_classes, average,
metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr},
)

def test_auroc_differentiability(self, preds, target, sk_metric, num_classes, average, max_fpr):
def test_auroc_differentiability(self, preds, target, sk_metric, num_classes, max_fpr):
# max_fpr different from None is not support in multi class
if max_fpr is not None and num_classes != 1:
pytest.skip("max_fpr parameter not support for multi class or multi label")
Expand All @@ -154,16 +155,12 @@ def test_auroc_differentiability(self, preds, target, sk_metric, num_classes, av
if max_fpr is not None and _TORCH_LOWER_1_6:
pytest.skip("requires torch v1.6 or higher to test max_fpr argument")

# average='micro' only supported for multilabel
if average == "micro" and preds.ndim > 2 and preds.ndim == target.ndim + 1:
pytest.skip("micro argument only support for multilabel input")

self.run_differentiability_test(
preds=preds,
target=target,
metric_module=AUROC,
metric_functional=auroc,
metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr},
metric_args={"num_classes": num_classes, "max_fpr": max_fpr},
)


Expand Down
5 changes: 3 additions & 2 deletions tests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1, average=
(_input_multilabel.preds, _input_multilabel.target, _sk_avg_prec_multilabel_prob, NUM_CLASSES),
],
)
@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None])
class TestAveragePrecision(MetricTester):
@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None])
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_average_precision(self, preds, target, sk_metric, num_classes, average, ddp, dist_sync_on_step):
Expand All @@ -103,6 +103,7 @@ def test_average_precision(self, preds, target, sk_metric, num_classes, average,
metric_args={"num_classes": num_classes, "average": average},
)

@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None])
def test_average_precision_functional(self, preds, target, sk_metric, num_classes, average):
if target.max() > 1 and average == "micro":
pytest.skip("average=micro and multiclass input cannot be used together")
Expand All @@ -115,7 +116,7 @@ def test_average_precision_functional(self, preds, target, sk_metric, num_classe
metric_args={"num_classes": num_classes, "average": average},
)

def test_average_precision_differentiability(self, preds, sk_metric, target, num_classes, average):
def test_average_precision_differentiability(self, preds, sk_metric, target, num_classes):
self.run_differentiability_test(
preds=preds,
target=target,
Expand Down
6 changes: 3 additions & 3 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
pass

NUM_PROCESSES = 2
NUM_BATCHES = 10
NUM_BATCHES = 4 # Need to be divisible with the number of processes
BATCH_SIZE = 32
NUM_CLASSES = 5
EXTRA_DIM = 3
Expand Down Expand Up @@ -545,15 +545,15 @@ def run_differentiability_test(
metric = metric_module(**metric_args)
if preds.is_floating_point():
preds.requires_grad = True
out = metric(preds[0], target[0])
out = metric(preds[0, :2], target[0, :2])

# Check if requires_grad matches is_differentiable attribute
_assert_requires_grad(metric, out)

if metric.is_differentiable and metric_functional is not None:
# check for numerical correctness
assert torch.autograd.gradcheck(
partial(metric_functional, **metric_args), (preds[0].double(), target[0])
partial(metric_functional, **metric_args), (preds[0, :2].double(), target[0, :2])
)

# reset as else it will carry over to other tests
Expand Down