Skip to content

Commit ccb4d8d

Browse files
authored
[Refactor] Align test accuracy for AE (#2737)
1 parent e8ac800 commit ccb4d8d

File tree

6 files changed

+248
-97
lines changed

6 files changed

+248
-97
lines changed

configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
input_size=(512, 512),
3737
heatmap_size=(128, 128),
3838
sigma=2,
39+
decode_topk=30,
40+
decode_center_shift=0.5,
3941
decode_keypoint_order=[
4042
0, 1, 2, 3, 4, 5, 6, 11, 12, 7, 8, 9, 10, 13, 14, 15, 16
4143
],
@@ -97,7 +99,7 @@
9799
test_cfg=dict(
98100
multiscale_test=False,
99101
flip_test=True,
100-
shift_heatmap=True,
102+
shift_heatmap=False,
101103
restore_heatmap_size=True,
102104
align_corners=False))
103105

@@ -113,9 +115,14 @@
113115
dict(
114116
type='BottomupResize',
115117
input_size=codec['input_size'],
116-
size_factor=32,
118+
size_factor=64,
117119
resize_mode='expand'),
118-
dict(type='PackPoseInputs')
120+
dict(
121+
type='PackPoseInputs',
122+
meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape',
123+
'img_shape', 'input_size', 'input_center', 'input_scale',
124+
'flip', 'flip_direction', 'flip_indices', 'raw_ann_info',
125+
'skeleton_links'))
119126
]
120127

