From b1a20f952e11a792c84d99b66f8644bf27293a74 Mon Sep 17 00:00:00 2001 From: tobias-kupek-swarm Date: Fri, 29 Oct 2021 10:48:50 +0200 Subject: [PATCH 1/5] map metric - fix empty predictions --- tests/detection/test_map.py | 20 ++++++++++++++++++++ torchmetrics/detection/map.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 00fbc19848c..2b77c60ca0a 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -206,6 +206,26 @@ def test_error_on_wrong_init(): MAP(class_metrics=0) +@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") +def test_empty_preds(): + """Test empty predictions.""" + + metric = MAP() + + metric.update([dict( + boxes=torch.Tensor([[]]), + scores=torch.Tensor([]), + labels=torch.IntTensor([]), + ), + ], [ + dict( + boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), + labels=torch.IntTensor([4]), + ), + ]) + metric.compute() + + @pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") def test_error_on_wrong_input(): """Test class input validation.""" diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index 3e3b9cbb154..a5d2397de50 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -365,7 +365,7 @@ def _get_coco_format( annotations = [] annotation_id = 1 # has to start with 1, otherwise COCOEval results are wrong - boxes = [box_convert(box, in_fmt="xyxy", out_fmt="xywh") for box in boxes] + boxes = [box_convert(box, in_fmt="xyxy", out_fmt="xywh") if boxes[0].size(1) == 4 else box for box in boxes] for image_id, (image_boxes, image_labels) in enumerate(zip(boxes, labels)): image_boxes = image_boxes.cpu().tolist() image_labels = image_labels.cpu().tolist() From 3c46378f95ffd160c99b3c84e8f5564dcbd47ada Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Oct 2021 08:50:29 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/detection/test_map.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 2b77c60ca0a..00d80a479bd 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -212,17 +212,21 @@ def test_empty_preds(): metric = MAP() - metric.update([dict( - boxes=torch.Tensor([[]]), - scores=torch.Tensor([]), - labels=torch.IntTensor([]), - ), - ], [ - dict( - boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), - labels=torch.IntTensor([4]), - ), - ]) + metric.update( + [ + dict( + boxes=torch.Tensor([[]]), + scores=torch.Tensor([]), + labels=torch.IntTensor([]), + ), + ], + [ + dict( + boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), + labels=torch.IntTensor([4]), + ), + ], + ) metric.compute() From 6f375896c5592b84867a05f8c3ecbd03fbd24494 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 29 Oct 2021 11:34:02 +0200 Subject: [PATCH 3/5] Apply suggestions from code review --- tests/detection/test_map.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 00d80a479bd..5755afd2a57 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -217,13 +217,13 @@ def test_empty_preds(): dict( boxes=torch.Tensor([[]]), scores=torch.Tensor([]), - labels=torch.IntTensor([]), + labels=torch.IntTensor([]) ), ], [ dict( boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), - labels=torch.IntTensor([4]), + labels=torch.IntTensor([4]) ), ], ) From 5559fc9be3d15c4d8e80deca492bd2bbb9fcf3b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Oct 2021 09:34:32 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/detection/test_map.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 5755afd2a57..610a275c63e 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -214,17 +214,10 @@ def test_empty_preds(): metric.update( [ - dict( - boxes=torch.Tensor([[]]), - scores=torch.Tensor([]), - labels=torch.IntTensor([]) - ), + dict(boxes=torch.Tensor([[]]), scores=torch.Tensor([]), labels=torch.IntTensor([])), ], [ - dict( - boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), - labels=torch.IntTensor([4]) - ), + dict(boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), labels=torch.IntTensor([4])), ], ) metric.compute() From caff264b87b63a47c9cb801c63e8d69b5ee87cb7 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 29 Oct 2021 11:39:04 +0200 Subject: [PATCH 5/5] chlog --- CHANGELOG.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 47801a831fa..5707c02edb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,16 +23,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594)) + ## [0.6.0] - 2021-10-28 ### Added - Added audio metrics: - - Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) - - Short Term Objective Intelligibility (STOI) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) + - Perceptual Evaluation of Speech Quality (PESQ) ([#353](https://github.com/PyTorchLightning/metrics/pull/353)) + - Short Term Objective Intelligibility (STOI) ([#353](https://github.com/PyTorchLightning/metrics/pull/353)) - Added Information retrieval metrics: - - `RetrievalRPrecision` ([#577](https://github.com/PyTorchLightning/metrics/pull/577/)) + - `RetrievalRPrecision` ([#577](https://github.com/PyTorchLightning/metrics/pull/577)) - `RetrievalHitRate` ([#576](https://github.com/PyTorchLightning/metrics/pull/576)) - Added NLP metrics: - `SacreBLEUScore` ([#546](https://github.com/PyTorchLightning/metrics/pull/546))