Skip to content

Commit

Permalink
[Feature] Add visibility prediction head (#2417)
Browse files Browse the repository at this point in the history
  • Loading branch information
Billccx authored Jun 19, 2023
1 parent 4679b30 commit 1340c3a
Show file tree
Hide file tree
Showing 8 changed files with 465 additions and 10 deletions.
10 changes: 10 additions & 0 deletions mmpose/datasets/transforms/common_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,16 @@ def transform(self, results: Dict) -> Optional[dict]:

results.update(encoded)

if results.get('keypoint_weights', None) is not None:
results['transformed_keypoints_visible'] = results[
'keypoint_weights']
elif results.get('keypoints', None) is not None:
results['transformed_keypoints_visible'] = results[
'keypoints_visible']
else:
raise ValueError('GenerateTarget requires \'keypoint_weights\' or'
' \'keypoints_visible\' in the results.')

return results

def __repr__(self) -> str:
Expand Down
10 changes: 8 additions & 2 deletions mmpose/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ class PackPoseInputs(BaseTransform):
'keypoint_x_labels': 'keypoint_x_labels',
'keypoint_y_labels': 'keypoint_y_labels',
'keypoint_weights': 'keypoint_weights',
'instance_coords': 'instance_coords'
'instance_coords': 'instance_coords',
'transformed_keypoints_visible': 'keypoints_visible',
}

# items in `field_mapping_table` will be packed into
Expand Down Expand Up @@ -196,6 +197,10 @@ def transform(self, results: dict) -> dict:
if self.pack_transformed and 'transformed_keypoints' in results:
gt_instances.set_field(results['transformed_keypoints'],
'transformed_keypoints')
if self.pack_transformed and \
'transformed_keypoints_visible' in results:
gt_instances.set_field(results['transformed_keypoints_visible'],
'transformed_keypoints_visible')

data_sample.gt_instances = gt_instances

Expand All @@ -205,7 +210,8 @@ def transform(self, results: dict) -> dict:
if key in results:
# For pose-lifting, store only target-related fields
if 'lifting_target_label' in results and key in {
'keypoint_labels', 'keypoint_weights'
'keypoint_labels', 'keypoint_weights',
'transformed_keypoints_visible'
}:
continue
if isinstance(results[key], list):
Expand Down
7 changes: 4 additions & 3 deletions mmpose/models/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from .coord_cls_heads import RTMCCHead, SimCCHead
from .heatmap_heads import (AssociativeEmbeddingHead, CIDHead, CPMHead,
HeatmapHead, MSPNHead, ViPNASHead)
from .hybrid_heads import DEKRHead
from .hybrid_heads import DEKRHead, VisPredictHead
from .regression_heads import (DSNTHead, IntegralRegressionHead,
RegressionHead, RLEHead, TemporalRegressionHead,
TrajectoryRegressionHead)

__all__ = [
'BaseHead', 'HeatmapHead', 'CPMHead', 'MSPNHead', 'ViPNASHead',
'RegressionHead', 'IntegralRegressionHead', 'SimCCHead', 'RLEHead',
'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'CIDHead', 'RTMCCHead',
'TemporalRegressionHead', 'TrajectoryRegressionHead'
'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'VisPredictHead',
'CIDHead', 'RTMCCHead', 'TemporalRegressionHead',
'TrajectoryRegressionHead'
]
5 changes: 2 additions & 3 deletions mmpose/models/heads/hybrid_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dekr_head import DEKRHead
from .vis_head import VisPredictHead

__all__ = [
'DEKRHead',
]
__all__ = ['DEKRHead', 'VisPredictHead']
229 changes: 229 additions & 0 deletions mmpose/models/heads/hybrid_heads/vis_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union

import torch
from torch import Tensor, nn

from mmpose.models.utils.tta import flip_visibility
from mmpose.registry import MODELS
from mmpose.utils.tensor_utils import to_numpy
from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
OptSampleList, Predictions)
from ..base_head import BaseHead


