diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index daed5c69501..7b1abdd7e32 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -16,11 +16,11 @@ import sys from functools import partial from typing import Any, Callable, Dict, Optional, Sequence -import yaml import numpy as np import pytest import torch +import yaml from torch import Tensor, tensor from torch.multiprocessing import Pool, set_start_method diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 05c113d04c3..698ae799cef 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -418,11 +418,12 @@ def device(self) -> "torch.device": return self._device def to(self, *args: Any, **kwargs: Any) -> "Metric": - """Moves the parameters and buffers. Normal dtype casting is not supported by this method - instead use the `set_dtype` method instead. + """Moves the parameters and buffers. + + Normal dtype casting is not supported by this method instead use the `set_dtype` method instead. """ out = torch._C._nn._parse_to(*args, **kwargs) - if len(out)==4: # pytorch 1.5 and higher + if len(out) == 4: # pytorch 1.5 and higher device, dtype, non_blocking, convert_to_format = out else: # pytorch 1.4 and lower device, dtype, non_blocking = out @@ -431,8 +432,12 @@ def to(self, *args: Any, **kwargs: Any) -> "Metric": def convert(t: Tensor) -> Tensor: if convert_to_format is not None and t.dim() in (4, 5): - return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, - non_blocking, memory_format=convert_to_format) + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + memory_format=convert_to_format, + ) return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking) self._device = device