Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Feb 6, 2023
1 parent debd53c commit 1678d2c
Show file tree
Hide file tree
Showing 32 changed files with 81 additions and 80 deletions.
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,7 @@ unfixable = ["F401"]
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # todo # Missing docstring in magic method
"D205", # todo # 1 blank line required between summary line and description
"D401", # todo # First line of docstring should be in imperative mood: ...
"D403", # todo # First word of the first line should be properly capitalized
"D415", # todo # First line should end with a period, question mark, or exclamation point"
"D417", # todo # Missing argument descriptions in the docstring: ...
]
Expand Down
4 changes: 4 additions & 0 deletions tests/integrations/lightning/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ def __init__(self, size, length):
self.data = torch.randn(length, size)

def __getitem__(self, index):
"""Get datapoint."""
return {"id": str(index), "x": self.data[index]}

def __len__(self):
"""Return length of dataset."""
return self.len


Expand All @@ -35,9 +37,11 @@ def __init__(self, size, length):
self.data = torch.randn(length, size)

def __getitem__(self, index):
"""Get datapoint."""
return self.data[index]

def __len__(self):
"""Get length of dataset."""
return self.len


Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def naive_implementation_pit_scipy(


def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor:
"""average the metric values.
"""Average the metric values.
Args:
preds: predictions, shape[batch, spk, time]
Expand Down
16 changes: 8 additions & 8 deletions tests/unittests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@


def compare_mean(values, weights):
"""reference implementation for mean aggregation."""
"""Reference implementation for mean aggregation."""
return np.average(values.numpy(), weights=weights)


def compare_sum(values, weights):
"""reference implementation for sum aggregation."""
"""Reference implementation for sum aggregation."""
return np.sum(values.numpy())


def compare_min(values, weights):
"""reference implementation for min aggregation."""
"""Reference implementation for min aggregation."""
return np.min(values.numpy())


def compare_max(values, weights):
"""reference implementation for max aggregation."""
"""Reference implementation for max aggregation."""
return np.max(values.numpy())


Expand Down Expand Up @@ -104,7 +104,7 @@ def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, va
@pytest.mark.parametrize("nan_strategy", ["error", "warn"])
@pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric])
def test_nan_error(value, nan_strategy, metric_class):
"""test correct errors are raised."""
"""Test correct errors are raised."""
metric = metric_class(nan_strategy=nan_strategy)
if nan_strategy == "error":
with pytest.raises(RuntimeError, match="Encounted `nan` values in tensor"):
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_nan_error(value, nan_strategy, metric_class):
],
)
def test_nan_expected(metric_class, nan_strategy, value, expected):
"""test that nan values are handled correctly."""
"""Test that nan values are handled correctly."""
metric = metric_class(nan_strategy=nan_strategy)
metric.update(value.clone())
out = metric.compute()
Expand All @@ -150,7 +150,7 @@ def test_nan_expected(metric_class, nan_strategy, value, expected):

@pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric])
def test_error_on_wrong_nan_strategy(metric_class):
"""test error raised on wrong nan_strategy argument."""
"""Test error raised on wrong nan_strategy argument."""
with pytest.raises(ValueError, match="Arg `nan_strategy` should either .*"):
metric_class(nan_strategy=[])

Expand All @@ -160,7 +160,7 @@ def test_error_on_wrong_nan_strategy(metric_class):
"weights, expected", [(1, 11.5), (torch.ones(2, 1, 1), 11.5), (torch.tensor([1, 2]).reshape(2, 1, 1), 13.5)]
)
def test_mean_metric_broadcasting(weights, expected):
"""check that weight broadcasting works for mean metric."""
"""Check that weight broadcasting works for mean metric."""
values = torch.arange(24).reshape(2, 3, 4)
avg = MeanMetric()

Expand Down
8 changes: 2 additions & 6 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ def test_metric_collection_wrong_input(tmpdir):


