Skip to content

Commit

Permalink
add multi-stage loss (open-mmlab#204)
Browse files Browse the repository at this point in the history
* add multi-stage loss

* Modifications
  • Loading branch information
wusize authored Oct 21, 2020
1 parent ef553f5 commit 113e4da
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 18 deletions.
60 changes: 46 additions & 14 deletions mmpose/models/detectors/top_down.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.image import imwrite
from mmcv.visualization.image import imshow

Expand Down Expand Up @@ -43,7 +44,6 @@ def __init__(self,
self.keypoint_head = builder.build_head(keypoint_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg

self.loss = builder.build_loss(loss_pose)
self.init_weights(pretrained=pretrained)

Expand Down Expand Up @@ -115,30 +115,62 @@ def forward_train(self, img, target, target_weight, img_metas, **kwargs):
# if return loss
losses = dict()
if isinstance(output, list):
# multi-stage models
for output_i in output:
if target.dim() == 5 and target_weight.dim() == 4:
# target: [batch_size, num_outputs, num_joints, h, w]
# target_weight: [batch_size, num_outputs, num_joints, 1]
assert target.size(1) == len(output)
if isinstance(self.loss, nn.Sequential):
assert len(self.loss) == len(output)
if 'loss_weights' in self.train_cfg and self.train_cfg[
'loss_weights'] is not None:
assert len(self.train_cfg['loss_weights']) == len(output)
for i in range(len(output)):
if target.dim() == 5 and target_weight.dim() == 4:
target_i = target[:, i, :, :, :]
target_weight_i = target_weight[:, i, :, :]
else:
target_i = target
target_weight_i = target_weight
if isinstance(self.loss, nn.Sequential):
loss_func = self.loss[i]
else:
loss_func = self.loss

loss_i = loss_func(output[i], target_i, target_weight_i)
if 'loss_weights' in self.train_cfg and self.train_cfg[
'loss_weights']:
loss_i = loss_i * self.train_cfg['loss_weights'][i]
if 'mse_loss' not in losses:
losses['mse_loss'] = self.loss(output_i, target,
target_weight)
losses['mse_loss'] = loss_i
else:
losses['mse_loss'] += self.loss(output_i, target,
target_weight)
losses['mse_loss'] += loss_i
else:
assert not isinstance(self.loss, nn.Sequential)
assert target.dim() == 4 and target_weight.dim() == 3
# target: [batch_size, num_joints, h, w]
# target_weight: [batch_size, num_joints, 1]
losses['mse_loss'] = self.loss(output, target, target_weight)

if isinstance(output, list):
_, avg_acc, cnt = pose_pck_accuracy(
output[-1][target_weight.squeeze(-1) > 0].unsqueeze(
0).detach().cpu().numpy(),
target[target_weight.squeeze(-1) > 0].unsqueeze(
0).detach().cpu().numpy())
if target.dim() == 5 and target_weight.dim() == 4:
_, avg_acc, _ = pose_pck_accuracy(
output[-1][target_weight[:, -1, :, :].squeeze(-1) > 0].
unsqueeze(0).detach().cpu().numpy(),
target[:, -1, :, :, :][target_weight[:, -1, :, :].squeeze(
-1) > 0].unsqueeze(0).detach().cpu().numpy())
# Only use the last output for prediction
else:
_, avg_acc, _ = pose_pck_accuracy(
output[-1][target_weight.squeeze(-1) > 0].unsqueeze(
0).detach().cpu().numpy(),
target[target_weight.squeeze(-1) > 0].unsqueeze(
0).detach().cpu().numpy())
else:
_, avg_acc, cnt = pose_pck_accuracy(
_, avg_acc, _ = pose_pck_accuracy(
output[target_weight.squeeze(-1) > 0].unsqueeze(
0).detach().cpu().numpy(),
target[target_weight.squeeze(-1) > 0].unsqueeze(
0).detach().cpu().numpy())

losses['acc_pose'] = float(avg_acc)

return losses
Expand Down
59 changes: 55 additions & 4 deletions tests/test_model/test_top_down_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_topdown_forward():
shift_heatmap=True,
unbiased_decoding=False,
modulate_kernel=11),
loss_pose=dict(type='JointsMSELoss', use_target_weight=False))
loss_pose=dict(type='JointsMSELoss', use_target_weight=True))

detector = TopDown(model_cfg['backbone'], model_cfg['keypoint_head'],
model_cfg['train_cfg'], model_cfg['test_cfg'],
Expand Down Expand Up @@ -84,8 +84,54 @@ def test_topdown_forward():
with torch.no_grad():
_ = detector.forward(imgs, img_metas=img_metas, return_loss=False)

model_cfg = dict(
type='TopDown',
pretrained=None,
backbone=dict(
type='HourglassNet',
num_stacks=1,
),
keypoint_head=dict(
type='TopDownMultiStageHead',
in_channels=256,
out_channels=17,
num_stages=1,
num_deconv_layers=0,
extra=dict(final_conv_kernel=1, ),
),
train_cfg=dict(loss_weights=([1])),
test_cfg=dict(
flip_test=False,
post_process=True,
shift_heatmap=True,
unbiased_decoding=False,
modulate_kernel=11),
loss_pose=[dict(type='JointsMSELoss', use_target_weight=True)])

detector = TopDown(model_cfg['backbone'], model_cfg['keypoint_head'],
model_cfg['train_cfg'], model_cfg['test_cfg'],
model_cfg['pretrained'], model_cfg['loss_pose'])

detector.init_weights()

input_shape = (1, 3, 256, 256)
mm_inputs = _demo_mm_inputs(input_shape, num_outputs=1)

imgs = mm_inputs.pop('imgs')
target = mm_inputs.pop('target')
target_weight = mm_inputs.pop('target_weight')
img_metas = mm_inputs.pop('img_metas')

# Test forward train
losses = detector.forward(
imgs, target, target_weight, img_metas, return_loss=True)
assert isinstance(losses, dict)
# Test forward test
with torch.no_grad():
_ = detector.forward(imgs, img_metas=img_metas, return_loss=False)


def _demo_mm_inputs(input_shape=(1, 3, 256, 256)):
def _demo_mm_inputs(input_shape=(1, 3, 256, 256), num_outputs=None):
"""Create a superset of inputs needed to run test or train batches.
Args:
Expand All @@ -97,8 +143,13 @@ def _demo_mm_inputs(input_shape=(1, 3, 256, 256)):
rng = np.random.RandomState(0)

imgs = rng.rand(*input_shape)
target = np.zeros([N, 17, H // 4, W // 4], dtype=np.float32)
target_weight = np.ones([N, 17], dtype=np.float32)
if num_outputs is not None:
target = np.zeros([N, num_outputs, 17, H // 4, W // 4],
dtype=np.float32)
target_weight = np.ones([N, num_outputs, 17, 1], dtype=np.float32)
else:
target = np.zeros([N, 17, H // 4, W // 4], dtype=np.float32)
target_weight = np.ones([N, 17, 1], dtype=np.float32)

img_metas = [{
'img_shape': (H, W, C),
Expand Down

0 comments on commit 113e4da

Please sign in to comment.