Skip to content

Commit

Permalink
more Ruff checks (#1760)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
SkafteNicki and Borda authored May 9, 2023
1 parent 1e8fede commit cf7604d
Show file tree
Hide file tree
Showing 23 changed files with 62 additions and 57 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ select = [
"D", # see: https://pypi.org/project/pydocstyle
"N", # see: https://pypi.org/project/pep8-naming
"S", # see: https://pypi.org/project/flake8-bandit
"T10", # see: https://pypi.org/project/flake8-debugger/
"Q", # see: https://pypi.org/project/flake8-quotes/
]
extend-select = [
"C4", # see: https://pypi.org/project/flake8-comprehensions
Expand Down Expand Up @@ -114,7 +116,7 @@ unfixable = ["F401"]
"setup.py" = ["ANN202", "ANN401"]
"src/**" = ["ANN401"]
"tests/**" = [
"S101", "ANN001", "ANN002", "ANN003", "ANN201", "ANN202", "ANN204", "ANN205", "ANN401"
"S101", "ANN001", "ANN201", "ANN202", "ANN401"
]

[tool.ruff.pydocstyle]
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/functional/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def _find_best_perm_by_exhaustive_method(

# find the metric of each permutation
perm_num = ps.shape[-1]
# shape [batch_size, spk_num, perm_num]
# shape of [batch_size, spk_num, perm_num]
bps = ps[None, ...].expand(batch_size, spk_num, perm_num)
# shape [batch_size, spk_num, perm_num]
# shape of [batch_size, spk_num, perm_num]
metric_of_ps_details = torch.gather(metric_mtx, 2, bps)
# shape [batch_size, perm_num]
# shape of [batch_size, perm_num]
metric_of_ps = metric_of_ps_details.mean(dim=1)

# find the best metric and best permutation
Expand Down
10 changes: 5 additions & 5 deletions tests/integrations/lightning/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def __init__(self, size, length) -> None:
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
def __getitem__(self, index) -> dict:
"""Get datapoint."""
return {"id": str(index), "x": self.data[index]}

def __len__(self):
def __len__(self) -> int:
"""Return length of dataset."""
return self.len

Expand All @@ -45,11 +45,11 @@ def __init__(self, size, length) -> None:
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
def __getitem__(self, index) -> torch.Tensor:
"""Get datapoint."""
return self.data[index]

def __len__(self):
def __len__(self) -> int:
"""Get length of dataset."""
return self.len

Expand Down Expand Up @@ -80,7 +80,7 @@ def forward(self, x):
return self.layer(x)

@staticmethod
def loss(_, prediction):
def loss(_, prediction) -> torch.Tensor:
"""Arbitrary loss."""
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

Expand Down
6 changes: 3 additions & 3 deletions tests/integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ def _step(self, stage, batch):

return loss

def training_step(self, batch, batch_idx, *args, **kwargs):
def training_step(self, batch, batch_idx):
return self._step("train", batch)

def validation_step(self, batch, batch_idx, *args, **kwargs):
def validation_step(self, batch, batch_idx):
return self._step("val", batch)

def test_step(self, batch, batch_idx, *args, **kwargs):
def test_step(self, batch, batch_idx):
return self._step("test", batch)

def _assert_epoch_end(self, stage):
Expand Down
1 change: 0 additions & 1 deletion tests/unittests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def naive_implementation_pit_scipy(
for e in range(spk_num):
metric_mtx[:, t, e] = metric_func(preds[:, e, ...], target[:, t, ...])

# pit_r = PermutationInvariantTraining(metric_func, eval_func)(preds, target)
metric_mtx = metric_mtx.detach().cpu().numpy()
best_metrics = []
best_perms = []
Expand Down
2 changes: 0 additions & 2 deletions tests/unittests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Ten
"preds, target, ref_metric",
[
(inputs_1spk.preds, inputs_1spk.target, original_impl_compute_permutation),
# (inputs_1spk.preds, inputs_1spk.target, original_impl_no_compute_permutation, False),
(inputs_2spk.preds, inputs_2spk.target, original_impl_compute_permutation),
# (inputs_2spk.preds, inputs_2spk.target, original_impl_no_compute_permutation, False),
],
)
class TestSDR(MetricTester):
Expand Down
7 changes: 4 additions & 3 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pickle
import time
from copy import deepcopy
from typing import Any

import pytest
import torch
Expand Down Expand Up @@ -285,8 +286,8 @@ class DummyMetric(Metric):
def __init__(self) -> None:
super().__init__()

def update(self, *args, kwarg):
print("Entered DummyMetric")
def update(self, *args: Any, kwarg: Any):
pass

def compute(self):
return
Expand All @@ -298,7 +299,7 @@ def __init__(self) -> None:
super().__init__()

def update(self, preds, target, kwarg2):
print("Entered MyAccuracy")
pass

def compute(self):
return
Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from operator import neg, pos
from typing import Any

import pytest
import torch
Expand All @@ -30,7 +31,7 @@ def __init__(self, val_to_return) -> None:
self.add_state("_num_updates", tensor(0), dist_reduce_fx="sum")
self._val_to_return = val_to_return

def update(self, *args, **kwargs) -> None:
def update(self, *args: Any, **kwargs: Any) -> None:
"""Compute state."""
self._num_updates += 1

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def update(self, x):
def compute(self):
return self.x // self.c

def __repr__(self):
def __repr__(self) -> str:
return f"DummyCatMetric(x={self.x}, c={self.c})"

metric = DummyCatMetric()
Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import pickle
from collections import OrderedDict
from typing import Any
from unittest.mock import Mock

import cloudpickle
Expand Down Expand Up @@ -473,7 +474,7 @@ def test_no_warning_on_custom_forward(metric_class):
class UnsetProperty(metric_class):
full_state_update = None

def forward(self, *args, **kwargs):
def forward(self, *args: Any, **kwargs: Any):
self.update(*args, **kwargs)

with no_warning_call(
Expand Down
6 changes: 3 additions & 3 deletions tests/unittests/classification/test_group_fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def run_differentiability_test(
metric_functional: Optional[Callable] = None,
metric_args: Optional[dict] = None,
groups: Optional[Tensor] = None,
):
) -> None:
"""Test if a metric is differentiable or not.
Args:
Expand Down Expand Up @@ -148,7 +148,7 @@ def run_precision_test_cpu(
metric_args: Optional[dict] = None,
dtype: torch.dtype = torch.half,
**kwargs_update: Any,
):
) -> None:
"""Test if a metric can be used with half precision tensors on cpu.
Args:
Expand Down Expand Up @@ -185,7 +185,7 @@ def run_precision_test_gpu(
metric_args: Optional[dict] = None,
dtype: torch.dtype = torch.half,
**kwargs_update: Any,
):
) -> None:
"""Test if a metric can be used with half precision tensors on gpu.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def _sklearn_precision_recall_curve_multiclass(preds, target, ignore_index=None)
precision.append(res[0])
recall.append(res[1])
thresholds.append(res[2])
# return precision, recall, thresholds
return [np.nan_to_num(x, nan=0.0) for x in [precision, recall, thresholds]]


Expand Down
10 changes: 5 additions & 5 deletions tests/unittests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def run_precision_test_cpu(
metric_args: Optional[dict] = None,
dtype: torch.dtype = torch.half,
**kwargs_update: Any,
):
) -> None:
"""Test if a metric can be used with half precision tensors on cpu.
Args:
Expand Down Expand Up @@ -482,7 +482,7 @@ def run_precision_test_gpu(
metric_args: Optional[dict] = None,
dtype: torch.dtype = torch.half,
**kwargs_update: Any,
):
) -> None:
"""Test if a metric can be used with half precision tensors on gpu.
Args:
Expand Down Expand Up @@ -513,7 +513,7 @@ def run_differentiability_test(
metric_module: Metric,
metric_functional: Optional[Callable] = None,
metric_args: Optional[dict] = None,
):
) -> None:
"""Test if a metric is differentiable or not.
Args:
Expand Down Expand Up @@ -549,7 +549,7 @@ class DummyMetric(Metric):
name = "Dummy"
full_state_update: Optional[bool] = True

def __init__(self, **kwargs) -> None:
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")

Expand All @@ -568,7 +568,7 @@ class DummyListMetric(Metric):
name = "DummyList"
full_state_update: Optional[bool] = True

def __init__(self, **kwargs) -> None:
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.add_state("x", [], dist_reduce_fx="cat")

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ class _ImgDataset(Dataset):
def __init__(self, imgs) -> None:
self.imgs = imgs

def __getitem__(self, idx):
def __getitem__(self, idx) -> torch.Tensor:
return self.imgs[idx]

def __len__(self):
def __len__(self) -> int:
return self.imgs.shape[0]


Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/image/test_inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ class _ImgDataset(Dataset):
def __init__(self, imgs) -> None:
self.imgs = imgs

def __getitem__(self, idx):
def __getitem__(self, idx) -> torch.Tensor:
return self.imgs[idx]

def __len__(self):
def __len__(self) -> int:
return self.imgs.shape[0]


Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/image/test_kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ class _ImgDataset(Dataset):
def __init__(self, imgs) -> None:
self.imgs = imgs

def __getitem__(self, idx):
def __getitem__(self, idx) -> torch.Tensor:
return self.imgs[idx]

def __len__(self):
def __len__(self) -> int:
return self.imgs.shape[0]


Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/image/test_tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from collections import namedtuple
from functools import partial
from typing import Any

import pytest
import torch
Expand All @@ -30,7 +31,7 @@
class TotalVariationTester(TotalVariation):
"""Tester class for `TotalVariation` metric overriding its update method."""

def update(self, img, *args):
def update(self, img, *args: Any):
"""Update metric."""
super().update(img=img)

Expand Down
10 changes: 5 additions & 5 deletions tests/unittests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Dict, List, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Tuple, Type, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -78,7 +78,7 @@ def _compute_sklearn_metric(
empty_target_action: str = "skip",
ignore_index: int = None,
reverse: bool = False,
**kwargs,
**kwargs: Any,
) -> Tensor:
"""Compute metric with multiple iterations over every query predictions set."""
if indexes is None:
Expand Down Expand Up @@ -466,7 +466,7 @@ def run_functional_metric_test(
reference_metric: Callable,
metric_args: dict,
reverse: bool = False,
**kwargs,
**kwargs: Any,
):
"""Test functional implementation of metric."""
_ref_metric_adapted = partial(_compute_sklearn_metric, metric=reference_metric, reverse=reverse, **metric_args)
Expand Down Expand Up @@ -537,7 +537,7 @@ def run_metric_class_arguments_test(
metric_args: dict = None,
exception_type: Type[Exception] = ValueError,
kwargs_update: dict = None,
):
) -> None:
"""Test that specific errors are raised for incorrect input."""
_errors_test_class_metric(
indexes=indexes,
Expand All @@ -558,7 +558,7 @@ def run_functional_metric_arguments_test(
message: str = "",
exception_type: Type[Exception] = ValueError,
kwargs_update: dict = None,
):
) -> None:
"""Test that specific errors are raised for incorrect input."""
_errors_test_functional_metric(
preds=preds,
Expand Down
Loading

0 comments on commit cf7604d

Please sign in to comment.