diff --git a/CHANGELOG.md b/CHANGELOG.md index 40ed1b62c79..820f242e754 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462)) +- Fixed warnings being suppressed in `MeanAveragePrecision` when requested ([#2501](https://github.com/Lightning-AI/torchmetrics/pull/2501)) + + ## [1.3.2] - 2024-03-18 ### Fixed diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 9f37208212f..b36e4ed89a6 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -827,8 +827,9 @@ def _get_safe_item_values( rle = self.mask_utils.encode(np.asfortranarray(i)) masks.append((tuple(rle["size"]), rle["counts"])) output[1] = tuple(masks) # type: ignore[call-overload] - if (output[0] is not None and len(output[0]) > self.max_detection_thresholds[-1]) or ( - output[1] is not None and len(output[1]) > self.max_detection_thresholds[-1] + if warn and ( + (output[0] is not None and len(output[0]) > self.max_detection_thresholds[-1]) + or (output[1] is not None and len(output[1]) > self.max_detection_thresholds[-1]) ): _warning_on_too_many_detections(self.max_detection_thresholds[-1]) return output # type: ignore[return-value] diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index f0dcdff52f2..64af4aab4ab 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -712,7 +712,8 @@ def test_for_box_format(self, box_format, iou_val_expected, map_val_expected, ba assert round(float(result["ious"][(0, 0)]), 3) == iou_val_expected @pytest.mark.parametrize("iou_type", ["bbox", "segm"]) - def test_warning_on_many_detections(self, iou_type, backend): + @pytest.mark.parametrize("warn_on_many_detections", [False, True]) + def test_warning_on_many_detections(self, iou_type, warn_on_many_detections, backend, recwarn): """Test that a warning is raised when there are many detections.""" if iou_type == "bbox": preds = [ @@ -727,8 +728,13 @@ def test_warning_on_many_detections(self, iou_type, backend): preds, targets = _generate_random_segm_input("cpu", 1, 101, 10, False) metric = MeanAveragePrecision(iou_type=iou_type, backend=backend) - with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"): - metric.update(preds, targets) + metric.warn_on_many_detections = warn_on_many_detections + + if warn_on_many_detections: + with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"): + metric.update(preds, targets) + else: + assert len(recwarn) == 0 @pytest.mark.parametrize( ("preds", "target", "expected_iou_len", "iou_keys", "precision_shape", "recall_shape", "scores_shape"),