Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add visibility prediction head #2417

Merged
merged 5 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mmpose/datasets/transforms/common_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,16 @@ def transform(self, results: Dict) -> Optional[dict]:

results.update(encoded)

if results.get('keypoint_weights', None) is not None:
results['transformed_keypoints_visible'] = results[
'keypoint_weights']
elif results.get('keypoints', None) is not None:
results['transformed_keypoints_visible'] = results[
'keypoints_visible']
else:
raise ValueError('GenerateTarget requires \'keypoint_weights\' or'
' \'keypoints_visible\' in the results.')

return results

def __repr__(self) -> str:
Expand Down
10 changes: 8 additions & 2 deletions mmpose/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ class PackPoseInputs(BaseTransform):
'keypoint_x_labels': 'keypoint_x_labels',
'keypoint_y_labels': 'keypoint_y_labels',
'keypoint_weights': 'keypoint_weights',
'instance_coords': 'instance_coords'
'instance_coords': 'instance_coords',
'transformed_keypoints_visible': 'keypoints_visible',
}

# items in `field_mapping_table` will be packed into
Expand Down Expand Up @@ -196,6 +197,10 @@ def transform(self, results: dict) -> dict:
if self.pack_transformed and 'transformed_keypoints' in results:
gt_instances.set_field(results['transformed_keypoints'],
'transformed_keypoints')
if self.pack_transformed and \
'transformed_keypoints_visible' in results:
gt_instances.set_field(results['transformed_keypoints_visible'],
'transformed_keypoints_visible')

data_sample.gt_instances = gt_instances

Expand All @@ -205,7 +210,8 @@ def transform(self, results: dict) -> dict:
if key in results:
# For pose-lifting, store only target-related fields
if 'lifting_target_label' in results and key in {
'keypoint_labels', 'keypoint_weights'
'keypoint_labels', 'keypoint_weights',
'transformed_keypoints_visible'
}:
continue
if isinstance(results[key], list):
Expand Down
7 changes: 4 additions & 3 deletions mmpose/models/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from .coord_cls_heads import RTMCCHead, SimCCHead
from .heatmap_heads import (AssociativeEmbeddingHead, CIDHead, CPMHead,
HeatmapHead, MSPNHead, ViPNASHead)
from .hybrid_heads import DEKRHead
from .hybrid_heads import DEKRHead, VisPredictHead
from .regression_heads import (DSNTHead, IntegralRegressionHead,
RegressionHead, RLEHead, TemporalRegressionHead,
TrajectoryRegressionHead)

__all__ = [
'BaseHead', 'HeatmapHead', 'CPMHead', 'MSPNHead', 'ViPNASHead',
'RegressionHead', 'IntegralRegressionHead', 'SimCCHead', 'RLEHead',
'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'CIDHead', 'RTMCCHead',
'TemporalRegressionHead', 'TrajectoryRegressionHead'
'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'VisPredictHead',
'CIDHead', 'RTMCCHead', 'TemporalRegressionHead',
'TrajectoryRegressionHead'
]
5 changes: 2 additions & 3 deletions mmpose/models/heads/hybrid_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dekr_head import DEKRHead
from .vis_head import VisPredictHead

__all__ = [
'DEKRHead',
]
__all__ = ['DEKRHead', 'VisPredictHead']
229 changes: 229 additions & 0 deletions mmpose/models/heads/hybrid_heads/vis_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union

import torch
from torch import Tensor, nn

from mmpose.models.utils.tta import flip_visibility
from mmpose.registry import MODELS
from mmpose.utils.tensor_utils import to_numpy
from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
OptSampleList, Predictions)
from ..base_head import BaseHead


@MODELS.register_module()
class VisPredictHead(BaseHead):
"""VisPredictHead must be used together with other heads. It can predict
keypoints coordinates of and their visibility simultaneously. In the
current version, it only supports top-down approaches.
Args:
pose_cfg (Config): Config to construct keypoints prediction head
loss (Config): Config for visibility loss. Defaults to use
:class:`BCELoss`
use_sigmoid (bool): Whether to use sigmoid activation function
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
"""