@MODELS.register_module()
class VisPredictHead(BaseHead):
"""VisPredictHead must be used together with other heads. It can predict
keypoints coordinates of and their visibility simultaneously. In the
current version, it only supports top-down approaches.
Args:
pose_cfg (Config): Config to construct keypoints prediction head
loss (Config): Config for visibility loss. Defaults to use
:class:`BCELoss`
use_sigmoid (bool): Whether to use sigmoid activation function
init_cfg (Config, optional): Config to control the initialization. See
:attr:`default_init_cfg` for default settings
"""

def __init__(self,
pose_cfg: ConfigType,
loss: ConfigType = dict(
type='BCELoss', use_target_weight=False,
with_logits=True),
use_sigmoid: bool = False,
init_cfg: OptConfigType = None):

if init_cfg is None:
init_cfg = self.default_init_cfg

super().__init__(init_cfg)

self.in_channels = pose_cfg['in_channels']
if pose_cfg.get('num_joints', None) is not None:
self.out_channels = pose_cfg['num_joints']
elif pose_cfg.get('out_channels', None) is not None:
self.out_channels = pose_cfg['out_channels']
else:
raise ValueError('VisPredictHead requires \'num_joints\' or'
' \'out_channels\' in the pose_cfg.')

self.loss_module = MODELS.build(loss)

self.pose_head = MODELS.build(pose_cfg)
self.pose_cfg = pose_cfg

self.use_sigmoid = use_sigmoid

modules = [
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(self.in_channels, self.out_channels)
]
if use_sigmoid:
modules.append(nn.Sigmoid())

self.vis_head = nn.Sequential(*modules)

def vis_forward(self, feats: Tuple[Tensor]):
"""Forward the vis_head. The input is multi scale feature maps and the
output is coordinates visibility.
Args:
feats (Tuple[Tensor]): Multi scale feature maps.
Returns:
Tensor: output coordinates visibility.
"""
x = feats[-1]
while len(x.shape) < 4:
x.unsqueeze_(-1)
x = self.vis_head(x)
return x.reshape(-1, self.out_channels)

def forward(self, feats: Tuple[Tensor]):
"""Forward the network. The input is multi scale feature maps and the
output is coordinates and coordinates visibility.
Args:
feats (Tuple[Tensor]): Multi scale feature maps.
Returns:
Tuple[Tensor]: output coordinates and coordinates visibility.
"""
x_pose = self.pose_head.forward(feats)
x_vis = self.vis_forward(feats)

return x_pose, x_vis

def integrate(self, batch_vis: Tensor,
pose_preds: Union[Tuple, Predictions]) -> InstanceList:
"""Add keypoints visibility prediction to pose prediction.
Overwrite the original keypoint_scores.
"""
if isinstance(pose_preds, tuple):
pose_pred_instances, pose_pred_fields = pose_preds
else:
pose_pred_instances = pose_preds
pose_pred_fields = None

batch_vis_np = to_numpy(batch_vis, unzip=True)

assert len(pose_pred_instances) == len(batch_vis_np)
for index, _ in enumerate(pose_pred_instances):
pose_pred_instances[index].keypoint_scores = batch_vis_np[index]

return pose_pred_instances, pose_pred_fields

def predict(self,
feats: Tuple[Tensor],
batch_data_samples: OptSampleList,
test_cfg: ConfigType = {}) -> Predictions:
"""Predict results from features.
Args:
feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
features (or multiple multi-stage features in TTA)
batch_data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
test_cfg (dict): The runtime config for testing process. Defaults
to {}
Returns:
Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
posehead's ``test_cfg['output_heatmap']==True``, return both
pose and heatmap prediction; otherwise only return the pose
prediction.
The pose prediction is a list of ``InstanceData``, each contains
the following fields:
- keypoints (np.ndarray): predicted keypoint coordinates in
shape (num_instances, K, D) where K is the keypoint number
and D is the keypoint dimension
- keypoint_scores (np.ndarray): predicted keypoint scores in
shape (num_instances, K)
- keypoint_visibility (np.ndarray): predicted keypoints
visibility in shape (num_instances, K)
The heatmap prediction is a list of ``PixelData``, each contains
the following fields:
- heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
"""
if test_cfg.get('flip_test', False):
# TTA: flip test -> feats = [orig, flipped]
assert isinstance(feats, list) and len(feats) == 2
flip_indices = batch_data_samples[0].metainfo['flip_indices']
_feats, _feats_flip = feats

