Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Feb 15, 2024
1 parent 442c463 commit 050a8c7
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
10 changes: 5 additions & 5 deletions src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def _cast_and_nan_check_input(
) -> Tuple[Tensor, Tensor]:
"""Convert input ``x`` to a tensor and check for Nans."""
if not isinstance(x, Tensor):
x = torch.as_tensor(x, dtype=torch.float32, device=self.device)
x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
if weight is not None and not isinstance(weight, Tensor):
weight = torch.as_tensor(weight, dtype=torch.float32, device=self.device)
weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)

nans = torch.isnan(x)
if weight is not None:
Expand All @@ -101,7 +101,7 @@ def _cast_and_nan_check_input(
x[nans | nans_weight] = self.nan_strategy
weight[nans | nans_weight] = self.nan_strategy

return x.float(), weight.float()
return x.to(self.dtype), weight.to(self.dtype)

def update(self, value: Union[float, Tensor]) -> None:
"""Overwrite in child class."""
Expand Down Expand Up @@ -557,9 +557,9 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0
"""
# broadcast weight to value shape
if not isinstance(value, Tensor):
value = torch.as_tensor(value, dtype=torch.float32, device=self.device)
value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
if weight is not None and not isinstance(weight, Tensor):
weight = torch.as_tensor(weight, dtype=torch.float32, device=self.device)
weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)
weight = torch.broadcast_to(weight, value.shape)
value, weight = self._cast_and_nan_check_input(value, weight)

Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,11 @@ def test_with_default_dtype(metric_class, compare_function):
"""Test that the metric works with a default dtype of float64."""
torch.set_default_dtype(torch.float64)
metric = metric_class()
assert metric.dtype == torch.float64
values = torch.randn(10000)
metric.update(values)
result = metric.compute()
assert result.dtype == torch.float64
assert result.dtype == values.dtype
assert torch.allclose(result, compare_function(values), atol=1e-12)
assert result == compare_function(values)
torch.set_default_dtype(torch.float32)
16 changes: 16 additions & 0 deletions tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,19 @@ def test_update_properties(metric, method):
m.reset()
assert not m.update_called
assert m.update_count == 0


def test_dtype_property():
"""Test that dtype property works as expected."""
metric = DummyMetricSum()
assert metric.dtype == torch.float32
metric.set_dtype(torch.float64)
assert metric.dtype == torch.float64

torch.set_default_dtype(torch.float64)
metric = DummyMetricSum()
assert metric.dtype == torch.float64
torch.set_default_dtype(torch.float32)
assert metric.dtype == torch.float64 # should not change after initialization
metric.set_dtype(torch.float32)
assert metric.dtype == torch.float32

0 comments on commit 050a8c7

Please sign in to comment.