Skip to content

Commit

Permalink
Merge 4db02a1 into d882fec
Browse files Browse the repository at this point in the history
  • Loading branch information
liqikai9 authored Sep 23, 2022
2 parents d882fec + 4db02a1 commit fd78734
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@
in_channels=32,
out_channels=3 * 17,
deconv_out_channels=None,
loss=dict(type='KeypointMSELoss', use_target_weight=True),
loss=dict(type='CombinedTargetMSELoss', use_target_weight=True),
print_acc_pose=False,
decoder=codec),
test_cfg=dict(
flip_test=True,
Expand Down
2 changes: 1 addition & 1 deletion mmpose/codecs/utils/offset_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def generate_offset_heatmap(
(N, K)
radius_factor (float): The radius factor of the binary label
map. The positive region is defined as the neighbor of the
keypoit with the radius :math:`r=radius_factor*max(W, H)`
keypoint with the radius :math:`r=radius_factor*max(W, H)`
Returns:
tuple:
Expand Down
21 changes: 13 additions & 8 deletions mmpose/models/heads/heatmap_heads/heatmap_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class HeatmapHead(BaseHead):
keypoint coordinates from the network output. Defaults to ``None``
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
print_acc_pose(bool): Whether to print the `acc_pose` during training.
Defaults to ``True``
.. _`Simple Baselines`: https://arxiv.org/abs/1804.06208
"""
Expand All @@ -84,7 +86,8 @@ def __init__(self,
loss: ConfigType = dict(
type='KeypointMSELoss', use_target_weight=True),
decoder: OptConfigType = None,
init_cfg: OptConfigType = None):
init_cfg: OptConfigType = None,
print_acc_pose: bool = True):

if init_cfg is None:
init_cfg = self.default_init_cfg
Expand All @@ -96,6 +99,7 @@ def __init__(self,
self.align_corners = align_corners
self.input_transform = input_transform
self.input_index = input_index
self.print_acc_pose = print_acc_pose
self.loss_module = MODELS.build(loss)
if decoder is not None:
self.decoder = KEYPOINT_CODECS.build(decoder)
Expand Down Expand Up @@ -332,13 +336,14 @@ def loss(self,
losses.update(loss_kpt=loss)

# calculate accuracy
_, avg_acc, _ = pose_pck_accuracy(
output=to_numpy(pred_fields),
target=to_numpy(gt_heatmaps),
mask=to_numpy(keypoint_weights) > 0)

acc_pose = torch.tensor(avg_acc, device=gt_heatmaps.device)
losses.update(acc_pose=acc_pose)
if self.print_acc_pose:
_, avg_acc, _ = pose_pck_accuracy(
output=to_numpy(pred_fields),
target=to_numpy(gt_heatmaps),
mask=to_numpy(keypoint_weights) > 0)

acc_pose = torch.tensor(avg_acc, device=gt_heatmaps.device)
losses.update(acc_pose=acc_pose)

return losses

Expand Down
6 changes: 5 additions & 1 deletion mmpose/models/heads/heatmap_heads/vipnas_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class ViPNASHead(HeatmapHead):
keypoint coordinates from the network output. Defaults to ``None``
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
print_acc_pose(bool): Whether to print the `acc_pose` during training.
Defaults to ``True``
.. _`ViPNAS`: https://arxiv.org/abs/2105.10154
.. _`Simple Baselines`: https://arxiv.org/abs/1804.06208
Expand All @@ -85,7 +87,8 @@ def __init__(self,
loss: ConfigType = dict(
type='KeypointMSELoss', use_target_weight=True),
decoder: OptConfigType = None,
init_cfg: OptConfigType = None):
init_cfg: OptConfigType = None,
print_acc_pose: bool = True):

if init_cfg is None:
init_cfg = self.default_init_cfg
Expand All @@ -97,6 +100,7 @@ def __init__(self,
self.align_corners = align_corners
self.input_transform = input_transform
self.input_index = input_index
self.print_acc_pose = print_acc_pose
self.loss_module = MODELS.build(loss)
if decoder is not None:
self.decoder = KEYPOINT_CODECS.build(decoder)
Expand Down
13 changes: 7 additions & 6 deletions mmpose/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
from .classification_loss import BCELoss, JSDiscretLoss, KLDiscretLoss
from .heatmap_loss import AdaptiveWingLoss
from .loss_wrappers import MultipleLossWrapper
from .mse_loss import KeypointMSELoss, KeypointOHKMMSELoss
from .mse_loss import (CombinedTargetMSELoss, KeypointMSELoss,
KeypointOHKMMSELoss)
from .multi_loss_factory import AELoss, HeatmapLoss, MultiLossFactory
from .regression_loss import (BoneLoss, L1Loss, MPJPELoss, MSELoss, RLELoss,
SemiSupervisionLoss, SmoothL1Loss, SoftWingLoss,
WingLoss)

__all__ = [
'KeypointMSELoss', 'KeypointOHKMMSELoss', 'HeatmapLoss', 'AELoss',
'MultiLossFactory', 'SmoothL1Loss', 'WingLoss', 'MPJPELoss', 'MSELoss',
'L1Loss', 'BCELoss', 'BoneLoss', 'SemiSupervisionLoss', 'SoftWingLoss',
'AdaptiveWingLoss', 'RLELoss', 'KLDiscretLoss', 'MultipleLossWrapper',
'JSDiscretLoss'
'KeypointMSELoss', 'KeypointOHKMMSELoss', 'CombinedTargetMSELoss',
'HeatmapLoss', 'AELoss', 'MultiLossFactory', 'SmoothL1Loss', 'WingLoss',
'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BoneLoss',
'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss', 'RLELoss',
'KLDiscretLoss', 'MultipleLossWrapper', 'JSDiscretLoss'
]

0 comments on commit fd78734

Please sign in to comment.