Skip to content

Commit

Permalink
MeanAveragePrecision: Skip box conversion if no boxes are present (#1097
Browse files Browse the repository at this point in the history
)

* Skip box conversion if no boxes are present

The `box_convert` function from torchvision expects the input to be a
Tensor[N, 4], where N > 0. Should N == 0 and in_fmt != out_fmt, `unbind`
will error out on the boxes tensor during the conversion process.

The workaround is, therefore to skip the box conversion if boxes is an
empty tensor.

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
(cherry picked from commit d29919c)
  • Loading branch information
kouyk authored and Borda committed Jun 20, 2022
1 parent 2ce4451 commit 5c4a8c2
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 39 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed mAP calculation for areas with 0 predictions ([#1080](https://github.com/PyTorchLightning/metrics/pull/1080))


- Fixed bug where avg precision state and auroc state was not merge when using MetricCollections ([#1086](https://github.com/PyTorchLightning/metrics/pull/1086))


- Skip box conversion if no boxes are present in `MeanAveragePrecision` ([#1097](https://github.com/PyTorchLightning/metrics/pull/1097))


## [0.9.1] - 2022-06-08

### Added
Expand Down
94 changes: 56 additions & 38 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
],
)


_inputs = Input(
preds=[
[
Expand Down Expand Up @@ -299,7 +298,6 @@ class TestMAP(MetricTester):

@pytest.mark.parametrize("ddp", [False, True])
def test_map_bbox(self, compute_on_cpu, ddp):

"""Test modular implementation for correctness."""
self.run_class_metric_test(
ddp=ddp,
Expand Down Expand Up @@ -344,12 +342,8 @@ def test_empty_preds():
metric = MeanAveragePrecision()

metric.update(
[
dict(boxes=Tensor([]), scores=Tensor([]), labels=IntTensor([])),
],
[
dict(boxes=Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=IntTensor([4])),
],
[dict(boxes=Tensor([]), scores=Tensor([]), labels=IntTensor([]))],
[dict(boxes=Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=IntTensor([4]))],
)
metric.compute()

Expand All @@ -360,16 +354,56 @@ def test_empty_ground_truths():
metric = MeanAveragePrecision()

metric.update(
[
dict(
boxes=Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]),
scores=Tensor([0.5]),
labels=IntTensor([4]),
),
],
[
dict(boxes=Tensor([]), labels=IntTensor([])),
],
[dict(boxes=Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), scores=Tensor([0.5]), labels=IntTensor([4]))],
[dict(boxes=Tensor([]), labels=IntTensor([]))],
)
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_empty_ground_truths_xywh():
"""Test empty ground truths in xywh format."""
metric = MeanAveragePrecision(box_format="xywh")

metric.update(
[dict(boxes=Tensor([[214.1500, 41.2900, 348.2600, 243.7800]]), scores=Tensor([0.5]), labels=IntTensor([4]))],
[dict(boxes=Tensor([]), labels=IntTensor([]))],
)
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_empty_preds_xywh():
"""Test empty predictions in xywh format."""
metric = MeanAveragePrecision(box_format="xywh")

metric.update(
[dict(boxes=Tensor([]), scores=Tensor([]), labels=IntTensor([]))],
[dict(boxes=Tensor([[214.1500, 41.2900, 348.2600, 243.7800]]), labels=IntTensor([4]))],
)
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_empty_ground_truths_cxcywh():
"""Test empty ground truths in cxcywh format."""
metric = MeanAveragePrecision(box_format="cxcywh")

metric.update(
[dict(boxes=Tensor([[388.2800, 163.1800, 348.2600, 243.7800]]), scores=Tensor([0.5]), labels=IntTensor([4]))],
[dict(boxes=Tensor([]), labels=IntTensor([]))],
)
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
def test_empty_preds_cxcywh():
"""Test empty predictions in cxcywh format."""
metric = MeanAveragePrecision(box_format="cxcywh")

metric.update(
[dict(boxes=Tensor([]), scores=Tensor([]), labels=IntTensor([]))],
[dict(boxes=Tensor([[388.2800, 163.1800, 348.2600, 243.7800]]), labels=IntTensor([4]))],
)
metric.compute()

Expand Down Expand Up @@ -467,16 +501,8 @@ def test_segm_iou_empty_gt_mask():
metric = MeanAveragePrecision(iou_type="segm")

metric.update(
[
dict(
masks=torch.randint(0, 1, (1, 10, 10)).bool(),
scores=Tensor([0.5]),
labels=IntTensor([4]),
),
],
[
dict(masks=Tensor([]), labels=IntTensor([])),
],
[dict(masks=torch.randint(0, 1, (1, 10, 10)).bool(), scores=Tensor([0.5]), labels=IntTensor([4]))],
[dict(masks=Tensor([]), labels=IntTensor([]))],
)

metric.compute()
Expand All @@ -488,16 +514,8 @@ def test_segm_iou_empty_pred_mask():
metric = MeanAveragePrecision(iou_type="segm")

metric.update(
[
dict(
masks=torch.BoolTensor([]),
scores=Tensor([]),
labels=IntTensor([]),
),
],
[
dict(masks=torch.randint(0, 1, (1, 10, 10)).bool(), labels=IntTensor([4])),
],
[dict(masks=torch.BoolTensor([]), scores=Tensor([]), labels=IntTensor([]))],
[dict(masks=torch.randint(0, 1, (1, 10, 10)).bool(), labels=IntTensor([4]))],
)

metric.compute()
Expand Down
3 changes: 2 additions & 1 deletion torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,8 @@ def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]:

if self.iou_type == "bbox":
boxes = _fix_empty_tensors(item["boxes"])
boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy")
if boxes.numel() > 0:
boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy")
return boxes
elif self.iou_type == "segm":
masks = []
Expand Down

0 comments on commit 5c4a8c2

Please sign in to comment.