From 1b601b530d6ffddda087e8266ec4a9a84dcf3e91 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 14 Nov 2023 14:46:37 +0100 Subject: [PATCH 1/2] impl + tests --- src/torchmetrics/detection/mean_ap.py | 4 ++++ tests/unittests/detection/test_map.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index e155a579010..a33c9000d87 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -202,6 +202,9 @@ class MeanAveragePrecision(Metric): - ``recall``: a tensor of shape ``(TxKxAxM)`` containing the recall values. Here ``T`` is the number of IoU thresholds, ``K`` is the number of classes, ``A`` is the number of areas and ``M`` is the number of max detections per image. + - ``scores``: a tensor of shape ``(TxRxKxAxM)`` containing the confidence scores. Here ``T`` is the + number of IoU thresholds, ``R`` is the number of recall thresholds, ``K`` is the number of classes, + ``A`` is the number of areas and ``M`` is the number of max detections per image. average: Method for averaging scores over labels. Choose between "``"macro"`` and ``"micro"``. @@ -531,6 +534,7 @@ def compute(self) -> dict: ), f"{prefix}precision": torch.tensor(coco_eval.eval["precision"]), f"{prefix}recall": torch.tensor(coco_eval.eval["recall"]), + f"{prefix}scores": torch.tensor(coco_eval.eval["scores"]), } result_dict.update(summary) diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 3027035c4da..e657a768596 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -741,7 +741,7 @@ def test_warning_on_many_detections(self, iou_type, backend): metric.update(preds, targets) @pytest.mark.parametrize( - ("preds", "target", "expected_iou_len", "iou_keys", "precision_shape", "recall_shape"), + ("preds", "target", "expected_iou_len", "iou_keys", "precision_shape", "recall_shape", "scores_shape"), [ ( [ @@ -758,6 +758,7 @@ def test_warning_on_many_detections(self, iou_type, backend): [(0, 0)], (10, 101, 1, 4, 3), (10, 1, 4, 3), + (10, 101, 1, 4, 3), ), ( _inputs["preds"], @@ -766,11 +767,12 @@ def test_warning_on_many_detections(self, iou_type, backend): list(product([0, 1, 2, 3], [0, 1, 2, 3, 4, 49])), (10, 101, 6, 4, 3), (10, 6, 4, 3), + (10, 101, 6, 4, 3), ), ], ) def test_for_extended_stats( - self, preds, target, expected_iou_len, iou_keys, precision_shape, recall_shape, backend + self, preds, target, expected_iou_len, iou_keys, precision_shape, recall_shape, scores_shape, backend ): """Test that extended stats are computed correctly.""" metric = MeanAveragePrecision(extended_summary=True, backend=backend) @@ -793,6 +795,10 @@ def test_for_extended_stats( assert isinstance(recall, Tensor) assert recall.shape == recall_shape + scores = result["scores"] + assert isinstance(scores, Tensor) + assert scores.shape == scores_shape + @pytest.mark.parametrize("class_metrics", [False, True]) def test_average_argument(self, class_metrics, backend): """Test that average argument works. From cf8289ffcf96af7a46b1a1b5a4002d561d0ea4d1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 14 Nov 2023 14:48:43 +0100 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 361f564dbb7..109bba8acdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for Pytorch v2.1 ([#2142](https://github.com/Lightning-AI/torchmetrics/pull/2142)) +- Added confidence scores when `extended_summary=True` in `MeanAveragePrecision` ([#2212](https://github.com/Lightning-AI/torchmetrics/pull/2212)) + ### Changed - Change default state of `SpectralAngleMapper` and `UniversalImageQualityIndex` to be tensors ([#2089](https://github.com/Lightning-AI/torchmetrics/pull/2089))