diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d33c26fcb9..bb5aa4ba71d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed unintentional downloading of `nltk.punkt` when `lsum` not in `rouge_keys` ([#1258](https://github.com/Lightning-AI/metrics/pull/1258)) +- Fixed type casting in `MAP` metric between `bool` and `float32` ([#1150](https://github.com/Lightning-AI/metrics/pull/1150)) + + ## [0.10.0] - 2022-10-04 ### Added diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 602e2bf8920..fe2eca20de4 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -500,7 +500,7 @@ def __evaluate_image_gt_no_preds( return { "dtMatches": torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device), "gtMatches": torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device), - "dtScores": torch.zeros(nb_det, dtype=torch.bool, device=self.device), + "dtScores": torch.zeros(nb_det, dtype=torch.float32, device=self.device), "gtIgnore": gt_ignore, "dtIgnore": det_ignore, } diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 6ec1341ee69..98601c596f2 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -168,17 +168,31 @@ # Test empty preds case, to ensure bool inputs are properly casted to uint8 # From https://github.com/Lightning-AI/metrics/issues/981 +# and https://github.com/Lightning-AI/metrics/issues/1147 _inputs3 = Input( preds=[ + [ + dict( + boxes=Tensor([[258.0, 41.0, 606.0, 285.0]]), + scores=Tensor([0.536]), + labels=IntTensor([0]), + ), + ], [ dict(boxes=Tensor([]), scores=Tensor([]), labels=Tensor([])), ], ], target=[ + [ + dict( + boxes=Tensor([[214.0, 41.0, 562.0, 285.0]]), + labels=IntTensor([0]), + ) + ], [ dict( boxes=Tensor([[1.0, 2.0, 3.0, 4.0]]), - scores=Tensor([0.8]), + scores=Tensor([0.8]), # target does not have scores labels=Tensor([1]), ), ],