_batch_vis = self.vis_forward(_feats)
_batch_vis_flip = flip_visibility(
self.vis_forward(_feats_flip), flip_indices=flip_indices)
batch_vis = (_batch_vis + _batch_vis_flip) * 0.5
else:
batch_vis = self.vis_forward(feats) # (B, K, D)

batch_vis.unsqueeze_(dim=1) # (B, N, K, D)

if not self.use_sigmoid:
batch_vis = torch.sigmoid(batch_vis)

batch_pose = self.pose_head.predict(feats, batch_data_samples,
test_cfg)

return self.integrate(batch_vis, batch_pose)

def vis_accuracy(self, vis_pred_outputs, vis_labels):
"""Calculate visibility prediction accuracy."""
probabilities = torch.sigmoid(torch.flatten(vis_pred_outputs))
threshold = 0.5
predictions = (probabilities >= threshold).int()
labels = torch.flatten(vis_labels)
correct = torch.sum(predictions == labels).item()
accuracy = correct / len(labels)
return torch.tensor(accuracy)

def loss(self,
feats: Tuple[Tensor],
batch_data_samples: OptSampleList,
train_cfg: OptConfigType = {}) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
feats (Tuple[Tensor]): The multi-stage features
batch_data_samples (List[:obj:`PoseDataSample`]): The batch
data samples
train_cfg (dict): The runtime config for training process.
Defaults to {}
Returns:
dict: A dictionary of losses.
"""
vis_pred_outputs = self.vis_forward(feats)
vis_labels = torch.cat([
d.gt_instance_labels.keypoint_weights for d in batch_data_samples
])

# calculate vis losses
losses = dict()
loss_vis = self.loss_module(vis_pred_outputs, vis_labels)

losses.update(loss_vis=loss_vis)

# calculate vis accuracy
acc_vis = self.vis_accuracy(vis_pred_outputs, vis_labels)
losses.update(acc_vis=acc_vis)

# calculate keypoints losses
loss_kpt = self.pose_head.loss(feats, batch_data_samples)
losses.update(loss_kpt)

return losses

@property
def default_init_cfg(self):
init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)]
return init_cfg
9 changes: 7 additions & 2 deletions mmpose/models/losses/classification_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@ class BCELoss(nn.Module):
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
with_logits (bool): Whether to use BCEWithLogitsLoss. Default: False.
"""

def __init__(self, use_target_weight=False, loss_weight=1.):
def __init__(self,
use_target_weight=False,
loss_weight=1.,
with_logits=False):
super().__init__()
self.criterion = F.binary_cross_entropy
self.criterion = F.binary_cross_entropy if not with_logits\
else F.binary_cross_entropy_with_logits
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight

Expand Down
15 changes: 15 additions & 0 deletions mmpose/models/utils/tta.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ def flip_coordinates(coords: Tensor, flip_indices: List[int],
return coords


def flip_visibility(vis: Tensor, flip_indices: List[int]):
"""Flip keypoints visibility for test-time augmentation.
Args:
vis (Tensor): The keypoints visibility to flip. Should be a tensor
in shape [B, K]
flip_indices (List[int]): The indices of each keypoint's symmetric
keypoint
"""
assert vis.ndim == 2

vis = vis[:, flip_indices]
return vis


def aggregate_heatmaps(heatmaps: List[Tensor],
size: Optional[Tuple[int, int]],
align_corners: bool = False,
Expand Down
Loading

0 comments on commit 1340c3a

Please sign in to comment.