Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multi-stage loss #204

Merged
merged 2 commits into from
Oct 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions mmpose/models/detectors/top_down.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.image import imwrite
from mmcv.visualization.image import imshow

Expand Down Expand Up @@ -43,7 +44,6 @@ def __init__(self,
self.keypoint_head = builder.build_head(keypoint_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg

self.loss = builder.build_loss(loss_pose)
self.init_weights(pretrained=pretrained)

Expand Down Expand Up @@ -115,30 +115,62 @@ def forward_train(self, img, target, target_weight, img_metas, **kwargs):
# if return loss
losses = dict()
if isinstance(output, list):
# multi-stage models
for output_i in output:
if target.dim() == 5 and target_weight.dim() == 4:
# target: [batch_size, num_outputs, num_joints, h, w]
# target_weight: [batch_size, num_outputs, num_joints, 1]
assert target.size(1) == len(output)
if isinstance(self.loss, nn.Sequential):
assert len(self.loss) == len(output)
if 'loss_weights' in self.train_cfg and self.train_cfg[
'loss_weights'] is not None:
assert len(self.train_cfg['loss_weights']) == len(output)
for i in range(len(output)):
if target.dim() == 5 and target_weight.dim() == 4:
target_i = target[:, i, :, :, :]
target_weight_i = target_weight[:, i, :, :]
else:
target_i = target
target_weight_i = target_weight
if isinstance(self.loss, nn.Sequential):
loss_func = self.loss[i]
else:
loss_func = self.loss

loss_i = loss_func(output[i], target_i, target_weight_i)
if 'loss_weights' in self.train_cfg and self.train_cfg[
'loss_weights']:
loss_i = loss_i * self.train_cfg['loss_weights'][i]
if 'mse_loss' not in losses:
losses['mse_loss'] = self.loss(output_i, target,
target_weight)
losses['mse_loss'] = loss_i
else:
losses['mse_loss'] += self.loss(output_i, target,
target_weight)
losses['mse_loss'] += loss_i
else:
assert not isinstance(self.loss, nn.Sequential)
assert target.dim() == 4 and target_weight.dim() == 3
# target: [batch_size, num_joints, h, w]
# target_weight: [batch_size, num_joints, 1]
losses['mse_loss'] = self.loss(output, target, target_weight)

if isinstance(output, list):
_, avg_acc, cnt = pose_pck_accuracy(
output[-1][target_weight.squeeze(-1) > 0].unsqueeze(
0).detach().cpu().numpy(),
target[target_weight.squeeze(-1) > 0].unsqueeze(
0).detach().cpu().numpy())
if target.dim() == 5 and target_weight.dim() == 4:
_, avg_acc, _ = pose_pck_accuracy(
output[-1][target_weight[:, -1, :, :].squeeze(-1) > 0].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: why is the last output special

Copy link
Collaborator

@jin-s13 jin-s13 Oct 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally speaking, the last output is the final prediction (used for calculating accuracy), while the intermediate outputs are used for training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, @wusize may add a line of comment explaining this magic number

unsqueeze(0).detach().cpu().numpy(),
target[:, -1, :, :, :][target_weight[:, -1, :, :].squeeze(
-1) > 0].unsqueeze(0).detach().cpu().numpy())
# Only use the last output for prediction
else:
_, avg_acc, _ = pose_pck_accuracy(
output[-1][target_weight.squeeze(-1) > 0].unsqueeze(
0).detach().cpu().numpy(),
target[target_weight.squeeze(-1) > 0].unsqueeze(
0).detach().cpu().numpy())
else:
_, avg_acc, cnt = pose_pck_accuracy(
_, avg_acc, _ = pose_pck_accuracy(
output[target_weight.squeeze(-1) > 0].unsqueeze(
0).detach().cpu().numpy(),
target[target_weight.squeeze(-1) > 0].unsqueeze(
0).detach().cpu().numpy())

losses['acc_pose'] = float(avg_acc)

return losses
Expand Down
59 changes: 55 additions & 4 deletions tests/test_model/test_top_down_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_topdown_forward():
shift_heatmap=True,
unbiased_decoding=False,
modulate_kernel=11),
loss_pose=dict(type='JointsMSELoss', use_target_weight=False))
loss_pose=dict(type='JointsMSELoss', use_target_weight=True))

detector = TopDown(model_cfg['backbone'], model_cfg['keypoint_head'],
model_cfg['train_cfg'], model_cfg['test_cfg'],
Expand Down Expand Up @@ -84,8 +84,54 @@ def test_topdown_forward():
with torch.no_grad():
_ = detector.forward(imgs, img_metas=img_metas, return_loss=False)

model_cfg = dict(
type='TopDown',
pretrained=None,
backbone=dict(
type='HourglassNet',
num_stacks=1,
),
keypoint_head=dict(
type='TopDownMultiStageHead',
in_channels=256,
out_channels=17,
num_stages=1,
num_deconv_layers=0,
extra=dict(final_conv_kernel=1, ),
),
train_cfg=dict(loss_weights=([1])),
test_cfg=dict(
flip_test=False,
post_process=True,
shift_heatmap=True,
unbiased_decoding=False,
modulate_kernel=11),
loss_pose=[dict(type='JointsMSELoss', use_target_weight=True)])

detector = TopDown(model_cfg['backbone'], model_cfg['keypoint_head'],
model_cfg['train_cfg'], model_cfg['test_cfg'],
model_cfg['pretrained'], model_cfg['loss_pose'])

detector.init_weights()

input_shape = (1, 3, 256, 256)
mm_inputs = _demo_mm_inputs(input_shape, num_outputs=1)

imgs = mm_inputs.pop('imgs')
target = mm_inputs.pop('target')
target_weight = mm_inputs.pop('target_weight')
img_metas = mm_inputs.pop('img_metas')

# Test forward train
losses = detector.forward(
imgs, target, target_weight, img_metas, return_loss=True)
assert isinstance(losses, dict)
# Test forward test
with torch.no_grad():
_ = detector.forward(imgs, img_metas=img_metas, return_loss=False)


def _demo_mm_inputs(input_shape=(1, 3, 256, 256)):
def _demo_mm_inputs(input_shape=(1, 3, 256, 256), num_outputs=None):
"""Create a superset of inputs needed to run test or train batches.

Args:
Expand All @@ -97,8 +143,13 @@ def _demo_mm_inputs(input_shape=(1, 3, 256, 256)):
rng = np.random.RandomState(0)

imgs = rng.rand(*input_shape)
target = np.zeros([N, 17, H // 4, W // 4], dtype=np.float32)
target_weight = np.ones([N, 17], dtype=np.float32)
if num_outputs is not None:
target = np.zeros([N, num_outputs, 17, H // 4, W // 4],
dtype=np.float32)
target_weight = np.ones([N, num_outputs, 17, 1], dtype=np.float32)
else:
target = np.zeros([N, 17, H // 4, W // 4], dtype=np.float32)
target_weight = np.ones([N, 17, 1], dtype=np.float32)

img_metas = [{
'img_shape': (H, W, C),
Expand Down