Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype issues #493

Merged
merged 14 commits into from
Sep 7, 2021
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `average` argument to `AveragePrecision` metric for reducing multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))



### Changed

- `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))


- `half`, `double`, `float` will no longer change the dtype of the metric states. Use `metric.set_dtype` instead ([#493](https://github.com/PyTorchLightning/metrics/pull/493))


### Deprecated


### Removed

- Removed `dtype` property ([#493](https://github.com/PyTorchLightning/metrics/pull/493))


### Fixed

Expand Down
9 changes: 2 additions & 7 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,25 +255,20 @@ def test_device_and_dtype_transfer(tmpdir):
assert metric.x.is_cuda is False
assert metric.device == torch.device("cpu")
assert metric.x.dtype == torch.float32
assert metric.dtype == torch.float32

metric = metric.to(device="cuda")
assert metric.x.is_cuda
assert metric.device == torch.device("cuda")

metric = metric.double()
metric.set_dtype(torch.double)
assert metric.x.dtype == torch.float64
assert metric.dtype == torch.float64
metric.reset()
assert metric.x.dtype == torch.float64
assert metric.dtype == torch.float64

metric = metric.half()
metric.set_dtype(torch.half)
assert metric.x.dtype == torch.float16
assert metric.dtype == torch.float16
metric.reset()
assert metric.x.dtype == torch.float16
assert metric.dtype == torch.float16


def test_warning_on_compute_before_update():
Expand Down
89 changes: 48 additions & 41 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Metric(Module, ABC):
will be used to perform the allgather.
"""

__jit_ignored_attributes__ = ["device", "dtype"]
__jit_ignored_attributes__ = ["device"]
__jit_unused_properties__ = ["is_differentiable"]

def __init__(
Expand All @@ -84,7 +84,6 @@ def __init__(
torch._C._log_api_usage_once(f"torchmetrics.metric.{self.__class__.__name__}")

self._LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", op.ge, "1.3.0")
self._dtype: Union[str, torch.dtype] = torch.get_default_dtype()
self._device = torch.device("cpu")

self.dist_sync_on_step = dist_sync_on_step
Expand Down Expand Up @@ -413,30 +412,36 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self.update: Callable = self._wrap_update(self.update) # type: ignore
self.compute: Callable = self._wrap_compute(self.compute) # type: ignore

@property
def dtype(self) -> "torch.dtype":
"""Return the dtype of the metric."""
return self._dtype

@dtype.setter
def dtype(self, new_dtype: "torch.dtype") -> None:
# necessary to avoid infinite recursion
raise RuntimeError("Cannot set the dtype explicitly. Please use module.to(new_dtype).")

@property
def device(self) -> "torch.device":
"""Return the device of the metric."""
return self._device

def to(self, *args: Any, **kwargs: Any) -> "Metric":
"""Moves and/or casts the parameters and buffers.
"""Moves the parameters and buffers.

Works similar to nn.Module.to but also updates the metrics device and dtype properties
Normal dtype casting is not supported by this method instead use the `set_dtype` method instead.
"""
# there is diff nb vars in PT 1.5
out = torch._C._nn._parse_to(*args, **kwargs)
self._update_properties(device=out[0], dtype=out[1])
return super().to(*args, **kwargs)
if len(out) == 4: # pytorch 1.5 and higher
device, dtype, non_blocking, convert_to_format = out
Borda marked this conversation as resolved.
Show resolved Hide resolved
else: # pytorch 1.4 and lower
device, dtype, non_blocking = out
convert_to_format = None
dtype = None # prevent dtype being casted

def convert(t: Tensor) -> Tensor:
if convert_to_format is not None and t.dim() in (4, 5):
return t.to(
device,
dtype if t.is_floating_point() or t.is_complex() else None,
non_blocking,
memory_format=convert_to_format,
)
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)

self._device = device
return self._apply(convert)

def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "Metric":
"""Moves all model parameters and buffers to the GPU.
Expand All @@ -446,46 +451,48 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "Metric":
"""
if device is None or isinstance(device, int):
device = torch.device("cuda", index=device)
self._update_properties(device=device)
self._device = device
return super().cuda(device=device)

def cpu(self) -> "Metric":
"""Moves all model parameters and buffers to the CPU."""
self._update_properties(device=torch.device("cpu"))
self._device = torch.device("cpu")
return super().cpu()

def type(self, dst_type: Union[str, torch.dtype]) -> "Metric":
"""Casts all parameters and buffers to :attr:`dst_type`.
"""Method override default and prevent dtype casting.

Arguments:
dst_type (type or string): the desired type
Please use `metric.set_dtype(dtype)` instead.
"""
self._update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)
return self

def float(self) -> "Metric":
"""Casts all floating point parameters and buffers to ``float`` datatype."""
self._update_properties(dtype=torch.float)
return super().float()
"""Method override default and prevent dtype casting.

Please use `metric.set_dtype(dtype)` instead.
"""
return self

def double(self) -> "Metric":
"""Casts all floating point parameters and buffers to ``double`` datatype."""
self._update_properties(dtype=torch.double)
return super().double()
"""Method override default and prevent dtype casting.

Please use `metric.set_dtype(dtype)` instead.
"""
return self

def half(self) -> "Metric":
"""Casts all floating point parameters and buffers to ``half`` datatype."""
self._update_properties(dtype=torch.half)
return super().half()
"""Method override default and prevent dtype casting.

def _update_properties(
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
"""Updates the internal device and or dtype attributes of the metric."""
if device is not None:
self._device = device
if dtype is not None:
self._dtype = dtype
Please use `metric.set_dtype(dtype)` instead.
"""
return self

def set_dtype(self, dst_type: Union[str, torch.dtype]) -> None:
"""Special version of `type` for transferring all metric states to specific dtype
Arguments:
dst_type (type or string): the desired type
"""
return super().type(dst_type)

def _apply(self, fn: Callable) -> Module:
"""Overwrite _apply function such that we can also move metric states to the correct device when `.to`,
Expand Down