def __init__(self,
pose_cfg: ConfigType,
loss: ConfigType = dict(
type='BCELoss', use_target_weight=False,
with_logits=True),
use_sigmoid: bool = False,
init_cfg: OptConfigType = None):

if init_cfg is None:
init_cfg = self.default_init_cfg

super().__init__(init_cfg)

self.in_channels = pose_cfg['in_channels']
if pose_cfg.get('num_joints', None) is not None:
self.out_channels = pose_cfg['num_joints']
elif pose_cfg.get('out_channels', None) is not None:
self.out_channels = pose_cfg['out_channels']
else:
raise ValueError('VisPredictHead requires \'num_joints\' or'
' \'out_channels\' in the pose_cfg.')

self.loss_module = MODELS.build(loss)

self.pose_head = MODELS.build(pose_cfg)
self.pose_cfg = pose_cfg

self.use_sigmoid = use_sigmoid

modules = [
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(self.in_channels, self.out_channels)
]
if use_sigmoid:
modules.append(nn.Sigmoid())

self.vis_head = nn.Sequential(*modules)

def vis_forward(self, feats: Tuple[Tensor]):
"""Forward the vis_head. The input is multi scale feature maps and the
output is coordinates visibility.
Args:
feats (Tuple[Tensor]): Multi scale feature maps.
Returns:
Tensor: output coordinates visibility.
"""
x = feats[-1]
while len(x.shape) < 4:
x.unsqueeze_(-1)
x = self.vis_head(x)
return x.reshape(-1, self.out_channels)

def forward(self, feats: Tuple[Tensor]):
"""Forward the network. The input is multi scale feature maps and the
output is coordinates and coordinates visibility.
Args:
feats (Tuple[Tensor]): Multi scale feature maps.
Returns:
Tuple[Tensor]: output coordinates and coordinates visibility.
"""
x_pose = self.pose_head.forward(feats)
x_vis = self.vis_forward(feats)

return x_pose, x_vis

def integrate(self, batch_vis: Tensor,
pose_preds: Union[Tuple, Predictions]) -> InstanceList:
"""Add keypoints visibility prediction to pose prediction.
Overwrite the original keypoint_scores.
"""
if isinstance(pose_preds, tuple):
pose_pred_instances, pose_pred_fields = pose_preds
else:
pose_pred_instances = pose_preds
pose_pred_fields = None

batch_vis_np = to_numpy(batch_vis, unzip=True)

assert len(pose_pred_instances) == len(batch_vis_np)
for index, _ in enumerate(pose_pred_instances):
pose_pred_instances[index].keypoint_scores = batch_vis_np[index]

return pose_pred_instances, pose_pred_fields

def predict(self,
feats: Tuple[Tensor],
batch_data_samples: OptSampleList,
test_cfg: ConfigType = {}) -> Predictions:
"""Predict results from features.
Args:
feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
features (or multiple multi-stage features in TTA)
batch_data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
test_cfg (dict): The runtime config for testing process. Defaults
to {}
Returns:
Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
posehead's ``test_cfg['output_heatmap']==True``, return both
pose and heatmap prediction; otherwise only return the pose
prediction.
The pose prediction is a list of ``InstanceData``, each contains
the following fields:
- keypoints (np.ndarray): predicted keypoint coordinates in
shape (num_instances, K, D) where K is the keypoint number
and D is the keypoint dimension
- keypoint_scores (np.ndarray): predicted keypoint scores in
shape (num_instances, K)
- keypoint_visibility (np.ndarray): predicted keypoints
visibility in shape (num_instances, K)
The heatmap prediction is a list of ``PixelData``, each contains
the following fields:
- heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
"""
if test_cfg.get('flip_test', False):
# TTA: flip test -> feats = [orig, flipped]
assert isinstance(feats, list) and len(feats) == 2
flip_indices = batch_data_samples[0].metainfo['flip_indices']
_feats, _feats_flip = feats

_batch_vis = self.vis_forward(_feats)
_batch_vis_flip = flip_visibility(
self.vis_forward(_feats_flip), flip_indices=flip_indices)
batch_vis = (_batch_vis + _batch_vis_flip) * 0.5
else:
batch_vis = self.vis_forward(feats) # (B, K, D)

