Skip to content

Commit

Permalink
[Fix] Fix the incorrect labels for training vis_head with combined da…
Browse files Browse the repository at this point in the history
…tasets (#2550)
  • Loading branch information
Ben-Louis authored Jul 27, 2023
1 parent 93c5723 commit abe09d3
Show file tree
Hide file tree
Showing 14 changed files with 315 additions and 73 deletions.
6 changes: 6 additions & 0 deletions configs/body_2d_keypoint/topdown_heatmap/coco/resnet_coco.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,9 @@ Results on COCO val2017 with detector having human AP of 56.4 on COCO val2017 da
| [pose_resnet_101](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res101_8xb32-210e_coco-384x288.py) | 384x288 | 0.749 | 0.906 | 0.817 | 0.799 | 0.941 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res101_8xb64-210e_coco-256x192-065d3625_20220926.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res101_8xb64-210e_coco-256x192_20220926.log) |
| [pose_resnet_152](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-256x192.py) | 256x192 | 0.736 | 0.904 | 0.818 | 0.791 | 0.942 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-256x192-0345f330_20220928.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-256x192_20220928.log) |
| [pose_resnet_152](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-384x288.py) | 384x288 | 0.750 | 0.908 | 0.821 | 0.800 | 0.942 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-384x288-7fbb906f_20220927.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-384x288_20220927.log) |

The following model is equipped with a visibility prediction head and has been trained using COCO and AIC datasets.

| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt | log |
| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: |
| [pose_resnet_50](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm-vis_res50_8xb64-210e_coco-aic-256x192-merge.py) | 256x192 | 0.729 | 0.900 | 0.807 | 0.783 | 0.938 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm-vis_res50_8xb64-210e_coco-aic-256x192-merge-21815b2c_20230726.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res50_8xb64-210e_coco-256x192_20220923.log) |
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
_base_ = ['../../../_base_/default_runtime.py']

# runtime
train_cfg = dict(max_epochs=210, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=5e-4,
))

# learning policy
param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=210,
milestones=[170, 200],
gamma=0.1,
by_epoch=True)
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)

# hooks
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))

# codec settings
codec = dict(
type='MSRAHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2)

# model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='ResNet',
depth=50,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
),
head=dict(
type='VisPredictHead',
loss=dict(
type='BCELoss',
use_target_weight=True,
use_sigmoid=True,
loss_weight=1e-3,
),
pose_cfg=dict(
type='HeatmapHead',
in_channels=2048,
out_channels=17,
loss=dict(type='KeypointMSELoss', use_target_weight=True),
decoder=codec)),
test_cfg=dict(
flip_test=True,
flip_mode='heatmap',
shift_heatmap=True,
))

# base dataset settings
dataset_type = 'CocoDataset'
data_mode = 'topdown'
data_root = 'data/coco/'

# pipelines
train_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(type='RandomBBoxTransform'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]

# train datasets
dataset_coco = dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/person_keypoints_train2017.json',
data_prefix=dict(img='train2017/'),
pipeline=[],
)

dataset_aic = dict(
type='AicDataset',
data_root='data/aic/',
data_mode=data_mode,
ann_file='annotations/aic_train.json',
data_prefix=dict(img='ai_challenger_keypoint_train_20170902/'
'keypoint_train_images_20170902/'),
pipeline=[
dict(
type='KeypointConverter',
num_keypoints=17,
mapping=[
(0, 6),
(1, 8),
(2, 10),
(3, 5),
(4, 7),
(5, 9),
(6, 12),
(7, 14),
(8, 16),
(9, 11),
(10, 13),
(11, 15),
])
],
)