121128
# data loaders
@@ -154,6 +161,6 @@
154161
type='CocoMetric',
155162
ann_file=data_root + 'annotations/person_keypoints_val2017.json',
156163
nms_mode='none',
157-
score_mode='keypoint',
164+
score_mode='bbox',
158165
)
159166
test_evaluator = val_evaluator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
<!-- [ALGORITHM] -->
2+
3+
<details>
4+
<summary align="right"><a href="https://arxiv.org/abs/1611.05424">Associative Embedding (NIPS'2017)</a></summary>
5+
6+
```bibtex
7+
@inproceedings{newell2017associative,
8+
title={Associative embedding: End-to-end learning for joint detection and grouping},
9+
author={Newell, Alejandro and Huang, Zhiao and Deng, Jia},
10+
booktitle={Advances in neural information processing systems},
11+
pages={2277--2287},
12+
year={2017}
13+
}
14+
```
15+
16+
</details>
17+
18+
<!-- [ALGORITHM] -->
19+
20+
<details>
21+
<summary align="right"><a href="http://openaccess.thecvf.com/content_CVPR_2019/html/Sun_Deep_High-Resolution_Representation_Learning_for_Human_Pose_Estimation_CVPR_2019_paper.html">HRNet (CVPR'2019)</a></summary>
22+
23+
```bibtex
24+
@inproceedings{sun2019deep,
25+
title={Deep high-resolution representation learning for human pose estimation},
26+
author={Sun, Ke and Xiao, Bin and Liu, Dong and Wang, Jingdong},
27+
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
28+
pages={5693--5703},
29+
year={2019}
30+
}
31+
```
32+
33+
</details>
34+
35+
<!-- [DATASET] -->
36+
37+
<details>
38+
<summary align="right"><a href="https://link.springer.com/chapter/10.1007/978-3-319-10602-1_48">COCO (ECCV'2014)</a></summary>
39+
40+
```bibtex
41+
@inproceedings{lin2014microsoft,
42+
title={Microsoft coco: Common objects in context},
43+
author={Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence},
44+
booktitle={European conference on computer vision},
45+
pages={740--755},
46+
year={2014},
47+
organization={Springer}
48+
}
49+
```
50+
51+
</details>
52+
53+
Results on COCO val2017 without multi-scale test
54+
55+
| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt | log |
56+
| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: |
57+
| [HRNet-w32](/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py) | 512x512 | 0.656 | 0.864 | 0.719 | 0.711 | 0.893 | [ckpt](https://download.openmmlab.com/mmpose/bottom_up/hrnet_w32_coco_512x512-bcb8c247_20200816.pth) | [log](https://download.openmmlab.com/mmpose/bottom_up/hrnet_w32_coco_512x512_20200816.log.json) |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
Collections:
2+
- Name: AE
3+
Paper:
4+
Title: "Associative embedding: End-to-end learning for joint detection and grouping"
5+
URL: https://arxiv.org/abs/1611.05424
6+
README: https://github.com/open-mmlab/mmpose/blob/main/docs/src/papers/algorithms/associative_embedding.md
7+
Models:
8+
- Config: configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py
9+
In Collection: AE
10+
Metadata:
11+
Architecture:
12+
- AE
13+
- HRNet
14+
Training Data: COCO
15+
Name: ae_hrnet-w32_8xb24-300e_coco-512x512
16+
Results:
17+
- Dataset: COCO
18+
Metrics:
19+
AP: 0.656
20+
AP@0.5: 0.864
21+
AP@0.75: 0.719
22+
AR: 0.711
23+
AR@0.5: 0.893
24+
Task: Body 2D Keypoint
25+
Weights: https://download.openmmlab.com/mmpose/bottom_up/hrnet_w32_coco_512x512-bcb8c247_20200816.pth

mmpose/codecs/associative_embedding.py

+94-84
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from collections import namedtuple
32
from itertools import product
43
from typing import Any, List, Optional, Tuple
54

@@ -16,6 +15,21 @@
1615
refine_keypoints_dark_udp)
1716

1817

18+
def _py_max_match(scores):
19+
"""Apply munkres algorithm to get the best match.
20+
21+
Args:
22+
scores(np.ndarray): cost matrix.
23+
24+
Returns:
25+
np.ndarray: best match.
26+
"""
27+
m = Munkres()
28+
tmp = m.compute(scores)
29+
tmp = np.array(tmp).astype(int)
30+
return tmp
31+
32+
1933
def _group_keypoints_by_tags(vals: np.ndarray,
2034
tags: np.ndarray,
2135
locs: np.ndarray,
@@ -54,89 +68,78 @@ def _group_keypoints_by_tags(vals: np.ndarray,
5468
np.ndarray: grouped keypoints in shape (G, K, D+1), where the last
5569
dimenssion is the concatenated keypoint coordinates and scores.
5670
"""
71+
72+
tag_k, loc_k, val_k = tags, locs, vals
5773
K, M, D = locs.shape
5874
assert vals.shape == tags.shape[:2] == (K, M)
5975
assert len(keypoint_order) == K
6076

61-
# Build Munkres instance
62-
munkres = Munkres()
63-
64-
# Build a group pool, each group contains the keypoints of an instance
65-
groups = []
77+
default_ = np.zeros((K, 3 + tag_k.shape[2]), dtype=np.float32)
6678

67-
Group = namedtuple('Group', field_names=['kpts', 'scores', 'tag_list'])
79+
joint_dict = {}
80+
tag_dict = {}
81+
for i in range(K):
82+
idx = keypoint_order[i]
6883

69-
def _init_group():
70-
"""Initialize a group, which is composed of the keypoints, keypoint
71-
scores and the tag of each keypoint."""
72-
_group = Group(
73-
kpts=np.zeros((K, D), dtype=np.float32),
74-
scores=np.zeros(K, dtype=np.float32),
75-
tag_list=[])
76-
return _group
84+
tags = tag_k[idx]
85+
joints = np.concatenate((loc_k[idx], val_k[idx, :, None], tags), 1)
86+
mask = joints[:, 2] > val_thr
87+
tags = tags[mask] # shape: [M, L]
88+
joints = joints[mask] # shape: [M, 3 + L], 3: x, y, val
7789

78-
for i in keypoint_order:
79-
# Get all valid candidate of the i-th keypoints
80-
valid = vals[i] > val_thr
81-
if not valid.any():
90+
if joints.shape[0] == 0:
8291
continue
8392

84-
tags_i = tags[i, valid] # (M', L)
85-
vals_i = vals[i, valid] # (M',)
86-
locs_i = locs[i, valid] # (M', D)
87-
88-
if len(groups) == 0: # Initialize the group pool
89-
for tag, val, loc in zip(tags_i, vals_i, locs_i):
90-
group = _init_group()
91-
group.kpts[i] = loc
92-
group.scores[i] = val
93-
group.tag_list.append(tag)
94-
95-
groups.append(group)
96-
97-
else: # Match keypoints to existing groups
98-
groups = groups[:max_groups]
99-
group_tags = [np.mean(g.tag_list, axis=0) for g in groups]
100-
101-
# Calculate distance matrix between group tags and tag candidates
102-
# of the i-th keypoint
103-
# Shape: (M', 1, L) , (1, G, L) -> (M', G, L)
104-
diff = tags_i[:, None] - np.array(group_tags)[None]
105-
dists = np.linalg.norm(diff, ord=2, axis=2)
106-
num_kpts, num_groups = dists.shape[:2]
107-
108-
# Experimental cost function for keypoint-group matching
109-
costs = np.round(dists) * 100 - vals_i[..., None]
110-
if num_kpts > num_groups:
111-
padding = np.full((num_kpts, num_kpts - num_groups),
112-
1e10,
113-
dtype=np.float32)
114-
costs = np.concatenate((costs, padding), axis=1)
115-
116-
# Match keypoints and groups by Munkres algorithm
117-
matches = munkres.compute(costs)
118-
for kpt_idx, group_idx in matches:
119-
if group_idx < num_groups and dists[kpt_idx,
120-
group_idx] < tag_thr:
121-
# Add the keypoint to the matched group
122-
group = groups[group_idx]
93+
if i == 0 or len(joint_dict) == 0:
94+
for tag, joint in zip(tags, joints):
95+
key = tag[0]
96+
joint_dict.setdefault(key, np.copy(default_))[idx] = joint
97+
tag_dict[key] = [tag]
98+
else:
99+
# shape: [M]
100+
grouped_keys = list(joint_dict.keys())
101+
# shape: [M, L]
102+
grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys]
103+
104+
# shape: [M, M, L]
105+
diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :]
106+
# shape: [M, M]
107+
diff_normed = np.linalg.norm(diff, ord=2, axis=2)
108+
diff_saved = np.copy(diff_normed)
109+
diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3]
110+
111+
num_added = diff.shape[0]
112+
num_grouped = diff.shape[1]
113+
114+
if num_added > num_grouped:
115+
diff_normed = np.concatenate(
116+
(diff_normed,
117+
np.zeros((num_added, num_added - num_grouped),
118+
dtype=np.float32) + 1e10),
119+
axis=1)
120+
121+
pairs = _py_max_match(diff_normed)
122+
for row, col in pairs:
123+
if (row < num_added and col < num_grouped
124+
and diff_saved[row][col] < tag_thr):
125+
key = grouped_keys[col]
126+
joint_dict[key][idx] = joints[row]
127+
tag_dict[key].append(tags[row])
123128
else:
124-
# Initialize a new group with unmatched keypoint
125-
group = _init_group()
126-
groups.append(group)
127-
128-
group.kpts[i] = locs_i[kpt_idx]
129-
group.scores[i] = vals_i[kpt_idx]
130-
group.tag_list.append(tags_i[kpt_idx])
131-
132-
groups = groups[:max_groups]
133-
if groups:
134-
grouped_keypoints = np.stack(
135-
[np.r_['1', g.kpts, g.scores[:, None]] for g in groups])
136-
else:
137-
grouped_keypoints = np.empty((0, K, D + 1))
129+
key = tags[row][0]
130+
joint_dict.setdefault(key, np.copy(default_))[idx] = \
131+
joints[row]
132+
tag_dict[key] = [tags[row]]
138133

139-
return grouped_keypoints
134+
joint_dict_keys = list(joint_dict.keys())[:max_groups]
135+
136+
if joint_dict_keys:
137+
results = np.array([joint_dict[i]
138+
for i in joint_dict_keys]).astype(np.float32)
139+
results = results[..., :D + 1]
140+
else:
141+
results = np.empty((0, K, D + 1), dtype=np.float32)
142+
return results
140143

141144

142145
@KEYPOINT_CODECS.register_module()
@@ -210,7 +213,8 @@ def __init__(
210213
decode_gaussian_kernel: int = 3,
211214
decode_keypoint_thr: float = 0.1,
212215
decode_tag_thr: float = 1.0,
213-
decode_topk: int = 20,
216+
decode_topk: int = 30,
217+
decode_center_shift=0.0,
214218
decode_max_instances: Optional[int] = None,
215219
) -> None:
216220
super().__init__()
@@ -222,8 +226,9 @@ def __init__(
222226
self.decode_keypoint_thr = decode_keypoint_thr
223227
self.decode_tag_thr = decode_tag_thr
224228
self.decode_topk = decode_topk
229+
self.decode_center_shift = decode_center_shift
225230
self.decode_max_instances = decode_max_instances
226-
self.dedecode_keypoint_order = decode_keypoint_order.copy()
231+
self.decode_keypoint_order = decode_keypoint_order.copy()
227232

228233
if self.use_udp:
229234
self.scale_factor = ((np.array(input_size) - 1) /
@@ -376,7 +381,7 @@ def _group_func(inputs: Tuple):
376381
vals,
377382
tags,
378383
locs,
379-
keypoint_order=self.dedecode_keypoint_order,
384+
keypoint_order=self.decode_keypoint_order,
380385
val_thr=self.decode_keypoint_thr,
381386
tag_thr=self.decode_tag_thr,
382387
max_groups=self.decode_max_instances)
@@ -463,13 +468,13 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
463468
f'tagging map ({batch_tags.shape})')
464469

465470
# Heatmap NMS
466-
batch_heatmaps = batch_heatmap_nms(batch_heatmaps,
467-
self.decode_nms_kernel)
471+
batch_heatmaps_peak = batch_heatmap_nms(batch_heatmaps,
472+
self.decode_nms_kernel)
468473

469474
# Get top-k in each heatmap and and convert to numpy
470475
batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy(
471476
self._get_batch_topk(
472-
batch_heatmaps, batch_tags, k=self.decode_topk))
477+
batch_heatmaps_peak, batch_tags, k=self.decode_topk))
473478

474479
# Group keypoint candidates into groups (instances)
475480
batch_groups = self._group_keypoints(batch_topk_vals, batch_topk_tags,
@@ -482,16 +487,14 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
482487
# Refine the keypoint prediction
483488
batch_keypoints = []
484489
batch_keypoint_scores = []
490+
batch_instance_scores = []
485491
for i, (groups, heatmaps, tags) in enumerate(
486492
zip(batch_groups, batch_heatmaps_np, batch_tags_np)):
487493

488494
keypoints, scores = groups[..., :-1], groups[..., -1]
495+
instance_scores = scores.mean(axis=-1)
489496

490497
if keypoints.size > 0:
491-
# identify missing keypoints
492-
keypoints, scores = self._fill_missing_keypoints(
493-
keypoints, scores, heatmaps, tags)
494-
495498
# refine keypoint coordinates according to heatmap distribution
496499
if self.use_udp:
497500
keypoints = refine_keypoints_dark_udp(
@@ -500,13 +503,20 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
500503
blur_kernel_size=self.decode_gaussian_kernel)
501504
else:
502505
keypoints = refine_keypoints(keypoints, heatmaps)
506+
keypoints += self.decode_center_shift * \
507+
(scores > 0).astype(keypoints.dtype)[..., None]
508+
509+
# identify missing keypoints
510+
keypoints, scores = self._fill_missing_keypoints(
511+
keypoints, scores, heatmaps, tags)
503512

504513
batch_keypoints.append(keypoints)
505514
batch_keypoint_scores.append(scores)
515+
batch_instance_scores.append(instance_scores)
506516

507517
# restore keypoint scale
508518
batch_keypoints = [
509519
kpts * self.scale_factor for kpts in batch_keypoints
510520
]
511521

512-
return batch_keypoints, batch_keypoint_scores
522+
return batch_keypoints, batch_keypoint_scores, batch_instance_scores

0 commit comments

Comments
 (0)