From f68d1d0b78b2c03f5a71ea5715ea48e03d890749 Mon Sep 17 00:00:00 2001 From: stancld Date: Tue, 6 Sep 2022 10:21:18 +0200 Subject: [PATCH 1/5] [MPS support] Make Jaccard Index working on MPS Fixes #1196 * Change `_bincount` calculation for `MPS` to run for loop fallback * Use torch.where implementation to apply `absent_score` instead of relying on item assignment There's still a warning on PyTorch side (using CPU fallback for some operations), however, no actions on our users' side is now required and the results are obtained smoothly. --- src/torchmetrics/functional/classification/jaccard.py | 2 +- src/torchmetrics/utilities/data.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 4f9b9c2400e..947d87aab4f 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -62,7 +62,7 @@ def _jaccard_from_confmat( # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. scores = intersection.float() / union.float() - scores[union == 0] = absent_score + scores.where(union == 0, torch.tensor(absent_score, dtype=scores.dtype, device=scores.device)) if ignore_index is not None and 0 <= ignore_index < num_classes: scores = torch.cat( diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index abf19598343..7b9376e8a09 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -242,7 +242,10 @@ def _squeeze_if_scalar(data: Any) -> Any: def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: - """``torch.bincount`` currently does not support deterministic mode on GPU. + """PyTorch currently does not support``torch.bincount`` for: + + - deterministic mode on GPU. + - MPS devices This implementation fallback to a for-loop counting occurrences in that case. @@ -253,7 +256,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: Returns: Number of occurrences for each unique element in x """ - if x.is_cuda and deterministic(): + if (x.is_cuda and deterministic()) or x.is_mps: if minlength is None: minlength = len(torch.unique(x)) output = torch.zeros(minlength, device=x.device, dtype=torch.long) From ebc1390200d0472067a9f610ebc77c3f20967d84 Mon Sep 17 00:00:00 2001 From: stancld Date: Tue, 6 Sep 2022 10:35:06 +0200 Subject: [PATCH 2/5] Check for x.is_mps only when _TORCH_GREATER_EQUAL_1_12 --- src/torchmetrics/utilities/data.py | 9 +++++++-- src/torchmetrics/utilities/imports.py | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index 7b9376e8a09..f9cd329fb11 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -16,7 +16,12 @@ import torch from torch import Tensor, tensor -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8 +from torchmetrics.utilities.imports import ( + _TORCH_GREATER_EQUAL_1_6, + _TORCH_GREATER_EQUAL_1_7, + _TORCH_GREATER_EQUAL_1_8, + _TORCH_GREATER_EQUAL_1_12, +) if _TORCH_GREATER_EQUAL_1_8: deterministic = torch.are_deterministic_algorithms_enabled @@ -256,7 +261,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: Returns: Number of occurrences for each unique element in x """ - if (x.is_cuda and deterministic()) or x.is_mps: + if x.is_cuda and deterministic() or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: if minlength is None: minlength = len(torch.unique(x)) output = torch.zeros(minlength, device=x.device, dtype=torch.long) diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 3622635d2ff..8e0f890cd63 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -104,6 +104,7 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] _TORCH_GREATER_EQUAL_1_8: Optional[bool] = _compare_version("torch", operator.ge, "1.8.0") _TORCH_GREATER_EQUAL_1_10: Optional[bool] = _compare_version("torch", operator.ge, "1.10.0") _TORCH_GREATER_EQUAL_1_11: Optional[bool] = _compare_version("torch", operator.ge, "1.11.0") +_TORCH_GREATER_EQUAL_1_12: Optional[bool] = _compare_version("torch", operator.ge, "1.12.0") _JIWER_AVAILABLE: bool = _package_available("jiwer") _NLTK_AVAILABLE: bool = _package_available("nltk") From 11c291fa6616dd4727ed50b87987e0c659bbc53d Mon Sep 17 00:00:00 2001 From: stancld Date: Tue, 6 Sep 2022 12:18:00 +0200 Subject: [PATCH 3/5] Fix scores assignment --- src/torchmetrics/functional/classification/jaccard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 947d87aab4f..e834d0691c5 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -62,7 +62,7 @@ def _jaccard_from_confmat( # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. scores = intersection.float() / union.float() - scores.where(union == 0, torch.tensor(absent_score, dtype=scores.dtype, device=scores.device)) + scores = scores.where(union != 0, torch.tensor(absent_score, dtype=scores.dtype, device=scores.device)) if ignore_index is not None and 0 <= ignore_index < num_classes: scores = torch.cat( From b145573f29fae76af14cb0c23cac3f55857997f1 Mon Sep 17 00:00:00 2001 From: stancld Date: Thu, 8 Sep 2022 16:18:03 +0200 Subject: [PATCH 4/5] Re-run CI From 9b051a7a66760d4ce2e128a86cb5f4b374e71eb2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 9 Sep 2022 09:57:57 +0200 Subject: [PATCH 5/5] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ec86995483..3cac9ed61cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,7 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed a bug in `ssim` when `return_full_image=True` where the score was still reduced ([#1204](https://github.com/Lightning-AI/metrics/pull/1204)) -- +- Fixed mps support in jaccard index ([#1205](https://github.com/Lightning-AI/metrics/pull/1205)) ## [0.9.3] - 2022-08-22