# data loaders
train_dataloader = dict(
batch_size=64,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='CombinedDataset',
metainfo=dict(from_file='configs/_base_/datasets/coco.py'),
datasets=[dataset_coco, dataset_aic],
pipeline=train_pipeline,
test_mode=False,
))
val_dataloader = dict(
batch_size=32,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/person_keypoints_val2017.json',
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader

# evaluators
val_evaluator = dict(
type='CocoMetric',
# score_mode='bbox',
ann_file=data_root + 'annotations/person_keypoints_val2017.json')
test_evaluator = val_evaluator
24 changes: 24 additions & 0 deletions docs/en/advanced_guides/implement_new_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,27 @@ class YourNewHead(BaseHead):
```

Finally, please remember to import your new prediction head in `[__init__.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/__init__.py)` .

### Head with Keypoints Visibility Prediction

Many models predict keypoint visibility based on confidence in coordinate predictions. However, this approach is suboptimal. Our [`VisPredictHead`](https://github.com/open-mmlab/mmpose/blob/dev-1.x/mmpose/models/heads/hybrid_heads/vis_head.py) wrapper enables heads to directly predict keypoint visibility from ground truth training data, improving reliability. To add visibility prediction, wrap your head module with VisPredictHead in the config file.

```python
model=dict(
...
head=dict(
type='VisPredictHead',
loss=dict(
type='BCELoss',
use_target_weight=True,
use_sigmoid=True,
loss_weight=1e-3),
pose_cfg=dict(
type='HeatmapHead',
in_channels=2048,
out_channels=17,
loss=dict(type='KeypointMSELoss', use_target_weight=True),
decoder=codec)),
...
)
```
24 changes: 24 additions & 0 deletions docs/zh_cn/advanced_guides/implement_new_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,27 @@ class YourNewHead(BaseHead):
```

最后,请记得在 [heads/\_\_init\_\_.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/__init__.py) 中导入你的新预测头部。

### 关键点可见性预测头部

许多模型都是通过对关键点坐标预测的置信度来判断关键点的可见性的。然而,这种解决方案并非最优。我们提供了一个叫做 [`VisPredictHead`](https://github.com/open-mmlab/mmpose/blob/dev-1.x/mmpose/models/heads/hybrid_heads/vis_head.py) 的头部模块包装器,使得头部模块能够直接预测关键点的可见性。这个包装器是用训练数据中关键点可见性真值来训练的。因此,其预测会更加可靠。用户可以通过修改配置文件来对自己的头部模块加上这个包装器。下面是一个例子:

```python
model=dict(
...
head=dict(
type='VisPredictHead',
loss=dict(
type='BCELoss',
use_target_weight=True,
use_sigmoid=True,
loss_weight=1e-3),
pose_cfg=dict(
type='HeatmapHead',
in_channels=2048,
out_channels=17,
loss=dict(type='KeypointMSELoss', use_target_weight=True),
decoder=codec)),
...
)
```
20 changes: 8 additions & 12 deletions mmpose/datasets/transforms/common_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _random_select_half_body(self, keypoints_visible: np.ndarray,
Args:
keypoints_visible (np.ndarray, optional): The visibility of
keypoints in shape (N, K, 1).
keypoints in shape (N, K, 1) or (N, K, 2).
upper_body_ids (list): The list of upper body keypoint indices
lower_body_ids (list): The list of lower body keypoint indices
Expand All @@ -349,6 +349,9 @@ def _random_select_half_body(self, keypoints_visible: np.ndarray,
of each instance. ``None`` means not applying half-body transform.
"""

if keypoints_visible.ndim == 3:
keypoints_visible = keypoints_visible[..., 0]

half_body_ids = []

for visible in keypoints_visible:
Expand Down Expand Up @@ -390,7 +393,6 @@ def transform(self, results: Dict) -> Optional[dict]:
Returns:
dict: The result dict.
"""

half_body_ids = self._random_select_half_body(
keypoints_visible=results['keypoints_visible'],
upper_body_ids=results['upper_body_ids'],
Expand Down Expand Up @@ -952,6 +954,10 @@ def transform(self, results: Dict) -> Optional[dict]:
' \'keypoints\' in the results.')

keypoints_visible = results['keypoints_visible']
if keypoints_visible.ndim == 3 and keypoints_visible.shape[2] == 2:
keypoints_visible, keypoints_visible_weights = \
keypoints_visible[..., 0], keypoints_visible[..., 1]
results['keypoints_visible_weights'] = keypoints_visible_weights

# Encoded items from the encoder(s) will be updated into the results.
# Please refer to the document of the specific codec for details about
Expand Down Expand Up @@ -1031,16 +1037,6 @@ def transform(self, results: Dict) -> Optional[dict]:

results.update(encoded)

if results.get('keypoint_weights', None) is not None:
results['transformed_keypoints_visible'] = results[
'keypoint_weights']
elif results.get('keypoints', None) is not None:
results['transformed_keypoints_visible'] = results[
'keypoints_visible']
else:
raise ValueError('GenerateTarget requires \'keypoint_weights\' or'
' \'keypoints_visible\' in the results.')

return results

def __repr__(self) -> str:
Expand Down
15 changes: 12 additions & 3 deletions mmpose/datasets/transforms/converting.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,18 @@ def __init__(self, num_keypoints: int,
self.interpolation = interpolation

def transform(self, results: dict) -> dict:
"""Transforms the keypoint results to match the target keypoints."""
num_instances = results['keypoints'].shape[0]

# Initialize output arrays
keypoints = np.zeros((num_instances, self.num_keypoints, 2))
keypoints_visible = np.zeros((num_instances, self.num_keypoints))

# When paired source_indexes are input,
# perform interpolation with self.source_index and self.source_index2
# Create a mask to weight visibility loss
keypoints_visible_weights = keypoints_visible.copy()
keypoints_visible_weights[:, self.target_index] = 1.0

# Interpolate keypoints if pairs of source indexes provided
if self.interpolation:
keypoints[:, self.target_index] = 0.5 * (
results['keypoints'][:, self.source_index] +
Expand All @@ -102,15 +107,19 @@ def transform(self, results: dict) -> dict:
keypoints_visible[:, self.target_index] = results[
'keypoints_visible'][:, self.source_index] * \
results['keypoints_visible'][:, self.source_index2]

# Otherwise just copy from the source index
else:
keypoints[:,
self.target_index] = results['keypoints'][:, self.
source_index]
keypoints_visible[:, self.target_index] = results[
'keypoints_visible'][:, self.source_index]

# Update the results dict
results['keypoints'] = keypoints
results['keypoints_visible'] = keypoints_visible
results['keypoints_visible'] = np.stack(
[keypoints_visible, keypoints_visible_weights], axis=2)
return results

def __repr__(self) -> str:
Expand Down
6 changes: 1 addition & 5 deletions mmpose/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class PackPoseInputs(BaseTransform):
'keypoint_y_labels': 'keypoint_y_labels',
'keypoint_weights': 'keypoint_weights',
'instance_coords': 'instance_coords',
'transformed_keypoints_visible': 'keypoints_visible',
'keypoints_visible_weights': 'keypoints_visible_weights'
}

# items in `field_mapping_table` will be packed into
Expand Down Expand Up @@ -195,10 +195,6 @@ def transform(self, results: dict) -> dict:
if self.pack_transformed and 'transformed_keypoints' in results:
gt_instances.set_field(results['transformed_keypoints'],
'transformed_keypoints')
if self.pack_transformed and \
'transformed_keypoints_visible' in results:
gt_instances.set_field(results['transformed_keypoints_visible'],
'transformed_keypoints_visible')

data_sample.gt_instances = gt_instances

Expand Down
Loading

0 comments on commit abe09d3

Please sign in to comment.