diff --git a/src/torchmetrics/aggregation.py b/src/torchmetrics/aggregation.py index a4332589898..ee4f86ffdc3 100644 --- a/src/torchmetrics/aggregation.py +++ b/src/torchmetrics/aggregation.py @@ -77,9 +77,9 @@ def _cast_and_nan_check_input( ) -> Tuple[Tensor, Tensor]: """Convert input ``x`` to a tensor and check for Nans.""" if not isinstance(x, Tensor): - x = torch.as_tensor(x, dtype=torch.float32, device=self.device) + x = torch.as_tensor(x, dtype=self.dtype, device=self.device) if weight is not None and not isinstance(weight, Tensor): - weight = torch.as_tensor(weight, dtype=torch.float32, device=self.device) + weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device) nans = torch.isnan(x) if weight is not None: @@ -101,7 +101,7 @@ def _cast_and_nan_check_input( x[nans | nans_weight] = self.nan_strategy weight[nans | nans_weight] = self.nan_strategy - return x.float(), weight.float() + return x.to(self.dtype), weight.to(self.dtype) def update(self, value: Union[float, Tensor]) -> None: """Overwrite in child class.""" @@ -557,9 +557,9 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0 """ # broadcast weight to value shape if not isinstance(value, Tensor): - value = torch.as_tensor(value, dtype=torch.float32, device=self.device) + value = torch.as_tensor(value, dtype=self.dtype, device=self.device) if weight is not None and not isinstance(weight, Tensor): - weight = torch.as_tensor(weight, dtype=torch.float32, device=self.device) + weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device) weight = torch.broadcast_to(weight, value.shape) value, weight = self._cast_and_nan_check_input(value, weight) diff --git a/tests/unittests/bases/test_aggregation.py b/tests/unittests/bases/test_aggregation.py index a2d34cb3469..5a65eaa3fb4 100644 --- a/tests/unittests/bases/test_aggregation.py +++ b/tests/unittests/bases/test_aggregation.py @@ -214,10 +214,11 @@ def test_with_default_dtype(metric_class, compare_function): """Test that the metric works with a default dtype of float64.""" torch.set_default_dtype(torch.float64) metric = metric_class() + assert metric.dtype == torch.float64 values = torch.randn(10000) metric.update(values) result = metric.compute() assert result.dtype == torch.float64 assert result.dtype == values.dtype - assert torch.allclose(result, compare_function(values), atol=1e-12) + assert result == compare_function(values) torch.set_default_dtype(torch.float32) diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index a7211f2982e..1cf2f1fb4ed 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -573,3 +573,19 @@ def test_update_properties(metric, method): m.reset() assert not m.update_called assert m.update_count == 0 + + +def test_dtype_property(): + """Test that dtype property works as expected.""" + metric = DummyMetricSum() + assert metric.dtype == torch.float32 + metric.set_dtype(torch.float64) + assert metric.dtype == torch.float64 + + torch.set_default_dtype(torch.float64) + metric = DummyMetricSum() + assert metric.dtype == torch.float64 + torch.set_default_dtype(torch.float32) + assert metric.dtype == torch.float64 # should not change after initialization + metric.set_dtype(torch.float32) + assert metric.dtype == torch.float32