Skip to content

Commit

Permalink
Add rsn backbone, head and pre-processing (open-mmlab#221)
Browse files Browse the repository at this point in the history
* Add rsn

* Add unit test

* Modification

* Some corrections

* Corrections

* Correction in top_down_transform.py

* Correction in top_down_transform.py

* Add mspn

* unit test for mspn

* remove mspn

* rename mspn_head to msmu_head

* Add unit test for backbone and head

* fix a bug in rsn.py

* Add rsn

* Add unit test

* Modification

* Some corrections

* Corrections

* Correction in top_down_transform.py

* Correction in top_down_transform.py

* Add mspn

* unit test for mspn

* remove mspn

* rename mspn_head to msmu_head

* Add unit test for backbone and head

* fix a bug in rsn.py

* rm unnecessary codes

* add comment for 3-sigma rule

* fix some bugs

Co-authored-by: jinsheng <jinsheng@sensetime.com>
  • Loading branch information
wusize and jin-s13 authored Nov 12, 2020
1 parent 122a63f commit b0fe89a
Show file tree
Hide file tree
Showing 10 changed files with 1,286 additions and 37 deletions.
157 changes: 157 additions & 0 deletions configs/top_down/rsn/coco/single_ctf_rsn18_coco_256x192.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
log_level = 'INFO'
load_from = None
resume_from = None
dist_params = dict(backend='nccl')
workflow = [('train', 1)]
checkpoint_config = dict(interval=10)
evaluation = dict(interval=10, metric='mAP', key_indicator='AP')

optimizer = dict(
type='Adam',
lr=5e-3,
)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='poly',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
min_lr=5e-5)
total_epochs = 210
log_config = dict(
interval=50, hooks=[
dict(type='TextLoggerHook'),
])

channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])

# model settings
model = dict(
type='TopDown',
pretrained=None,
backbone=dict(
type='RSN',
unit_channels=256,
num_stages=1,
num_units=4,
num_blocks=[2, 2, 2, 2],
num_steps=4,
norm_cfg=dict(type='BN')),
keypoint_head=dict(
type='TopDownMSMUHead',
out_shape=(64, 48),
unit_channels=256,
out_channels=channel_cfg['num_output_channels'],
num_stages=1,
num_units=4,
use_prm=False,
norm_cfg=dict(type='BN')),
train_cfg=dict(num_units=4, loss_weights=[0.25] * 3 + [1]),
test_cfg=dict(
flip_test=True,
post_process=True,
shift_heatmap=True,
unbiased_decoding=False,
modulate_kernel=11),
loss_pose=[dict(type='JointsMSELoss', use_target_weight=True)] * 3 +
[dict(type='JointsOHKMMSELoss', use_target_weight=True)])

data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
bbox_thr=1.0,
use_gt_bbox=False,
image_thr=0.0,
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(
type='TopDownHalfBodyTransform',
num_joints_half_body=8,
prob_half_body=0.3),
dict(
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
dict(type='TopDownAffine'),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='TopDownGenerateTarget',
kernel=[(11, 11), (9, 9), (7, 7), (5, 5)],
encoding='Megvii'),
dict(
type='Collect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
'rotation', 'bbox_score', 'flip_pairs'
]),
]

val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownAffine'),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='Collect',
keys=[
'img',
],
meta_keys=[
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
'flip_pairs'
]),
]

test_pipeline = val_pipeline