def test_metric_collection_args_kwargs(tmpdir):
"""Check that args and kwargs gets passed correctly in metric collection, Checks both update and forward
method.
"""
"""Check that args and kwargs gets passed correctly in metric collection, checks both update and forward."""
m1 = DummyMetricSum()
m2 = DummyMetricDiff()

Expand Down Expand Up @@ -450,9 +448,7 @@ def test_check_compute_groups_correctness(self, metrics, expected, preds, target

@pytest.mark.parametrize("method", ["items", "values", "keys"])
def test_check_compute_groups_items_and_values(self, metrics, expected, preds, target, method):
"""Check that whenever user call a methods that give access to the indivitual metric that state are copied
instead of just passed by reference.
"""
"""Check states are copied instead of passed by ref when a single metric in the collection is access."""
m = MetricCollection(deepcopy(metrics), compute_groups=True)
m2 = MetricCollection(deepcopy(metrics), compute_groups=False)

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def test_metrics_getitem(value, idx, expected_result):


def test_compositional_metrics_update():
"""test update method for compositional metrics."""
"""Test update method for compositional metrics."""
compos = DummyMetric(5) + DummyMetric(4)

assert isinstance(compos, CompositionalMetric)
Expand Down
4 changes: 1 addition & 3 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,7 @@ def reload_state_dict(state_dict, expected_x, expected_c):

@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_state_dict_is_synced(tmpdir):
"""This test asserts that metrics are synced while creating the state dict but restored after to continue
accumulation.
"""
"""Tests that metrics are synced while creating the state dict but restored after to continue accumulation."""
torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, tmpdir), nprocs=2)


Expand Down
12 changes: 6 additions & 6 deletions tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_pickle(tmpdir):


def test_state_dict(tmpdir):
"""test that metric states can be removed and added to state dict."""
"""Test that metric states can be removed and added to state dict."""
metric = DummyMetric()
assert metric.state_dict() == OrderedDict()
metric.persistent(True)
Expand All @@ -240,7 +240,7 @@ def test_state_dict(tmpdir):


def test_load_state_dict(tmpdir):
"""test that metric states can be loaded with state dict."""
"""Test that metric states can be loaded with state dict."""
metric = DummyMetricSum()
metric.persistent(True)
metric.update(5)
Expand All @@ -250,7 +250,7 @@ def test_load_state_dict(tmpdir):


def test_child_metric_state_dict():
"""test that child metric states will be added to parent state dict."""
"""Test that child metric states will be added to parent state dict."""

class TestModule(Module):
def __init__(self):
Expand Down Expand Up @@ -292,7 +292,7 @@ def test_device_and_dtype_transfer(tmpdir):


def test_warning_on_compute_before_update():
"""test that an warning is raised if user tries to call compute before update."""
"""Test that an warning is raised if user tries to call compute before update."""
metric = DummyMetricSum()

# make sure everything is fine with forward
Expand All @@ -315,13 +315,13 @@ def test_warning_on_compute_before_update():


def test_metric_scripts():
"""test that metrics are scriptable."""
"""Test that metrics are scriptable."""
torch.jit.script(DummyMetric())
torch.jit.script(DummyMetricSum())


