Skip to content

Commit

Permalink
[Refactor] Separate evaluation mappings from KeypointConverter (#2738)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tau-J authored Oct 9, 2023
1 parent ccb4d8d commit 0549504
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 87 deletions.
38 changes: 0 additions & 38 deletions mmpose/datasets/transforms/converting.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,44 +144,6 @@ def transform(self, results: dict) -> dict:

return results

def transform_sigmas(self, sigmas: Union[List, np.ndarray]):
"""Transforms the sigmas based on the mapping."""
list_input = False
if isinstance(sigmas, list):
sigmas = np.array(sigmas)
list_input = True

new_sigmas = np.ones(self.num_keypoints, dtype=sigmas.dtype)
new_sigmas[self.target_index] = sigmas[self.source_index]

if list_input:
new_sigmas = new_sigmas.tolist()

return new_sigmas

def transform_ann(self, ann_info: Union[dict, list]):
"""Transforms the annotations based on the mapping."""

list_input = True
if not isinstance(ann_info, list):
ann_info = [ann_info]
list_input = False

for ann in ann_info:
if 'keypoints' in ann:
keypoints = np.array(ann['keypoints']).reshape(-1, 3)
new_keypoints = np.zeros((self.num_keypoints, 3),
dtype=keypoints.dtype)
new_keypoints[self.target_index] = keypoints[self.source_index]
ann['keypoints'] = new_keypoints.reshape(-1).tolist()
if 'num_keypoints' in ann:
ann['num_keypoints'] = self.num_keypoints

if not list_input:
ann_info = ann_info[0]

return ann_info

def __repr__(self) -> str:
"""print the basic information of the transform.
Expand Down
3 changes: 2 additions & 1 deletion mmpose/evaluation/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
multilabel_classification_accuracy,
pose_pck_accuracy, simcc_pck_accuracy)
from .nms import nms, nms_torch, oks_nms, soft_oks_nms
from .transforms import transform_ann, transform_pred, transform_sigmas

__all__ = [
'keypoint_pck_accuracy', 'keypoint_auc', 'keypoint_nme', 'keypoint_epe',
'pose_pck_accuracy', 'multilabel_classification_accuracy',
'simcc_pck_accuracy', 'nms', 'oks_nms', 'soft_oks_nms', 'keypoint_mpjpe',
'nms_torch'
'nms_torch', 'transform_ann', 'transform_sigmas', 'transform_pred'
]
99 changes: 99 additions & 0 deletions mmpose/evaluation/functional/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union

import numpy as np


def transform_sigmas(sigmas: Union[List, np.ndarray], num_keypoints: int,
mapping: Union[List[Tuple[int, int]], List[Tuple[Tuple,
int]]]):
"""Transforms the sigmas based on the mapping."""
if len(mapping):
source_index, target_index = map(list, zip(*mapping))
else:
source_index, target_index = [], []

list_input = False
if isinstance(sigmas, list):
sigmas = np.array(sigmas)
list_input = True

new_sigmas = np.ones(num_keypoints, dtype=sigmas.dtype)
new_sigmas[target_index] = sigmas[source_index]

if list_input:
new_sigmas = new_sigmas.tolist()

return new_sigmas


def transform_ann(ann_info: Union[dict, list], num_keypoints: int,
mapping: Union[List[Tuple[int, int]], List[Tuple[Tuple,
int]]]):
"""Transforms COCO-format annotations based on the mapping."""
if len(mapping):
source_index, target_index = map(list, zip(*mapping))
else:
source_index, target_index = [], []

list_input = True
if not isinstance(ann_info, list):
ann_info = [ann_info]
list_input = False

for each in ann_info:
if 'keypoints' in each:
keypoints = np.array(each['keypoints'])

C = 3 # COCO-format: x, y, score
keypoints = keypoints.reshape(-1, C)
new_keypoints = np.zeros((num_keypoints, C), dtype=keypoints.dtype)
new_keypoints[target_index] = keypoints[source_index]
each['keypoints'] = new_keypoints.reshape(-1).tolist()

if 'num_keypoints' in each:
each['num_keypoints'] = num_keypoints

if not list_input:
ann_info = ann_info[0]

return ann_info


def transform_pred(pred_info: Union[dict, list], num_keypoints: int,
mapping: Union[List[Tuple[int, int]], List[Tuple[Tuple,
int]]]):
"""Transforms predictions based on the mapping."""
if len(mapping):
source_index, target_index = map(list, zip(*mapping))
else:
source_index, target_index = [], []

list_input = True
if not isinstance(pred_info, list):
pred_info = [pred_info]
list_input = False

for each in pred_info:
if 'keypoints' in each:
keypoints = np.array(each['keypoints'])

N, _, C = keypoints.shape
new_keypoints = np.zeros((N, num_keypoints, C),
dtype=keypoints.dtype)
new_keypoints[:, target_index] = keypoints[:, source_index]
each['keypoints'] = new_keypoints

keypoint_scores = np.array(each['keypoint_scores'])
new_scores = np.zeros((N, num_keypoints),
dtype=keypoint_scores.dtype)
new_scores[:, target_index] = keypoint_scores[:, source_index]
each['keypoint_scores'] = new_scores

if 'num_keypoints' in each:
each['num_keypoints'] = num_keypoints

if not list_input:
pred_info = pred_info[0]

return pred_info
51 changes: 31 additions & 20 deletions mmpose/evaluation/metrics/coco_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from xtcocotools.coco import COCO
from xtcocotools.cocoeval import COCOeval

from mmpose.registry import METRICS, TRANSFORMS
from mmpose.registry import METRICS
from mmpose.structures.bbox import bbox_xyxy2xywh
from ..functional import oks_nms, soft_oks_nms
from ..functional import (oks_nms, soft_oks_nms, transform_ann, transform_pred,
transform_sigmas)


@METRICS.register_module()
Expand Down Expand Up @@ -73,10 +74,12 @@ class CocoMetric(BaseMetric):
test submission when the ground truth annotations are absent. If
set to ``True``, ``outfile_prefix`` should specify the path to
store the output results. Defaults to ``False``
pred_converter (dict, optional): Config dictionary for the prediction
converter. The dictionary has the same parameters as
'KeypointConverter'. Defaults to None.
gt_converter (dict, optional): Config dictionary for the ground truth
converter. The dictionary must contain the key 'type' set to
'KeypointConverter' to indicate the type of ground truth converter
to be used. Defaults to None.
converter. The dictionary has the same parameters as
'KeypointConverter'. Defaults to None.
outfile_prefix (str | None): The prefix of json files. It includes
the file path and the prefix of filename, e.g., ``'a/b/prefix'``.
If not specified, a temp file will be created. Defaults to ``None``
Expand All @@ -99,7 +102,8 @@ def __init__(self,
nms_mode: str = 'oks_nms',
nms_thr: float = 0.9,
format_only: bool = False,
gt_converter: Optional[dict] = None,
pred_converter: Dict = None,
gt_converter: Dict = None,
outfile_prefix: Optional[str] = None,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
Expand Down Expand Up @@ -145,13 +149,8 @@ def __init__(self,

self.format_only = format_only
self.outfile_prefix = outfile_prefix

if gt_converter is not None:
assert gt_converter.get('type', None) == 'KeypointConverter', \
'the type of `gt_converter` must be \'KeypointConverter\''
self.gt_converter = TRANSFORMS.build(gt_converter)
else:
self.gt_converter = None
self.pred_converter = pred_converter
self.gt_converter = gt_converter

@property
def dataset_meta(self) -> Optional[dict]:
Expand All @@ -162,8 +161,9 @@ def dataset_meta(self) -> Optional[dict]:
def dataset_meta(self, dataset_meta: dict) -> None:
"""Set the dataset meta info to the metric."""
if self.gt_converter is not None:
dataset_meta['sigmas'] = self.gt_converter.transform_sigmas(
dataset_meta['sigmas'])
dataset_meta['sigmas'] = transform_sigmas(
dataset_meta['sigmas'], self.gt_converter['num_keypoints'],
self.gt_converter['mapping'])
dataset_meta['num_keypoints'] = len(dataset_meta['sigmas'])
self._dataset_meta = dataset_meta

Expand Down Expand Up @@ -394,19 +394,28 @@ def compute_metrics(self, results: list) -> Dict[str, float]:
self.coco = COCO(coco_json_path)
if self.gt_converter is not None:
for id_, ann in self.coco.anns.items():
self.coco.anns[id_] = self.gt_converter.transform_ann(ann)
self.coco.anns[id_] = transform_ann(
ann, self.gt_converter['num_keypoints'],
self.gt_converter['mapping'])

kpts = defaultdict(list)

# group the preds by img_id
for pred in preds:
img_id = pred['img_id']
for idx in range(len(pred['keypoints'])):

if self.pred_converter is not None:
pred = transform_pred(pred,
self.pred_converter['num_keypoints'],
self.pred_converter['mapping'])

for idx, keypoints in enumerate(pred['keypoints']):

instance = {
'id': pred['id'],
'img_id': pred['img_id'],
'category_id': pred['category_id'],
'keypoints': pred['keypoints'][idx],
'keypoints': keypoints,
'keypoint_scores': pred['keypoint_scores'][idx],
'bbox_score': pred['bbox_scores'][idx],
}
Expand All @@ -417,7 +426,6 @@ def compute_metrics(self, results: list) -> Dict[str, float]:
instance['area'] = pred['areas'][idx]
else:
# use keypoint to calculate bbox and get area
keypoints = pred['keypoints'][idx]
area = (
np.max(keypoints[:, 0]) - np.min(keypoints[:, 0])) * (
np.max(keypoints[:, 1]) - np.min(keypoints[:, 1]))
Expand All @@ -431,7 +439,10 @@ def compute_metrics(self, results: list) -> Dict[str, float]:
# score the prediction results according to `score_mode`
# and perform NMS according to `nms_mode`
valid_kpts = defaultdict(list)
num_keypoints = self.dataset_meta['num_keypoints']
if self.pred_converter is not None:
num_keypoints = self.pred_converter['num_keypoints']
else:
num_keypoints = self.dataset_meta['num_keypoints']
for img_id, instances in kpts.items():
for instance in instances:
# concatenate the keypoint coordinates and scores
Expand Down
29 changes: 1 addition & 28 deletions tests/test_datasets/test_transforms/test_converting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy

from unittest import TestCase

import numpy as np
Expand Down Expand Up @@ -107,30 +107,3 @@ def test_transform(self):
self.assertTrue(
(results['keypoints_visible'][:, target_index, 0] ==
self.data_info['keypoints_visible'][:, source_index]).all())

def test_transform_sigmas(self):

mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
transform = KeypointConverter(num_keypoints=5, mapping=mapping)
sigmas = np.random.rand(17)
new_sigmas = transform.transform_sigmas(sigmas)
self.assertEqual(len(new_sigmas), 5)
for i, j in mapping:
self.assertEqual(sigmas[i], new_sigmas[j])

def test_transform_ann(self):
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
transform = KeypointConverter(num_keypoints=5, mapping=mapping)

ann_info = dict(
num_keypoints=17,
keypoints=np.random.randint(3, size=(17 * 3, )).tolist())
ann_info_copy = deepcopy(ann_info)

_ = transform.transform_ann(ann_info)

self.assertEqual(ann_info['num_keypoints'], 5)
self.assertEqual(len(ann_info['keypoints']), 15)
for i, j in mapping:
self.assertListEqual(ann_info_copy['keypoints'][i * 3:i * 3 + 3],
ann_info['keypoints'][j * 3:j * 3 + 3])
56 changes: 56 additions & 0 deletions tests/test_evaluation/test_functional/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase

import numpy as np

from mmpose.evaluation.functional import (transform_ann, transform_pred,
transform_sigmas)


class TestKeypointEval(TestCase):

def test_transform_sigmas(self):

mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
num_keypoints = 5
sigmas = np.random.rand(17)
new_sigmas = transform_sigmas(sigmas, num_keypoints, mapping)
self.assertEqual(len(new_sigmas), 5)
for i, j in mapping:
self.assertEqual(sigmas[i], new_sigmas[j])

def test_transform_ann(self):
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
num_keypoints = 5

kpt_info = dict(
num_keypoints=17,
keypoints=np.random.randint(3, size=(17 * 3, )).tolist())
kpt_info_copy = deepcopy(kpt_info)

_ = transform_ann(kpt_info, num_keypoints, mapping)

self.assertEqual(kpt_info['num_keypoints'], 5)
self.assertEqual(len(kpt_info['keypoints']), 15)
for i, j in mapping:
self.assertListEqual(kpt_info_copy['keypoints'][i * 3:i * 3 + 3],
kpt_info['keypoints'][j * 3:j * 3 + 3])

def test_transform_pred(self):
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
num_keypoints = 5

kpt_info = dict(
num_keypoints=17,
keypoints=np.random.randint(3, size=(
1,
17,
3,
)),
keypoint_scores=np.ones((1, 17)))

_ = transform_pred(kpt_info, num_keypoints, mapping)

self.assertEqual(kpt_info['num_keypoints'], 5)
self.assertEqual(len(kpt_info['keypoints']), 1)

0 comments on commit 0549504

Please sign in to comment.