data_root = 'data/coco'
data = dict(
samples_per_gpu=32,
workers_per_gpu=2,
train=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
img_prefix=f'{data_root}/train2017/',
data_cfg=data_cfg,
pipeline=train_pipeline),
val=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=val_pipeline),
test=dict(
type='TopDownCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=val_pipeline),
)
153 changes: 122 additions & 31 deletions mmpose/datasets/pipelines/top_down_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,25 +205,35 @@ class TopDownGenerateTarget():
Modified keys: 'target', and 'target_weight'.
Args:
sigma: Sigma of heatmap gaussian.
sigma: Sigma of heatmap gaussian for 'MSRA' approach.
kernel: Kernel of heatmap gaussian for 'Megvii' approach.
encoding (str): Approach to generate target heatmaps.
Currently supported approaches: 'MSRA', 'Megvii'. Default:'MSRA'
unbiased_encoding (bool): Option to use unbiased
encoding methods.
Paper ref: Zhang et al. Distribution-Aware Coordinate
Representation for Human Pose Estimation (CVPR 2020).
"""

def __init__(self, sigma=2, unbiased_encoding=False):
def __init__(self,
sigma=2,
kernel=(11, 11),
encoding='MSRA',
unbiased_encoding=False):
self.sigma = sigma
self.unbiased_encoding = unbiased_encoding
self.kernel = kernel
self.encoding = encoding

def _generate_target(self, cfg, joints_3d, joints_3d_visible):
"""Generate the target heatmap.
def _msra_generate_target(self, cfg, joints_3d, joints_3d_visible, sigma):
"""Generate the target heatmap via "MSRA" approach.
Args:
cfg (dict): data config
joints_3d: np.ndarray ([num_joints, 3])
joints_3d_visible: np.ndarray ([num_joints, 3])
sigma: Sigma of heatmap gaussian
Returns:
tuple: A tuple containing targets.
Expand All @@ -232,54 +242,50 @@ def _generate_target(self, cfg, joints_3d, joints_3d_visible):
"""
num_joints = cfg['num_joints']
image_size = cfg['image_size']
heatmap_size = cfg['heatmap_size']
W, H = cfg['heatmap_size']
joint_weights = cfg['joint_weights']
use_different_joint_weights = cfg['use_different_joint_weights']

target_weight = np.zeros((num_joints, 1), dtype=np.float32)
target = np.zeros((num_joints, heatmap_size[1], heatmap_size[0]),
dtype=np.float32)
target = np.zeros((num_joints, H, W), dtype=np.float32)

tmp_size = self.sigma * 3
# 3-sigma rule
tmp_size = sigma * 3

if self.unbiased_encoding:
for joint_id in range(num_joints):
heatmap_vis = joints_3d_visible[joint_id, 0]
target_weight[joint_id] = heatmap_vis
target_weight[joint_id] = joints_3d_visible[joint_id, 0]

feat_stride = image_size / heatmap_size
feat_stride = image_size / [W, H]
mu_x = joints_3d[joint_id][0] / feat_stride[0]
mu_y = joints_3d[joint_id][1] / feat_stride[1]
# Check that any part of the gaussian is in-bounds
ul = [mu_x - tmp_size, mu_y - tmp_size]
br = [mu_x + tmp_size + 1, mu_y + tmp_size + 1]
if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] or br[
0] < 0 or br[1] < 0:
if ul[0] >= W or ul[1] >= H or br[0] < 0 or br[1] < 0:
target_weight[joint_id] = 0

if target_weight[joint_id] == 0:
continue

x = np.arange(0, heatmap_size[0], 1, np.float32)
y = np.arange(0, heatmap_size[1], 1, np.float32)
x = np.arange(0, W, 1, np.float32)
y = np.arange(0, H, 1, np.float32)
y = y[:, None]

if target_weight[joint_id] > 0.5:
target[joint_id] = np.exp(
-((x - mu_x)**2 + (y - mu_y)**2) / (2 * self.sigma**2))
-((x - mu_x)**2 + (y - mu_y)**2) / (2 * sigma**2))
else:
for joint_id in range(num_joints):
heatmap_vis = joints_3d_visible[joint_id, 0]
target_weight[joint_id] = heatmap_vis
target_weight[joint_id] = joints_3d_visible[joint_id, 0]

feat_stride = image_size / heatmap_size
feat_stride = image_size / [W, H]
mu_x = int(joints_3d[joint_id][0] / feat_stride[0] + 0.5)
mu_y = int(joints_3d[joint_id][1] / feat_stride[1] + 0.5)
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] or br[
0] < 0 or br[1] < 0:
if ul[0] >= W or ul[1] >= H or br[0] < 0 or br[1] < 0:
target_weight[joint_id] = 0

if target_weight[joint_id] > 0.5:
Expand All @@ -289,15 +295,14 @@ def _generate_target(self, cfg, joints_3d, joints_3d_visible):
x0 = y0 = size // 2
# The gaussian is not normalized,
# we want the center value to equal 1
g = np.exp(-((x - x0)**2 + (y - y0)**2) /
(2 * self.sigma**2))
g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))

# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]
g_x = max(0, -ul[0]), min(br[0], W) - ul[0]
g_y = max(0, -ul[1]), min(br[1], H) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], heatmap_size[0])
img_y = max(0, ul[1]), min(br[1], heatmap_size[1])
img_x = max(0, ul[0]), min(br[0], W)
img_y = max(0, ul[1]), min(br[1], H)

target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
Expand All @@ -307,14 +312,100 @@ def _generate_target(self, cfg, joints_3d, joints_3d_visible):

return target, target_weight

def _megvii_generate_target(self, cfg, joints_3d, joints_3d_visible,
kernel):
"""Generate the target heatmap via "Megvii" approach.
Args:
cfg (dict): data config
joints_3d: np.ndarray ([num_joints, 3])
joints_3d_visible: np.ndarray ([num_joints, 3])
kernel: Kernel of heatmap gaussian
Returns:
tuple: A tuple containing targets.
- target: Target heatmaps.
- target_weight: (1: visible, 0: invisible)
"""

num_joints = cfg['num_joints']
image_size = cfg['image_size']
W, H = cfg['heatmap_size']
heatmaps = np.zeros((num_joints, H, W), dtype='float32')
target_weight = np.zeros((num_joints, 1), dtype=np.float32)

for i in range(num_joints):
target_weight[i] = joints_3d_visible[i, 0]

if target_weight[i] < 1:
continue

target_y = int(joints_3d[i, 1] * H / image_size[1])
target_x = int(joints_3d[i, 0] * W / image_size[0])

if (target_x >= W or target_x < 0) \
or (target_y >= H or target_y < 0):
target_weight[i] = 0
continue

heatmaps[i, target_y, target_x] = 1
heatmaps[i] = cv2.GaussianBlur(heatmaps[i], kernel, 0)
maxi = heatmaps[i, target_y, target_x]

heatmaps[i] /= maxi / 255

return heatmaps, target_weight

def __call__(self, results):
"""Generate the target heatmap."""
joints_3d = results['joints_3d']
joints_3d_visible = results['joints_3d_visible']

target, target_weight = self._generate_target(results['ann_info'],
joints_3d,
joints_3d_visible)
assert self.encoding in ['MSRA', 'Megvii']

if self.encoding == 'MSRA':
if isinstance(self.sigma, list):
num_sigmas = len(self.sigma)
cfg = results['ann_info']
num_joints = cfg['num_joints']
heatmap_size = cfg['heatmap_size']

target = np.empty(
(0, num_joints, heatmap_size[1], heatmap_size[0]),
dtype=np.float32)
target_weight = np.empty((0, num_joints, 1), dtype=np.float32)
for i in range(num_sigmas):
target_i, target_weight_i = self._msra_generate_target(
cfg, joints_3d, joints_3d_visible, self.sigma[i])
target = np.concatenate([target, target_i[None]], axis=0)
target_weight = np.concatenate(
[target_weight, target_weight_i[None]], axis=0)
else:
target, target_weight = self._msra_generate_target(
results['ann_info'], joints_3d, joints_3d_visible,
self.sigma)
elif self.encoding == 'Megvii':
if isinstance(self.kernel, list):
num_kernels = len(self.kernel)
cfg = results['ann_info']
num_joints = cfg['num_joints']
W, H = cfg['heatmap_size']

target = np.empty((0, num_joints, H, W), dtype=np.float32)
target_weight = np.empty((0, num_joints, 1), dtype=np.float32)
for i in range(num_kernels):
target_i, target_weight_i = self._megvii_generate_target(
cfg, joints_3d, joints_3d_visible, self.kernel[i])
target = np.concatenate([target, target_i[None]], axis=0)
target_weight = np.concatenate(
[target_weight, target_weight_i[None]], axis=0)
else:
target, target_weight = self._megvii_generate_target(
results['ann_info'], joints_3d, joints_3d_visible,
self.kernel)
else:
raise ValueError(
f'Encoding approach {self.encoding} is not supported!')

results['target'] = target
results['target_weight'] = target_weight
Expand Down
Loading

0 comments on commit b0fe89a

Please sign in to comment.