Skip to content

Commit

Permalink
[Fix] Fix udp regress (#1682)
Browse files Browse the repository at this point in the history
  • Loading branch information
liqikai9 authored and ly015 committed Oct 14, 2022
1 parent af2af15 commit a3b71ee
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@
in_channels=32,
out_channels=3 * 17,
deconv_out_channels=None,
loss=dict(type='KeypointMSELoss', use_target_weight=True),
loss=dict(type='CombinedTargetMSELoss', use_target_weight=True),
decoder=codec),
train_cfg=dict(compute_acc=False),
test_cfg=dict(
flip_test=True,
flip_mode='udp_combined',
Expand Down
2 changes: 1 addition & 1 deletion mmpose/codecs/utils/offset_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def generate_offset_heatmap(
(N, K)
radius_factor (float): The radius factor of the binary label
map. The positive region is defined as the neighbor of the
keypoit with the radius :math:`r=radius_factor*max(W, H)`
keypoint with the radius :math:`r=radius_factor*max(W, H)`
Returns:
tuple:
Expand Down
15 changes: 8 additions & 7 deletions mmpose/models/heads/heatmap_heads/heatmap_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,14 @@ def loss(self,
losses.update(loss_kpt=loss)

# calculate accuracy
_, avg_acc, _ = pose_pck_accuracy(
output=to_numpy(pred_fields),
target=to_numpy(gt_heatmaps),
mask=to_numpy(keypoint_weights) > 0)

acc_pose = torch.tensor(avg_acc, device=gt_heatmaps.device)
losses.update(acc_pose=acc_pose)
if train_cfg.get('compute_acc', True):
_, avg_acc, _ = pose_pck_accuracy(
output=to_numpy(pred_fields),
target=to_numpy(gt_heatmaps),
mask=to_numpy(keypoint_weights) > 0)

acc_pose = torch.tensor(avg_acc, device=gt_heatmaps.device)
losses.update(acc_pose=acc_pose)

return losses

Expand Down
13 changes: 7 additions & 6 deletions mmpose/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
from .classification_loss import BCELoss, JSDiscretLoss, KLDiscretLoss
from .heatmap_loss import AdaptiveWingLoss
from .loss_wrappers import MultipleLossWrapper
from .mse_loss import KeypointMSELoss, KeypointOHKMMSELoss
from .mse_loss import (CombinedTargetMSELoss, KeypointMSELoss,
KeypointOHKMMSELoss)
from .multi_loss_factory import AELoss, HeatmapLoss, MultiLossFactory
from .regression_loss import (BoneLoss, L1Loss, MPJPELoss, MSELoss, RLELoss,
SemiSupervisionLoss, SmoothL1Loss, SoftWingLoss,
WingLoss)

__all__ = [
'KeypointMSELoss', 'KeypointOHKMMSELoss', 'HeatmapLoss', 'AELoss',
'MultiLossFactory', 'SmoothL1Loss', 'WingLoss', 'MPJPELoss', 'MSELoss',
'L1Loss', 'BCELoss', 'BoneLoss', 'SemiSupervisionLoss', 'SoftWingLoss',
'AdaptiveWingLoss', 'RLELoss', 'KLDiscretLoss', 'MultipleLossWrapper',
'JSDiscretLoss'
'KeypointMSELoss', 'KeypointOHKMMSELoss', 'CombinedTargetMSELoss',
'HeatmapLoss', 'AELoss', 'MultiLossFactory', 'SmoothL1Loss', 'WingLoss',
'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BoneLoss',
'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss', 'RLELoss',
'KLDiscretLoss', 'MultipleLossWrapper', 'JSDiscretLoss'
]

0 comments on commit a3b71ee

Please sign in to comment.