Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option for extended summary in MeanAveragePrecision #1983

Merged
merged 8 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added warning to `MeanAveragePrecision` if too many detections are observed ([#1978](https://github.com/Lightning-AI/torchmetrics/pull/1978))


- Added argument `extended_summary` to `MeanAveragePrecision` such that precision, recall, iou can be easily returned ([#1983](https://github.com/Lightning-AI/torchmetrics/pull/1983))

### Changed

-
Expand Down
59 changes: 45 additions & 14 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import apply_to_collection
from torchmetrics.utilities.imports import (
_MATPLOTLIB_AVAILABLE,
_PYCOCOTOOLS_AVAILABLE,
Expand Down Expand Up @@ -168,7 +169,23 @@ class MeanAveragePrecision(Metric):
Thresholds on max detections per image. If set to `None` will use thresholds ``[1, 10, 100]``.
Else, please provide a list of ints.
class_metrics:
Option to enable per-class metrics for mAP and mAR_100. Has a performance impact.
Option to enable per-class metrics for mAP and mAR_100. Has a performance impact that scales linearly with
the number of classes in the dataset.
extended_summary:
Option to enable extended summary with additional metrics including IOU, precision and recall. The output
dictionary will contain the following extra key-values:

- ``ious``: a dictionary containing the IoU values for every image/class combination e.g.
``ious[(0,0)]`` would contain the IoU for image 0 and class 0. Each value is a tensor with shape
``(n,m)`` where ``n`` is the number of detections and ``m`` is the number of ground truth boxes for
that image/class combination.
- ``precision``: a tensor of shape ``(TxRxKxAxM)`` containing the precision values. Here ``T`` is the
number of IoU thresholds, ``R`` is the number of recall thresholds, ``K`` is the number of classes,
``A`` is the number of areas and ``M`` is the number of max detections per image.
- ``recall``: a tensor of shape ``(TxKxAxM)`` containing the recall values. Here ``T`` is the number of
IoU thresholds, ``K`` is the number of classes, ``A`` is the number of areas and ``M`` is the number
of max detections per image.

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand Down Expand Up @@ -250,6 +267,7 @@ def __init__(
rec_thresholds: Optional[List[float]] = None,
max_detection_thresholds: Optional[List[int]] = None,
class_metrics: bool = False,
extended_summary: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -299,6 +317,10 @@ def __init__(
raise ValueError("Expected argument `class_metrics` to be a boolean")
self.class_metrics = class_metrics

if not isinstance(extended_summary, bool):
raise ValueError("Expected argument `extended_summary` to be a boolean")
self.extended_summary = extended_summary

self.add_state("detections", default=[], dist_reduce_fx=None)
self.add_state("detection_scores", default=[], dist_reduce_fx=None)
self.add_state("detection_labels", default=[], dist_reduce_fx=None)
Expand Down Expand Up @@ -358,27 +380,35 @@ def compute(self) -> dict:
coco_target.createIndex()
coco_preds.createIndex()

self.coco_eval = COCOeval(coco_target, coco_preds, iouType=self.iou_type)
self.coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64)
self.coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64)
self.coco_eval.params.maxDets = self.max_detection_thresholds
coco_eval = COCOeval(coco_target, coco_preds, iouType=self.iou_type)
coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64)
coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64)
coco_eval.params.maxDets = self.max_detection_thresholds

coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
stats = coco_eval.stats

self.coco_eval.evaluate()
self.coco_eval.accumulate()
self.coco_eval.summarize()
stats = self.coco_eval.stats
summary = {}
if self.extended_summary:
summary = {
"ious": apply_to_collection(coco_eval.ious, np.ndarray, lambda x: torch.tensor(x, dtype=torch.float32)),
"precision": torch.tensor(coco_eval.eval["precision"]), # precision has shape (TxRxKxAxM)
"recall": torch.tensor(coco_eval.eval["recall"]), # recall has shape (TxKxAxM)
}

# if class mode is enabled, evaluate metrics per class
if self.class_metrics:
map_per_class_list = []
mar_100_per_class_list = []
for class_id in self._get_classes():
self.coco_eval.params.catIds = [class_id]
coco_eval.params.catIds = [class_id]
with contextlib.redirect_stdout(io.StringIO()):
self.coco_eval.evaluate()
self.coco_eval.accumulate()
self.coco_eval.summarize()
class_stats = self.coco_eval.stats
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
class_stats = coco_eval.stats

map_per_class_list.append(torch.tensor([class_stats[0]]))
mar_100_per_class_list.append(torch.tensor([class_stats[8]]))
Expand All @@ -405,6 +435,7 @@ def compute(self) -> dict:
"map_per_class": map_per_class_values,
"mar_100_per_class": mar_100_per_class_values,
"classes": torch.tensor(self._get_classes(), dtype=torch.int32),
**summary,
}

@staticmethod
Expand Down
51 changes: 49 additions & 2 deletions tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections import namedtuple
from copy import deepcopy
from functools import partial
from itertools import product

import numpy as np
import pytest
Expand Down Expand Up @@ -685,11 +686,11 @@ def test_for_box_format(box_format, iou_val_expected, map_val_expected):

targets = [{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}]

metric = MeanAveragePrecision(box_format=box_format, iou_thresholds=[0.2])
metric = MeanAveragePrecision(box_format=box_format, iou_thresholds=[0.2], extended_summary=True)
metric.update(predictions, targets)
result = metric.compute()
assert result["map"].item() == map_val_expected
assert round(float(metric.coco_eval.ious[(0, 0)]), 3) == iou_val_expected
assert round(float(result["ious"][(0, 0)]), 3) == iou_val_expected


@pytest.mark.parametrize("iou_type", ["bbox", "segm"])
Expand All @@ -710,3 +711,49 @@ def test_warning_on_many_detections(iou_type):
metric = MeanAveragePrecision(iou_type=iou_type)
with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"):
metric.update(preds, targets)


@pytest.mark.parametrize(
("preds", "target", "expected_iou_len", "iou_keys", "precision_shape", "recall_shape"),
[
(
[[{"boxes": torch.tensor([[0.5, 0.5, 1, 1]]), "scores": torch.tensor([1.0]), "labels": torch.tensor([0])}]],
[[{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}]],
1, # 1 image x 1 class = 1
[(0, 0)],
(10, 101, 1, 4, 3),
(10, 1, 4, 3),
),
(
_inputs.preds,
_inputs.target,
24, # 4 images x 6 classes = 24
product([0, 1, 2, 3], [0, 1, 2, 3, 4, 49]),
(10, 101, 6, 4, 3),
(10, 6, 4, 3),
),
],
)
def test_for_extended_stats(preds, target, expected_iou_len, iou_keys, precision_shape, recall_shape):
"""Test that extended stats are computed correctly."""
metric = MeanAveragePrecision(extended_summary=True)
for (
p,
t,
) in zip(preds, target):
metric.update(p, t)
result = metric.compute()

ious = result["ious"]
assert isinstance(ious, dict)
assert len(ious) == expected_iou_len
for key in ious:
assert key in iou_keys

precision = result["precision"]
assert isinstance(precision, Tensor)
assert precision.shape == precision_shape

recall = result["recall"]
assert isinstance(recall, Tensor)
assert recall.shape == recall_shape