Skip to content

Commit 78e65cc

Browse files
Add interhand3D pipeline. (open-mmlab#575)
* Add interhand3D pipeline. * reorgnize pipelines. * reorganize interhand3d pipelines. * modify bceloss. * modify 3d heatmap generation pipeline * modify codes according to reviews. * modify some comments.
1 parent a406b4f commit 78e65cc

13 files changed

+431
-7
lines changed

mmpose/datasets/datasets/hand/interhand3d_dataset.py

+10
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def __init__(self,
9797
assert self.ann_info['num_joints'] == 42
9898
self.ann_info['joint_weights'] = \
9999
np.ones((self.ann_info['num_joints'], 1), dtype=np.float32)
100+
self.ann_info['heatmap3d_depth_bound'] = data_cfg[
101+
'heatmap3d_depth_bound']
102+
self.ann_info['heatmap_size_root'] = data_cfg['heatmap_size_root']
103+
self.ann_info['root_depth_bound'] = data_cfg['root_depth_bound']
100104

101105
self.dataset_name = 'interhand3d'
102106
self.camera_file = camera_file
@@ -257,7 +261,12 @@ def _get_db(self):
257261
# the bboxes have been extended
258262
center, scale = self._xywh2cs(*bbox, 1.0)
259263
abs_depth = rootnet_ann_data['abs_depth']
264+
# 41: 'l_wrist', left hand root
265+
# 20: 'r_wrist', right hand root
260266
rel_root_depth = joint_cam[41, 2] - joint_cam[20, 2]
267+
# if root is not valid, root-relative 3D depth is also invalid.
268+
rel_root_valid = joint_valid[20] * joint_valid[41]
269+
261270
# if root is not valid -> root-relative 3D pose is also not valid.
262271
# Therefore, mark all joints as invalid
263272
joint_valid[:20] *= joint_valid[20]
@@ -280,6 +289,7 @@ def _get_db(self):
280289
'hand_type': hand_type,
281290
'hand_type_valid': hand_type_valid,
282291
'rel_root_depth': rel_root_depth,
292+
'rel_root_valid': rel_root_valid,
283293
'abs_depth': abs_depth,
284294
'joints_cam': joint_cam,
285295
'focal': focal,

mmpose/datasets/pipelines/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .bottom_up_transform import * # noqa
2+
from .hand_transform import * # noqa
23
from .loading import LoadImageFromFile # noqa
34
from .mesh_transform import * # noqa
45
from .pose3d_transform import * # noqa
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import numpy as np
2+
3+
from mmpose.datasets.registry import PIPELINES
4+
from .top_down_transform import TopDownRandomFlip
5+
6+
7+
@PIPELINES.register_module()
8+
class HandRandomFlip(TopDownRandomFlip):
9+
"""Data augmentation with random image flip. A child class of
10+
TopDownRandomFlip.
11+
12+
Required keys: 'img', 'joints_3d', 'joints_3d_visible', 'center',
13+
'hand_type', 'rel_root_depth' and 'ann_info'.
14+
Modifies key: 'img', 'joints_3d', 'joints_3d_visible', 'center',
15+
'hand_type', 'rel_root_depth'.
16+
17+
Args:
18+
flip_prob (float): Probability of flip.
19+
"""
20+
21+
def __call__(self, results):
22+
"""Perform data augmentation with random image flip."""
23+
# base flip augmentation
24+
super().__call__(results)
25+
26+
# flip hand type and root depth
27+
hand_type = results['hand_type']
28+
rel_root_depth = results['rel_root_depth']
29+
flipped = results['flipped']
30+
if flipped:
31+
hand_type[0], hand_type[1] = hand_type[1], hand_type[0]
32+
rel_root_depth = -rel_root_depth
33+
results['hand_type'] = hand_type
34+
results['rel_root_depth'] = rel_root_depth
35+
return results
36+
37+
38+
@PIPELINES.register_module()
39+
class HandGenerateRelDepthTarget:
40+
"""Generate the target relative root depth.
41+
42+
Required keys: 'rel_root_depth', 'rel_root_valid', 'ann_info'. Modified
43+
keys: 'target', 'target_weight'.
44+
"""
45+
46+
def __init__(self):
47+
pass
48+
49+
def __call__(self, results):
50+
"""Generate the target heatmap."""
51+
rel_root_depth = results['rel_root_depth']
52+
rel_root_valid = results['rel_root_valid']
53+
cfg = results['ann_info']
54+
D = cfg['heatmap_size_root']
55+
root_depth_bound = cfg['root_depth_bound']
56+
target = (rel_root_depth / root_depth_bound + 0.5) * D
57+
target_weight = rel_root_valid * (target >= 0) * (target <= D)
58+
results['target'] = target * np.ones(1, dtype=np.float32)
59+
results['target_weight'] = target_weight * np.ones(1, dtype=np.float32)
60+
return results

mmpose/datasets/pipelines/pose3d_transform.py

+61
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,64 @@ def __call__(self, results):
295295
results[self.item] = torch.from_numpy(seq)
296296

297297
return results
298+
299+
300+
@PIPELINES.register_module()
301+
class Generate3DHeatmapTarget:
302+
"""Generate the target 3d heatmap.
303+
304+
Required keys: 'joints_3d', 'joints_3d_visible', 'ann_info'.
305+
Modified keys: 'target', and 'target_weight'.
306+
307+
Args:
308+
sigma: Sigma of heatmap gaussian.
309+
joint_indices (list): Indices of joints used for heatmap generation.
310+
If None (default) is given, all joints will be used.
311+
"""
312+
313+
def __init__(self, sigma=2, joint_indices=None):
314+
self.sigma = sigma
315+
self.joint_indices = joint_indices
316+
317+
def __call__(self, results):
318+
"""Generate the target heatmap."""
319+
joints_3d = results['joints_3d']
320+
joints_3d_visible = results['joints_3d_visible']
321+
cfg = results['ann_info']
322+
image_size = cfg['image_size']
323+
W, H, D = cfg['heatmap_size']
324+
heatmap3d_depth_bound = cfg['heatmap3d_depth_bound']
325+
joint_weights = cfg['joint_weights']
326+
use_different_joint_weights = cfg['use_different_joint_weights']
327+
328+
if self.joint_indices is not None:
329+
joints_3d = joints_3d[self.joint_indices, ...]
330+
joints_3d_visible = joints_3d_visible[self.joint_indices, ...]
331+
joint_weights = joint_weights[self.joint_indices, ...]
332+
333+
mu_x = joints_3d[:, 0] * W / image_size[0]
334+
mu_y = joints_3d[:, 1] * H / image_size[1]
335+
mu_z = (joints_3d[:, 2] / heatmap3d_depth_bound + 0.5) * D
336+
337+
target_weight = joints_3d_visible[:, 0]
338+
target_weight = target_weight * (mu_z >= 0) * (mu_z < D)
339+
if use_different_joint_weights:
340+
target_weight = target_weight * joint_weights
341+
target_weight = target_weight[:, None]
342+
343+
x, y, z = np.arange(W), np.arange(H), np.arange(D)
344+
zz, yy, xx = np.meshgrid(z, y, x)
345+
xx = xx[None, ...].astype(np.float32)
346+
yy = yy[None, ...].astype(np.float32)
347+
zz = zz[None, ...].astype(np.float32)
348+
349+
mu_x = mu_x[..., None, None, None]
350+
mu_y = mu_y[..., None, None, None]
351+
mu_z = mu_z[..., None, None, None]
352+
353+
target = np.exp(-((xx - mu_x)**2 + (yy - mu_y)**2 + (zz - mu_z)**2) /
354+
(2 * self.sigma**2))
355+
356+
results['target'] = target
357+
results['target_weight'] = target_weight
358+
return results

mmpose/datasets/pipelines/shared_transform.py

+57
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,60 @@ def __repr__(self):
403403
f'{self.saturation_upper}), '
404404
f'hue_delta={self.hue_delta})')
405405
return repr_str
406+
407+
408+
@PIPELINES.register_module()
409+
class MultitaskGatherTarget:
410+
"""Gather the targets for multitask heads.
411+
412+
Args:
413+
pipeline_list (list[list]): List of pipelines for all heads.
414+
pipeline_indices (list[int]): Pipeline index of each head.
415+
"""
416+
417+
def __init__(self, pipeline_list, pipeline_indices):
418+
self.pipelines = []
419+
for pipeline in pipeline_list:
420+
self.pipelines.append(Compose(pipeline))
421+
self.pipeline_indices = pipeline_indices
422+
423+
def __call__(self, results):
424+
# generate target and target weights using all pipelines
425+
_target, _target_weight = [], []
426+
for pipeline in self.pipelines:
427+
results_head = pipeline(results)
428+
_target.append(results_head['target'])
429+
_target_weight.append(results_head['target_weight'])
430+
431+
# reorganize generated target, target_weights according
432+
# to self.pipelines_indices
433+
target, target_weight = [], []
434+
for ind in self.pipeline_indices:
435+
target.append(_target[ind])
436+
target_weight.append(_target_weight[ind])
437+
438+
results['target'] = target
439+
results['target_weight'] = target_weight
440+
return results
441+
442+
443+
@PIPELINES.register_module()
444+
class RenameKeys:
445+
"""Rename the keys.
446+
447+
Args:
448+
key_pairs (Sequence[tuple]): Required keys to be renamed. If a tuple
449+
(key_src, key_tgt) is given as an element, the item retrived by key_src
450+
will be renamed as key_tgt.
451+
"""
452+
453+
def __init__(self, key_pairs):
454+
self.key_pairs = key_pairs
455+
456+
def __call__(self, results):
457+
"""Rename keys."""
458+
for key_pair in self.key_pairs:
459+
assert len(key_pair) == 2
460+
key_src, key_tgt = key_pair
461+
results[key_tgt] = results.pop(key_src)
462+
return results

mmpose/datasets/pipelines/top_down_transform.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ class TopDownRandomFlip:
1212
"""Data augmentation with random image flip.
1313
1414
Required keys: 'img', 'joints_3d', 'joints_3d_visible', 'center' and
15-
'ann_info'. Modifies key: 'img', 'joints_3d', 'joints_3d_visible' and
16-
'center'.
15+
'ann_info'.
16+
Modifies key: 'img', 'joints_3d', 'joints_3d_visible', 'center' and
17+
'flipped'.
1718
1819
Args:
1920
flip (bool): Option to perform random flip.
@@ -30,9 +31,12 @@ def __call__(self, results):
3031
joints_3d_visible = results['joints_3d_visible']
3132
center = results['center']
3233

34+
# A flag indicating whether the image is flipped,
35+
# which can be used by child class.
36+
flipped = False
3337
if np.random.rand() <= self.flip_prob:
38+
flipped = True
3439
img = img[:, ::-1, :]
35-
3640
joints_3d, joints_3d_visible = fliplr_joints(
3741
joints_3d, joints_3d_visible, img.shape[1],
3842
results['ann_info']['flip_pairs'])
@@ -42,6 +46,7 @@ def __call__(self, results):
4246
results['joints_3d'] = joints_3d
4347
results['joints_3d_visible'] = joints_3d_visible
4448
results['center'] = center
49+
results['flipped'] = flipped
4550

4651
return results
4752

@@ -645,3 +650,31 @@ def __call__(self, results):
645650
results['target_weight'] = target_weight
646651

647652
return results
653+
654+
655+
@PIPELINES.register_module()
656+
class TopDownRandomTranslation:
657+
"""Data augmentation with random translation.
658+
659+
Required key: 'scale' and 'center'. Modifies key: 'center'.
660+
661+
Notes:
662+
bbox height: H
663+
bbox width: W
664+
665+
Args:
666+
trans_factor (float): Translating center to
667+
``[-trans_factor, trans_factor] * [W, H] + center``.
668+
"""
669+
670+
def __init__(self, trans_factor=0.15):
671+
self.trans_factor = trans_factor
672+
673+
def __call__(self, results):
674+
"""Perform data augmentation with random translation."""
675+
center = results['center']
676+
scale = results['scale']
677+
# reference bbox size is [200, 200] pixels
678+
center += self.trans_factor * (2 * np.random.rand(2) - 1) * scale * 200
679+
results['center'] = center
680+
return results

mmpose/models/losses/classfication_loss.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,19 @@ def forward(self, output, target, target_weight):
1919
2020
Note:
2121
batch_size: N
22-
num_keypoints: K
22+
num_labels: K
2323
2424
Args:
2525
output (torch.Tensor[N, K]): Output classification.
2626
target (torch.Tensor[N, K]): Target classification.
27-
target_weight (torch.Tensor[N, K]):
28-
Weights across different joint types.
27+
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
28+
Weights across different labels.
2929
"""
30+
3031
if self.use_target_weight:
3132
loss = self.criterion(output, target, reduction='none')
33+
if target_weight.dim() == 1:
34+
target_weight = target_weight[:, None]
3235
loss = (loss * target_weight).mean()
3336
else:
3437
loss = self.criterion(output, target)

tests/test_datasets/test_hand_dataset.py

+3
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,9 @@ def test_top_down_InterHand3D_dataset():
367367
data_cfg = dict(
368368
image_size=[256, 256],
369369
heatmap_size=[64, 64, 64],
370+
heatmap3d_depth_bound=400.0,
371+
heatmap_size_root=64,
372+
root_depth_bound=400.0,
370373
num_output_channels=channel_cfg['num_output_channels'],
371374
num_joints=channel_cfg['dataset_joints'],
372375
dataset_channel=channel_cfg['dataset_channel'],
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
3+
4+
def test_bce_loss():
5+
from mmpose.models import build_loss
6+
7+
# test BCE loss without target weight
8+
loss_cfg = dict(type='BCELoss')
9+
loss = build_loss(loss_cfg)
10+
11+
fake_pred = torch.zeros((1, 2))
12+
fake_label = torch.zeros((1, 2))
13+
assert torch.allclose(loss(fake_pred, fake_label, None), torch.tensor(0.))
14+
15+
fake_pred = torch.ones((1, 2)) * 0.5
16+
fake_label = torch.zeros((1, 2))
17+
assert torch.allclose(
18+
loss(fake_pred, fake_label, None), -torch.log(torch.tensor(0.5)))
19+
20+
# test BCE loss with target weight
21+
loss_cfg = dict(type='BCELoss', use_target_weight=True)
22+
loss = build_loss(loss_cfg)
23+
24+
fake_pred = torch.ones((1, 2)) * 0.5
25+
fake_label = torch.zeros((1, 2))
26+
fake_weight = torch.ones((1, 2))
27+
assert torch.allclose(
28+
loss(fake_pred, fake_label, fake_weight),
29+
-torch.log(torch.tensor(0.5)))
30+
31+
fake_weight[:, 0] = 0
32+
assert torch.allclose(
33+
loss(fake_pred, fake_label, fake_weight),
34+
-0.5 * torch.log(torch.tensor(0.5)))
35+
36+
fake_weight = torch.ones(1)
37+
assert torch.allclose(
38+
loss(fake_pred, fake_label, fake_weight),
39+
-torch.log(torch.tensor(0.5)))

0 commit comments

Comments
 (0)