Skip to content

Commit

Permalink
[Enhance] Enable CocoMetric to get ann_file from MessageHub (#2678)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis authored Sep 8, 2023
1 parent ebed085 commit 0750eae
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 2 deletions.
8 changes: 8 additions & 0 deletions mmpose/datasets/datasets/base/base_coco_style_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from mmengine.dataset import BaseDataset, force_full_init
from mmengine.fileio import exists, get_local_path, load
from mmengine.logging import MessageHub
from mmengine.utils import is_list_of
from xtcocotools.coco import COCO

Expand Down Expand Up @@ -112,6 +113,13 @@ def __init__(self,
lazy_init=lazy_init,
max_refetch=max_refetch)

if self.test_mode:
# save the ann_file into MessageHub for CocoMetric
message = MessageHub.get_current_instance()
dataset_name = self.metainfo['dataset_name']
message.update_info_dict(
{f'{dataset_name}_ann_file': self.ann_file})

@classmethod
def _load_metainfo(cls, metainfo: dict = None) -> dict:
"""Collect meta information from the dictionary of meta.
Expand Down
7 changes: 6 additions & 1 deletion mmpose/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,12 @@ class PackPoseInputs(BaseTransform):
bbox='bboxes',
bbox_score='bbox_scores',
keypoints='keypoints',
keypoints_visible='keypoints_visible')
keypoints_visible='keypoints_visible',
# In CocoMetric, the area of predicted instances will be calculated
# using gt_instances.bbox_scales. To unsure correspondence with
# previous version, this key is preserved here.
bbox_scale='bbox_scales',
)

# items in `field_mapping_table` will be packed into
# PoseDataSample.gt_fields and converted to Tensor. These items will be
Expand Down
14 changes: 13 additions & 1 deletion mmpose/evaluation/metrics/coco_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from mmengine.evaluator import BaseMetric
from mmengine.fileio import dump, get_local_path, load
from mmengine.logging import MMLogger
from mmengine.logging import MessageHub, MMLogger, print_log
from xtcocotools.coco import COCO
from xtcocotools.cocoeval import COCOeval

Expand Down Expand Up @@ -166,6 +166,18 @@ def dataset_meta(self, dataset_meta: dict) -> None:
dataset_meta['num_keypoints'] = len(dataset_meta['sigmas'])
self._dataset_meta = dataset_meta

if self.coco is None:
message = MessageHub.get_current_instance()
ann_file = message.get_info(
f"{dataset_meta['dataset_name']}_ann_file", None)
if ann_file is not None:
with get_local_path(ann_file) as local_path:
self.coco = COCO(local_path)
print_log(
f'CocoMetric for dataset '
f"{dataset_meta['dataset_name']} has successfully "
f'loaded the annotation file from {ann_file}', 'current')

def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
Expand Down
18 changes: 18 additions & 0 deletions tests/test_evaluation/test_metrics/test_coco_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import numpy as np
from mmengine.fileio import dump, load
from mmengine.logging import MessageHub
from xtcocotools.coco import COCO

from mmpose.datasets.datasets import CocoDataset
from mmpose.datasets.datasets.utils import parse_pose_metainfo
from mmpose.evaluation.metrics import CocoMetric

Expand All @@ -21,6 +23,12 @@ def setUp(self):
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""

# during CI on github, the unit tests for datasets will save ann_file
# into MessageHub, which will influence the unit tests for CocoMetric
msg = MessageHub.get_current_instance()
msg.runtime_info.clear()

self.tmp_dir = tempfile.TemporaryDirectory()

self.ann_file_coco = 'tests/data/coco/test_coco.json'
Expand Down Expand Up @@ -664,3 +672,13 @@ def test_gt_converter(self):
self.assertTrue(
osp.isfile(
osp.join(self.tmp_dir.name, 'test_convert.keypoints.json')))

def test_get_ann_file_from_dataset(self):
_ = CocoDataset(ann_file=self.ann_file_coco, test_mode=True)
metric = CocoMetric(ann_file=None)
metric.dataset_meta = self.dataset_meta_coco
self.assertIsNotNone(metric.coco)

# clear message to avoid disturbing other tests
message = MessageHub.get_current_instance()
message.pop_info('coco_ann_file')
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
from mmengine.fileio import dump, load
from mmengine.logging import MessageHub
from xtcocotools.coco import COCO

from mmpose.datasets.datasets.utils import parse_pose_metainfo
Expand All @@ -21,6 +22,13 @@ def setUp(self):
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""

# during CI on github, the unit tests for datasets will save ann_file
# into MessageHub, which will influence the unit tests for
# CocoWholeBodyMetric
msg = MessageHub.get_current_instance()
msg.runtime_info.clear()

self.tmp_dir = tempfile.TemporaryDirectory()

self.ann_file_coco = 'tests/data/coco/test_coco_wholebody.json'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
from mmengine.fileio import load
from mmengine.logging import MessageHub
from mmengine.structures import InstanceData
from xtcocotools.coco import COCO

Expand All @@ -22,6 +23,12 @@ def setUp(self):
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""

# during CI on github, the unit tests for datasets will save ann_file
# into MessageHub, which will influence the unit tests for CocoMetric
msg = MessageHub.get_current_instance()
msg.runtime_info.clear()

self.tmp_dir = tempfile.TemporaryDirectory()

self.ann_file_coco = \
Expand Down

0 comments on commit 0750eae

Please sign in to comment.