diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 947d87aab4f..96f59c1e7e0 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -62,7 +62,7 @@ def _jaccard_from_confmat( # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. scores = intersection.float() / union.float() - scores.where(union == 0, torch.tensor(absent_score, dtype=scores.dtype, device=scores.device)) + scores = scores.where(union == 0, torch.tensor(absent_score, dtype=scores.dtype, device=scores.device)) if ignore_index is not None and 0 <= ignore_index < num_classes: scores = torch.cat(