Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Enable CocoMetric to get ann_file from MessageHub #2678

Merged
merged 7 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mmpose/datasets/datasets/base/base_coco_style_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from mmengine.dataset import BaseDataset, force_full_init
from mmengine.fileio import exists, get_local_path, load
from mmengine.logging import MessageHub
from mmengine.utils import is_list_of
from xtcocotools.coco import COCO

Expand Down Expand Up @@ -112,6 +113,13 @@ def __init__(self,
lazy_init=lazy_init,
max_refetch=max_refetch)

if self.test_mode:
# save the ann_file into MessageHub for CocoMetric
message = MessageHub.get_current_instance()
dataset_name = self.metainfo['dataset_name']
message.update_info_dict(
{f'{dataset_name}_ann_file': self.ann_file})

@classmethod
def _load_metainfo(cls, metainfo: dict = None) -> dict:
"""Collect meta information from the dictionary of meta.
Expand Down
7 changes: 6 additions & 1 deletion mmpose/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,12 @@ class PackPoseInputs(BaseTransform):
bbox='bboxes',
bbox_score='bbox_scores',
keypoints='keypoints',
keypoints_visible='keypoints_visible')
keypoints_visible='keypoints_visible',
# In CocoMetric, the area of predicted instances will be calculated
# using gt_instances.bbox_scales. To unsure correspondence with
# previous version, this key is preserved here.
bbox_scale='bbox_scales',
)

# items in `field_mapping_table` will be packed into
# PoseDataSample.gt_fields and converted to Tensor. These items will be
Expand Down
14 changes: 13 additions & 1 deletion mmpose/evaluation/metrics/coco_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from mmengine.evaluator import BaseMetric
from mmengine.fileio import dump, get_local_path, load
from mmengine.logging import MMLogger
from mmengine.logging import MessageHub, MMLogger, print_log
from xtcocotools.coco import COCO
from xtcocotools.cocoeval import COCOeval

Expand Down Expand Up @@ -166,6 +166,18 @@ def dataset_meta(self, dataset_meta: dict) -> None:
dataset_meta['num_keypoints'] = len(dataset_meta['sigmas'])
self._dataset_meta = dataset_meta

if self.coco is None:
message = MessageHub.get_current_instance()
ann_file = message.get_info(
f"{dataset_meta['dataset_name']}_ann_file", None)
if ann_file is not None:
with get_local_path(ann_file) as local_path:
self.coco = COCO(local_path)
print_log(
f'CocoMetric for dataset '
f"{dataset_meta['dataset_name']} has successfully "
f'loaded the annotation file from {ann_file}', 'current')

def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
Expand Down
18 changes: 18 additions & 0 deletions tests/test_evaluation/test_metrics/test_coco_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import numpy as np
from mmengine.fileio import dump, load
from mmengine.logging import MessageHub
from xtcocotools.coco import COCO

from mmpose.datasets.datasets import CocoDataset
from mmpose.datasets.datasets.utils import parse_pose_metainfo
from mmpose.evaluation.metrics import CocoMetric

Expand All @@ -21,6 +23,12 @@ def setUp(self):
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""

# during CI on github, the unit tests for datasets will save ann_file
# into MessageHub, which will influence the unit tests for CocoMetric
msg = MessageHub.get_current_instance()
msg.runtime_info.clear()

self.tmp_dir = tempfile.TemporaryDirectory()

self.ann_file_coco = 'tests/data/coco/test_coco.json'
Expand Down Expand Up @@ -664,3 +672,13 @@ def test_gt_converter(self):
self.assertTrue(
osp.isfile(
osp.join(self.tmp_dir.name, 'test_convert.keypoints.json')))

def test_get_ann_file_from_dataset(self):
_ = CocoDataset(ann_file=self.ann_file_coco, test_mode=True)
metric = CocoMetric(ann_file=None)
metric.dataset_meta = self.dataset_meta_coco
self.assertIsNotNone(metric.coco)

# clear message to avoid disturbing other tests
message = MessageHub.get_current_instance()
message.pop_info('coco_ann_file')
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
from mmengine.fileio import dump, load
from mmengine.logging import MessageHub
from xtcocotools.coco import COCO

from mmpose.datasets.datasets.utils import parse_pose_metainfo
Expand All @@ -21,6 +22,13 @@ def setUp(self):
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""

# during CI on github, the unit tests for datasets will save ann_file
# into MessageHub, which will influence the unit tests for
# CocoWholeBodyMetric
msg = MessageHub.get_current_instance()
msg.runtime_info.clear()

self.tmp_dir = tempfile.TemporaryDirectory()

self.ann_file_coco = 'tests/data/coco/test_coco_wholebody.json'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
from mmengine.fileio import load
from mmengine.logging import MessageHub
from mmengine.structures import InstanceData
from xtcocotools.coco import COCO

Expand All @@ -22,6 +23,12 @@ def setUp(self):
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""

# during CI on github, the unit tests for datasets will save ann_file
# into MessageHub, which will influence the unit tests for CocoMetric
msg = MessageHub.get_current_instance()
msg.runtime_info.clear()

self.tmp_dir = tempfile.TemporaryDirectory()

self.ann_file_coco = \
Expand Down