diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fcd0e8f1cf..b886a933fb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added new detection metric `PanopticQuality` ([#929](https://github.com/PyTorchLightning/metrics/pull/929)) -- Add `ClassificationTask` Enum and use in metrics ([#1479](https://github.com/Lightning-AI/metrics/pull/1479)) +- Added `ClassificationTask` Enum and use in metrics ([#1479](https://github.com/Lightning-AI/metrics/pull/1479)) + + +- Added `ignore_index` option to `exact_match` metric ([#1540](https://github.com/Lightning-AI/metrics/pull/1540)) ### Changed diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 0ad8f036096..745208fce9f 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -124,7 +124,8 @@ def update(self, preds, target) -> None: preds, target, self.num_classes, self.multidim_average, self.ignore_index ) preds, target = _multiclass_stat_scores_format(preds, target, 1) - correct, total = _multiclass_exact_match_update(preds, target, self.multidim_average) + + correct, total = _multiclass_exact_match_update(preds, target, self.multidim_average, self.ignore_index) if self.multidim_average == "samplewise": self.correct.append(correct) self.total = total diff --git a/src/torchmetrics/functional/classification/exact_match.py b/src/torchmetrics/functional/classification/exact_match.py index da41cb8c016..43f39932617 100644 --- a/src/torchmetrics/functional/classification/exact_match.py +++ b/src/torchmetrics/functional/classification/exact_match.py @@ -41,8 +41,13 @@ def _multiclass_exact_match_update( preds: Tensor, target: Tensor, multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: """Compute the statistics.""" + if ignore_index is not None: + preds = preds.clone() + preds[target == ignore_index] = ignore_index + correct = (preds == target).sum(1) == preds.shape[1] correct = correct if multidim_average == "samplewise" else correct.sum() total = torch.tensor(preds.shape[0] if multidim_average == "global" else 1, device=correct.device) @@ -109,7 +114,7 @@ def multiclass_exact_match( _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) preds, target = _multiclass_stat_scores_format(preds, target, top_k) - correct, total = _multiclass_exact_match_update(preds, target, multidim_average) + correct, total = _multiclass_exact_match_update(preds, target, multidim_average, ignore_index) return _exact_match_reduce(correct, total) diff --git a/tests/unittests/classification/test_exact_match.py b/tests/unittests/classification/test_exact_match.py index 73eb4cf738a..32f5a2aedd1 100644 --- a/tests/unittests/classification/test_exact_match.py +++ b/tests/unittests/classification/test_exact_match.py @@ -34,8 +34,8 @@ def _baseline_exact_match_multiclass(preds, target, ignore_index, multidim_avera target = target.numpy() if ignore_index is not None: - target = np.copy(target) - target[target == ignore_index] = -1 + preds = np.copy(preds) + preds[target == ignore_index] = ignore_index correct = (preds == target).sum(-1) == preds.shape[1] correct = correct.sum() if multidim_average == "global" else correct