diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bc1234c5da..f4c13f971d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix compatibility between XLA in `_bincount` function ([#1471](https://github.com/Lightning-AI/metrics/pull/1471)) + + - diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index b5b72f6a9e1..214ee5a954b 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -16,7 +16,7 @@ import torch from torch import Tensor -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _XLA_AVAILABLE METRIC_EPS = 1e-6 @@ -220,7 +220,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: """ if minlength is None: minlength = len(torch.unique(x)) - if torch.are_deterministic_algorithms_enabled() or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: + if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: output = torch.zeros(minlength, device=x.device, dtype=torch.long) for i in range(minlength): output[i] = (x == i).sum() diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index f4f3080d854..8b2e723f017 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -118,3 +118,4 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] _PYSTOI_AVAILABLE: bool = _package_available("pystoi") _FAST_BSS_EVAL_AVAILABLE: bool = _package_available("fast_bss_eval") _MULTIPROCESSING_AVAILABLE: bool = _package_available("multiprocessing") +_XLA_AVAILABLE: bool = _package_available("torch_xla")