batch_vis.unsqueeze_(dim=1) # (B, N, K, D)

if not self.use_sigmoid:
batch_vis = torch.sigmoid(batch_vis)
Comment on lines +171 to +172
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is sigmoid applied to batch_vis when use_sigmoid is False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the value of visibility prediction should be between 0 and 1? If use_sigmoid is False, there won't be sigmoid in vis_forward.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the sigmoid function is always applied to the head outputs, what is the role of the use_sigmoid argument?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is sigmoid applied to batch_vis when use_sigmoid is False?

@Ben-Louis When use_sigmoid is True, a nn.Sigmoid() will be in the vis_head, namely, the batch_vis is between 0~1. Otherwise, we need to do sigmoid manually when use_sigmoid is False.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using nn.Sigmoid() and BCELoss separately is not officially recommended by Pytorch, a better way is to use BCEWithLogitsLoss and do sigmoid manually when a 0-1 score is needed. So I suggest him to set a use_sigmoid to keep the flexity.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tau-J, thanks for the clarification. This can indeed be puzzling, perhaps we could add some comments for clarity? Also, focusing on a single recommended practice could be beneficial. BTW, in MMDetection, this parameter is generally integrated in loss modules like mmdet.CrossEntropyLoss. We might consider refining the BCELoss module with this parameter, rather than introducing a completely new loss module.


batch_pose = self.pose_head.predict(feats, batch_data_samples,
test_cfg)

return self.integrate(batch_vis, batch_pose)

def vis_accuracy(self, vis_pred_outputs, vis_labels):
"""Calculate visibility prediction accuracy."""
probabilities = torch.sigmoid(torch.flatten(vis_pred_outputs))
threshold = 0.5
predictions = (probabilities >= threshold).int()
labels = torch.flatten(vis_labels)
correct = torch.sum(predictions == labels).item()
accuracy = correct / len(labels)
return torch.tensor(accuracy)

def loss(self,
feats: Tuple[Tensor],
batch_data_samples: OptSampleList,
train_cfg: OptConfigType = {}) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
feats (Tuple[Tensor]): The multi-stage features
batch_data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
train_cfg (dict): The runtime config for training process.
Defaults to {}
Returns:
dict: A dictionary of losses.
"""
vis_pred_outputs = self.vis_forward(feats)
vis_labels = torch.cat([
d.gt_instance_labels.keypoint_weights for d in batch_data_samples
])

# calculate vis losses
losses = dict()
loss_vis = self.loss_module(vis_pred_outputs, vis_labels)

losses.update(loss_vis=loss_vis)

# calculate vis accuracy
acc_vis = self.vis_accuracy(vis_pred_outputs, vis_labels)
losses.update(acc_vis=acc_vis)

# calculate keypoints losses
loss_kpt = self.pose_head.loss(feats, batch_data_samples)
losses.update(loss_kpt)

return losses

@property
def default_init_cfg(self):
init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)]
return init_cfg
9 changes: 7 additions & 2 deletions mmpose/models/losses/classification_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@ class BCELoss(nn.Module):
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
with_logits (bool): Whether to use BCEWithLogitsLoss. Default: False.
"""

def __init__(self, use_target_weight=False, loss_weight=1.):
def __init__(self,
use_target_weight=False,
loss_weight=1.,
with_logits=False):
super().__init__()
self.criterion = F.binary_cross_entropy
self.criterion = F.binary_cross_entropy if not with_logits\
else F.binary_cross_entropy_with_logits
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight

Expand Down
15 changes: 15 additions & 0 deletions mmpose/models/utils/tta.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ def flip_coordinates(coords: Tensor, flip_indices: List[int],
return coords


def flip_visibility(vis: Tensor, flip_indices: List[int]):
"""Flip keypoints visibility for test-time augmentation.
Args:
vis (Tensor): The keypoints visibility to flip. Should be a tensor
in shape [B, K]
flip_indices (List[int]): The indices of each keypoint's symmetric
keypoint
"""
assert vis.ndim == 2

vis = vis[:, flip_indices]
return vis


def aggregate_heatmaps(heatmaps: List[Tensor],
size: Optional[Tuple[int, int]],
align_corners: bool = False,
Expand Down
Loading