From 3a4ff87bdf1f0fdf394b414041e5858adc819b12 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 7 Sep 2021 10:48:25 +0200 Subject: [PATCH] Fix dtype issues (#493) * remove dtype * yaml testing * fix to and tests * chlog Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 6 ++- tests/bases/test_metric.py | 9 +--- torchmetrics/metric.py | 89 ++++++++++++++++++++------------------ 3 files changed, 55 insertions(+), 49 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c2d6f007cf6..a23a6128fc2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,17 +21,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `average` argument to `AveragePrecision` metric for reducing multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) - ### Changed - `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) +- `half`, `double`, `float` will no longer change the dtype of the metric states. Use `metric.set_dtype` instead ([#493](https://github.com/PyTorchLightning/metrics/pull/493)) + + ### Deprecated ### Removed +- Removed `dtype` property ([#493](https://github.com/PyTorchLightning/metrics/pull/493)) + ### Fixed diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index deddc8d827f..88b8ee27c0e 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -255,25 +255,20 @@ def test_device_and_dtype_transfer(tmpdir): assert metric.x.is_cuda is False assert metric.device == torch.device("cpu") assert metric.x.dtype == torch.float32 - assert metric.dtype == torch.float32 metric = metric.to(device="cuda") assert metric.x.is_cuda assert metric.device == torch.device("cuda") - metric = metric.double() + metric.set_dtype(torch.double) assert metric.x.dtype == torch.float64 - assert metric.dtype == torch.float64 metric.reset() assert metric.x.dtype == torch.float64 - assert metric.dtype == torch.float64 - metric = metric.half() + metric.set_dtype(torch.half) assert metric.x.dtype == torch.float16 - assert metric.dtype == torch.float16 metric.reset() assert metric.x.dtype == torch.float16 - assert metric.dtype == torch.float16 def test_warning_on_compute_before_update(): diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 3e5dc5bbfbb..698ae799cef 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -67,7 +67,7 @@ class Metric(Module, ABC): will be used to perform the allgather. """ - __jit_ignored_attributes__ = ["device", "dtype"] + __jit_ignored_attributes__ = ["device"] __jit_unused_properties__ = ["is_differentiable"] def __init__( @@ -84,7 +84,6 @@ def __init__( torch._C._log_api_usage_once(f"torchmetrics.metric.{self.__class__.__name__}") self._LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", op.ge, "1.3.0") - self._dtype: Union[str, torch.dtype] = torch.get_default_dtype() self._device = torch.device("cpu") self.dist_sync_on_step = dist_sync_on_step @@ -413,30 +412,36 @@ def __setstate__(self, state: Dict[str, Any]) -> None: self.update: Callable = self._wrap_update(self.update) # type: ignore self.compute: Callable = self._wrap_compute(self.compute) # type: ignore - @property - def dtype(self) -> "torch.dtype": - """Return the dtype of the metric.""" - return self._dtype - - @dtype.setter - def dtype(self, new_dtype: "torch.dtype") -> None: - # necessary to avoid infinite recursion - raise RuntimeError("Cannot set the dtype explicitly. Please use module.to(new_dtype).") - @property def device(self) -> "torch.device": """Return the device of the metric.""" return self._device def to(self, *args: Any, **kwargs: Any) -> "Metric": - """Moves and/or casts the parameters and buffers. + """Moves the parameters and buffers. - Works similar to nn.Module.to but also updates the metrics device and dtype properties + Normal dtype casting is not supported by this method instead use the `set_dtype` method instead. """ - # there is diff nb vars in PT 1.5 out = torch._C._nn._parse_to(*args, **kwargs) - self._update_properties(device=out[0], dtype=out[1]) - return super().to(*args, **kwargs) + if len(out) == 4: # pytorch 1.5 and higher + device, dtype, non_blocking, convert_to_format = out + else: # pytorch 1.4 and lower + device, dtype, non_blocking = out + convert_to_format = None + dtype = None # prevent dtype being casted + + def convert(t: Tensor) -> Tensor: + if convert_to_format is not None and t.dim() in (4, 5): + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + memory_format=convert_to_format, + ) + return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking) + + self._device = device + return self._apply(convert) def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "Metric": """Moves all model parameters and buffers to the GPU. @@ -446,46 +451,48 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "Metric": """ if device is None or isinstance(device, int): device = torch.device("cuda", index=device) - self._update_properties(device=device) + self._device = device return super().cuda(device=device) def cpu(self) -> "Metric": """Moves all model parameters and buffers to the CPU.""" - self._update_properties(device=torch.device("cpu")) + self._device = torch.device("cpu") return super().cpu() def type(self, dst_type: Union[str, torch.dtype]) -> "Metric": - """Casts all parameters and buffers to :attr:`dst_type`. + """Method override default and prevent dtype casting. - Arguments: - dst_type (type or string): the desired type + Please use `metric.set_dtype(dtype)` instead. """ - self._update_properties(dtype=dst_type) - return super().type(dst_type=dst_type) + return self def float(self) -> "Metric": - """Casts all floating point parameters and buffers to ``float`` datatype.""" - self._update_properties(dtype=torch.float) - return super().float() + """Method override default and prevent dtype casting. + + Please use `metric.set_dtype(dtype)` instead. + """ + return self def double(self) -> "Metric": - """Casts all floating point parameters and buffers to ``double`` datatype.""" - self._update_properties(dtype=torch.double) - return super().double() + """Method override default and prevent dtype casting. + + Please use `metric.set_dtype(dtype)` instead. + """ + return self def half(self) -> "Metric": - """Casts all floating point parameters and buffers to ``half`` datatype.""" - self._update_properties(dtype=torch.half) - return super().half() + """Method override default and prevent dtype casting. - def _update_properties( - self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None - ) -> None: - """Updates the internal device and or dtype attributes of the metric.""" - if device is not None: - self._device = device - if dtype is not None: - self._dtype = dtype + Please use `metric.set_dtype(dtype)` instead. + """ + return self + + def set_dtype(self, dst_type: Union[str, torch.dtype]) -> None: + """Special version of `type` for transferring all metric states to specific dtype + Arguments: + dst_type (type or string): the desired type + """ + return super().type(dst_type) def _apply(self, fn: Callable) -> Module: """Overwrite _apply function such that we can also move metric states to the correct device when `.to`,