def test_metric_forward_cache_reset():
"""test that forward cache is reset when `reset` is called."""
"""Test that forward cache is reset when `reset` is called."""
metric = DummyMetricSum()
_ = metric(2.0)
assert metric._forward_cache == 2.0
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def test_multilabel_auroc_threshold_arg(self, input, average):
)
@pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)])
def test_valid_input_thresholds(metric, thresholds):
"""test valid formats of the threshold argument."""
"""Test valid formats of the threshold argument."""
with pytest.warns(None) as record:
metric(thresholds=thresholds)
assert len(record) == 0
2 changes: 1 addition & 1 deletion tests/unittests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def test_multilabel_average_precision_threshold_arg(self, input, average):
)
@pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)])
def test_valid_input_thresholds(metric, thresholds):
"""test valid formats of the threshold argument."""
"""Test valid formats of the threshold argument."""
with pytest.warns(None) as record:
metric(thresholds=thresholds)
assert len(record) == 0
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def test_multilabel_error_on_wrong_dtypes(self, input):
)
@pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)])
def test_valid_input_thresholds(metric, thresholds):
"""test valid formats of the threshold argument."""
"""Test valid formats of the threshold argument."""
with pytest.warns(None) as record:
metric(thresholds=thresholds)
assert len(record) == 0
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test_multilabel_recall_at_fixed_precision_threshold_arg(self, input, min_pre
)
@pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)])
def test_valid_input_thresholds(metric, thresholds):
"""test valid formats of the threshold argument."""
"""Test valid formats of the threshold argument."""
with pytest.warns(None) as record:
metric(min_precision=0.5, thresholds=thresholds)
assert len(record) == 0
2 changes: 1 addition & 1 deletion tests/unittests/classification/test_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_multilabel_roc_threshold_arg(self, input, threshold_fn):
)
@pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)])
def test_valid_input_thresholds(metric, thresholds):
"""test valid formats of the threshold argument."""
"""Test valid formats of the threshold argument."""
with pytest.warns(None) as record:
metric(thresholds=thresholds)
assert len(record) == 0
2 changes: 1 addition & 1 deletion tests/unittests/classification/test_specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


def _calc_specificity(tn, fp):
"""safely calculate specificity."""
"""Safely calculate specificity."""
denom = tn + fp
if np.isscalar(tn):
denom = 1.0 if denom == 0 else denom
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def test_multilabel_specificity_at_sensitivity_threshold_arg(self, input, min_se
)
@pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)])
def test_valid_input_thresholds(metric, thresholds):
"""test valid formats of the threshold argument."""
"""Test valid formats of the threshold argument."""
with pytest.warns(None) as record:
metric(min_sensitivity=0.5, thresholds=thresholds)
assert len(record) == 0
12 changes: 5 additions & 7 deletions tests/unittests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def _assert_tensor(pl_result: Any, key: Optional[str] = None) -> None:


