diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ec86995483..3cac9ed61cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,7 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed a bug in `ssim` when `return_full_image=True` where the score was still reduced ([#1204](https://github.com/Lightning-AI/metrics/pull/1204)) -- +- Fixed mps support in jaccard index ([#1205](https://github.com/Lightning-AI/metrics/pull/1205)) ## [0.9.3] - 2022-08-22 diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 4f9b9c2400e..e834d0691c5 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -62,7 +62,7 @@ def _jaccard_from_confmat( # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. scores = intersection.float() / union.float() - scores[union == 0] = absent_score + scores = scores.where(union != 0, torch.tensor(absent_score, dtype=scores.dtype, device=scores.device)) if ignore_index is not None and 0 <= ignore_index < num_classes: scores = torch.cat( diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index abf19598343..f9cd329fb11 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -16,7 +16,12 @@ import torch from torch import Tensor, tensor -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8 +from torchmetrics.utilities.imports import ( + _TORCH_GREATER_EQUAL_1_6, + _TORCH_GREATER_EQUAL_1_7, + _TORCH_GREATER_EQUAL_1_8, + _TORCH_GREATER_EQUAL_1_12, +) if _TORCH_GREATER_EQUAL_1_8: deterministic = torch.are_deterministic_algorithms_enabled @@ -242,7 +247,10 @@ def _squeeze_if_scalar(data: Any) -> Any: def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: - """``torch.bincount`` currently does not support deterministic mode on GPU. + """PyTorch currently does not support``torch.bincount`` for: + + - deterministic mode on GPU. + - MPS devices This implementation fallback to a for-loop counting occurrences in that case. @@ -253,7 +261,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: Returns: Number of occurrences for each unique element in x """ - if x.is_cuda and deterministic(): + if x.is_cuda and deterministic() or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: if minlength is None: minlength = len(torch.unique(x)) output = torch.zeros(minlength, device=x.device, dtype=torch.long) diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 3622635d2ff..8e0f890cd63 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -104,6 +104,7 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] _TORCH_GREATER_EQUAL_1_8: Optional[bool] = _compare_version("torch", operator.ge, "1.8.0") _TORCH_GREATER_EQUAL_1_10: Optional[bool] = _compare_version("torch", operator.ge, "1.10.0") _TORCH_GREATER_EQUAL_1_11: Optional[bool] = _compare_version("torch", operator.ge, "1.11.0") +_TORCH_GREATER_EQUAL_1_12: Optional[bool] = _compare_version("torch", operator.ge, "1.12.0") _JIWER_AVAILABLE: bool = _package_available("jiwer") _NLTK_AVAILABLE: bool = _package_available("nltk")