diff --git a/CHANGELOG.md b/CHANGELOG.md index 353d704f766..6c757194dc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -112,6 +112,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed unsafe log operation in `TweedieDeviace` for power=1 ([#847](https://github.com/PyTorchLightning/metrics/pull/847)) +- Fixed bug in MAP metric related to either no ground truth or no predictions ([#884](https://github.com/PyTorchLightning/metrics/pull/884)) + + ## [0.7.2] - 2022-02-10 ### Fixed diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 40d5e140770..273ca8b8e19 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -266,6 +266,50 @@ def test_empty_metric(): metric.compute() +@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") +def test_missing_pred(): + """One good detection, one false negative. + + Map should be lower than 1. Actually it is 0.5, but the exact value depends on where we are sampling (i.e. recall's + values) + """ + gts = [ + dict(boxes=torch.Tensor([[10, 20, 15, 25]]), labels=torch.IntTensor([0])), + dict(boxes=torch.Tensor([[10, 20, 15, 25]]), labels=torch.IntTensor([0])), + ] + preds = [ + dict(boxes=torch.Tensor([[10, 20, 15, 25]]), scores=torch.Tensor([0.9]), labels=torch.IntTensor([0])), + # Empty prediction + dict(boxes=torch.Tensor([]), scores=torch.Tensor([]), labels=torch.IntTensor([])), + ] + metric = MeanAveragePrecision() + metric.update(preds, gts) + result = metric.compute() + assert result["map"] < 1, "MAP cannot be 1, as there is a missing prediction." + + +@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") +def test_missing_gt(): + """The symmetric case of test_missing_pred. + + One good detection, one false positive. Map should be lower than 1. Actually it is 0.5, but the exact value depends + on where we are sampling (i.e. recall's values) + """ + gts = [ + dict(boxes=torch.Tensor([[10, 20, 15, 25]]), labels=torch.IntTensor([0])), + dict(boxes=torch.Tensor([]), labels=torch.IntTensor([])), + ] + preds = [ + dict(boxes=torch.Tensor([[10, 20, 15, 25]]), scores=torch.Tensor([0.9]), labels=torch.IntTensor([0])), + dict(boxes=torch.Tensor([[10, 20, 15, 25]]), scores=torch.Tensor([0.95]), labels=torch.IntTensor([0])), + ] + + metric = MeanAveragePrecision() + metric.update(preds, gts) + result = metric.compute() + assert result["map"] < 1, "MAP cannot be 1, as there is an image with no ground truth, but some predictions." + + @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") def test_error_on_wrong_input(): """Test class input validation.""" diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index 1395a014b56..bf70362dc56 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -367,6 +367,60 @@ def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor: ious = box_iou(det, gt) return ious + def __evaluate_image_gt_no_preds( + self, gt: Tensor, gt_label_mask: Tensor, area_range: Tuple[int, int], nb_iou_thrs: int + ) -> Dict[str, Any]: + """Some GT but no predictions.""" + # GTs + gt = gt[gt_label_mask] + nb_gt = len(gt) + areas = box_area(gt) + ignore_area = (areas < area_range[0]) | (areas > area_range[1]) + gt_ignore, _ = torch.sort(ignore_area.to(torch.uint8)) + gt_ignore = gt_ignore.to(torch.bool) + + # Detections + nb_det = 0 + det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device) + + return { + "dtMatches": torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device), + "gtMatches": torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device), + "dtScores": torch.zeros(nb_det, dtype=torch.bool, device=self.device), + "gtIgnore": gt_ignore, + "dtIgnore": det_ignore, + } + + def __evaluate_image_preds_no_gt( + self, det: Tensor, idx: int, det_label_mask: Tensor, max_det: int, area_range: Tuple[int, int], nb_iou_thrs: int + ) -> Dict[str, Any]: + """Some predictions but no GT.""" + # GTs + nb_gt = 0 + gt_ignore = torch.zeros(nb_gt, dtype=torch.bool, device=self.device) + + # Detections + det = det[det_label_mask] + scores = self.detection_scores[idx] + scores_filtered = scores[det_label_mask] + scores_sorted, dtind = torch.sort(scores_filtered, descending=True) + det = det[dtind] + if len(det) > max_det: + det = det[:max_det] + nb_det = len(det) + det_areas = box_area(det).to(self.device) + det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1]) + ar = det_ignore_area.reshape((1, nb_det)) + det_ignore = torch.repeat_interleave(ar, nb_iou_thrs, 0) + + return { + "dtMatches": torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device), + "gtMatches": torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device), + "dtScores": scores_sorted, + "gtIgnore": gt_ignore, + "dtIgnore": det_ignore, + } + def _evaluate_image( self, idx: int, class_id: int, area_range: Tuple[int, int], max_det: int, ious: dict ) -> Optional[dict]: @@ -388,11 +442,24 @@ def _evaluate_image( det = self.detection_boxes[idx] gt_label_mask = self.groundtruth_labels[idx] == class_id det_label_mask = self.detection_labels[idx] == class_id - if len(gt_label_mask) == 0 or len(det_label_mask) == 0: + + # No Gt and No predictions --> ignore image + if len(gt_label_mask) == 0 and len(det_label_mask) == 0: return None + + nb_iou_thrs = len(self.iou_thresholds) + + # Some GT but no predictions + if len(gt_label_mask) > 0 and len(det_label_mask) == 0: + return self.__evaluate_image_gt_no_preds(gt, gt_label_mask, area_range, nb_iou_thrs) + + # Some predictions but no GT + if len(gt_label_mask) == 0 and len(det_label_mask) >= 0: + return self.__evaluate_image_preds_no_gt(det, idx, det_label_mask, max_det, area_range, nb_iou_thrs) + gt = gt[gt_label_mask] det = det[det_label_mask] - if len(gt) == 0 and len(det) == 0: + if gt.numel() == 0 and det.numel() == 0: return None areas = box_area(gt) @@ -424,10 +491,11 @@ def _evaluate_image( for idx_iou, t in enumerate(self.iou_thresholds): for idx_det, _ in enumerate(det): m = MeanAveragePrecision._find_best_gt_match(t, gt_matches, idx_iou, gt_ignore, ious, idx_det) - if m != -1: - det_ignore[idx_iou, idx_det] = gt_ignore[m] - det_matches[idx_iou, idx_det] = 1 - gt_matches[idx_iou, m] = 1 + if m == -1: + continue + det_ignore[idx_iou, idx_det] = gt_ignore[m] + det_matches[idx_iou, idx_det] = 1 + gt_matches[idx_iou, m] = 1 # set unmatched detections outside of area range to ignore det_areas = box_area(det)