def _assert_requires_grad(metric: Metric, pl_result: Any, key: Optional[str] = None) -> None:
"""Utility function for recursively asserting that metric output is consistent with the `is_differentiable`
attribute.
"""
"""Function for recursively asserting that metric output is consistent with the `is_differentiable` attribute."""
if isinstance(pl_result, Sequence):
for plr in pl_result:
_assert_requires_grad(metric, plr, key=key)
Expand Down Expand Up @@ -346,11 +344,11 @@ def _assert_dtype_support(


class MetricTester:
"""Class used for efficiently run alot of parametrized tests in ddp mode. Makes sure that ddp is only setup
once and that pool of processes are used for all tests.
"""General test class for all metrics
All tests should subclass from this and implement a new method called `test_metric_name` where the method
`self.run_metric_test` is called inside.
Class used for efficiently run alot of parametrized tests in ddp mode. Makes sure that ddp is only setup once and
that pool of processes are used for all tests. All tests should subclass from this and implement a new method called
`test_metric_name` where the method `self.run_metric_test` is called inside.
"""

atol: float = 1e-8
Expand Down
8 changes: 3 additions & 5 deletions tests/unittests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

@pytest.mark.parametrize("matrix_size", [2, 10, 100, 500])
def test_matrix_sqrt(matrix_size):
"""test that metrix sqrt function works as expected."""
"""Test that metrix sqrt function works as expected."""

def generate_cov(n):
data = torch.randn(2 * n, n)
Expand Down Expand Up @@ -91,9 +91,7 @@ def test_fid_raises_errors_and_warnings():
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
def test_fid_same_input(feature):
"""if real and fake are update on the same data the fid score should be
0.
"""
"""If real and fake are update on the same data the fid score should be 0."""
metric = FrechetInceptionDistance(feature=feature)

for _ in range(2):
Expand Down Expand Up @@ -124,7 +122,7 @@ def __len__(self):
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.parametrize("equal_size", [False, True])
def test_compare_fid(tmpdir, equal_size, feature=768):
"""check that the hole pipeline give the same result as torch-fidelity."""
"""Check that the hole pipeline give the same result as torch-fidelity."""
from torch_fidelity import calculate_metrics

metric = FrechetInceptionDistance(feature=feature).cuda()
Expand Down
4 changes: 1 addition & 3 deletions tests/unittests/image/test_image_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def test_invalid_input_ndims(batch_size=1, height=5, width=5, channels=1):


def test_multi_batch_image_gradients(batch_size=5, height=5, width=5, channels=1):
"""Test whether the module correctly calculates gradients for known input with non-unity batch size.Example
input-output pair taken from TF's implementation of i mage-gradients.
"""
"""Test whether the module correctly calculates gradients for known input with non-unity batch size."""
single_channel_img = torch.arange(0, 1 * height * width * channels, dtype=torch.float32)
single_channel_img = torch.reshape(single_channel_img, (channels, height, width))
image = torch.stack([single_channel_img for _ in range(batch_size)], dim=0)
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/image/test_inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __len__(self):
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.parametrize("compute_on_cpu", [True, False])
def test_compare_is(tmpdir, compute_on_cpu):
"""check that the hole pipeline give the same result as torch-fidelity."""
"""Check that the hole pipeline give the same result as torch-fidelity."""
from torch_fidelity import calculate_metrics

metric = InceptionScore(splits=1, compute_on_cpu=compute_on_cpu).cuda()
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/image/test_kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_kid_extra_parameters():
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
def test_kid_same_input(feature):
"""test that the metric works."""
"""Test that the metric works."""
metric = KernelInceptionDistance(feature=feature, subsets=5, subset_size=2)

for _ in range(2):
Expand Down Expand Up @@ -134,7 +134,7 @@ def __len__(self):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
def test_compare_kid(tmpdir, feature=2048):
"""check that the hole pipeline give the same result as torch-fidelity."""
"""Check that the hole pipeline give the same result as torch-fidelity."""
from torch_fidelity import calculate_metrics

metric = KernelInceptionDistance(feature=feature, subsets=1, subset_size=100).cuda()
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/image/test_lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, normalize: bool, reduction: str = "mean") -> Tensor:
"""comparison function for tm implementation."""
"""Comparison function for tm implementation."""
ref = LPIPS_reference(net=net_type)
res = ref(img1, img2, normalize=normalize).detach().cpu().numpy()
if reduction == "mean":
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_error_on_wrong_init():
],
)
def test_error_on_wrong_update(inp1, inp2):
"""test error is raised on wrong input to update method."""
"""Test error is raised on wrong input to update method."""
metric = LearnedPerceptualImagePatchSimilarity()
with pytest.raises(ValueError, match="Expected both input arguments to be normalized tensors .*"):
metric(inp1, inp2)
4 changes: 3 additions & 1 deletion tests/unittests/image/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ def test_ssim_half_gpu(self, preds, target, sigma):
],
)
def test_ssim_invalid_inputs(pred, target, kernel, sigma):
"""Test that an value errors are raised if input sizes are different, kernel length and sigma does not match
"""Test for invalid input.
Checks that that an value errors are raised if input sizes are different, kernel length and sigma does not match
size or invalid values are provided.
"""
pred = torch.rand(pred)
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/image/test_tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_sam_half_gpu(self, preds, target, reduction):


def test_correct_args():
"""that that arguments have the right type and sizes."""
"""That that arguments have the right type and sizes."""
with pytest.raises(ValueError, match="Expected argument `reduction`.*"):
_ = TotalVariation(reduction="diff")

Expand Down
Loading

0 comments on commit 1678d2c

Please sign in to comment.