From d3ba1b9333e85ec425882a07083f1c9089def274 Mon Sep 17 00:00:00 2001 From: Billccx_server Date: Wed, 31 May 2023 14:58:38 +0800 Subject: [PATCH 1/5] add vis_head --- .../datasets/transforms/common_transforms.py | 10 + mmpose/datasets/transforms/formatting.py | 7 +- mmpose/models/heads/hybrid_heads/__init__.py | 5 +- mmpose/models/heads/hybrid_heads/vis_head.py | 211 ++++++++++++++++++ mmpose/models/losses/__init__.py | 11 +- mmpose/models/losses/classification_loss.py | 42 ++++ mmpose/models/utils/tta.py | 15 ++ 7 files changed, 292 insertions(+), 9 deletions(-) create mode 100644 mmpose/models/heads/hybrid_heads/vis_head.py diff --git a/mmpose/datasets/transforms/common_transforms.py b/mmpose/datasets/transforms/common_transforms.py index 8db0ff37c7..a727b5374d 100644 --- a/mmpose/datasets/transforms/common_transforms.py +++ b/mmpose/datasets/transforms/common_transforms.py @@ -1029,6 +1029,16 @@ def transform(self, results: Dict) -> Optional[dict]: results.update(encoded) + if results.get('keypoint_weights', None) is not None: + results['transformed_keypoints_visible'] = results[ + 'keypoint_weights'] + elif results.get('keypoints', None) is not None: + results['transformed_keypoints_visible'] = results[ + 'keypoints_visible'] + else: + raise ValueError('GenerateTarget requires \'keypoint_weights\' or' + ' \'keypoints_visible\' in the results.') + return results def __repr__(self) -> str: diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py index 403147120d..eb421f12ad 100644 --- a/mmpose/datasets/transforms/formatting.py +++ b/mmpose/datasets/transforms/formatting.py @@ -129,7 +129,8 @@ class PackPoseInputs(BaseTransform): 'keypoint_x_labels': 'keypoint_x_labels', 'keypoint_y_labels': 'keypoint_y_labels', 'keypoint_weights': 'keypoint_weights', - 'instance_coords': 'instance_coords' + 'instance_coords': 'instance_coords', + 'transformed_keypoints_visible': 'keypoints_visible', } # items in `field_mapping_table` will be packed into @@ -196,6 +197,10 @@ def transform(self, results: dict) -> dict: if self.pack_transformed and 'transformed_keypoints' in results: gt_instances.set_field(results['transformed_keypoints'], 'transformed_keypoints') + if self.pack_transformed and \ + 'transformed_keypoints_visible' in results: + gt_instances.set_field(results['transformed_keypoints_visible'], + 'transformed_keypoints_visible') data_sample.gt_instances = gt_instances diff --git a/mmpose/models/heads/hybrid_heads/__init__.py b/mmpose/models/heads/hybrid_heads/__init__.py index 55d5a211c1..6431b6a2c2 100644 --- a/mmpose/models/heads/hybrid_heads/__init__.py +++ b/mmpose/models/heads/hybrid_heads/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dekr_head import DEKRHead +from .vis_head import VisPredictHead -__all__ = [ - 'DEKRHead', -] +__all__ = ['DEKRHead', 'VisPredictHead'] diff --git a/mmpose/models/heads/hybrid_heads/vis_head.py b/mmpose/models/heads/hybrid_heads/vis_head.py new file mode 100644 index 0000000000..eceaba607f --- /dev/null +++ b/mmpose/models/heads/hybrid_heads/vis_head.py @@ -0,0 +1,211 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +import torch +from torch import Tensor, nn + +from mmpose.models.utils.tta import flip_visibility +from mmpose.registry import MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, + OptSampleList, Predictions) +from ..base_head import BaseHead + + +@MODELS.register_module() +class VisPredictHead(BaseHead): + """VisPredictHead must be used together with other heads. It can predict + keypoints coordinates of and their visibility simultaneously. In the + current version, it only supports top-down approaches. + + Args: + pose_cfg (Config): Config to construct keypoints prediction head + loss (Config): Config for visibility loss. Defaults to use + :class:`BCEWithLogitsLoss` + use_sigmoid (bool): Whether to use sigmoid activation function + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + """ + + def __init__(self, + pose_cfg: ConfigType, + loss: ConfigType = dict( + type='BCEWithLogitsLoss', use_target_weight=True), + use_sigmoid: bool = False, + init_cfg: OptConfigType = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.in_channels = pose_cfg['in_channels'] + if pose_cfg.get('num_joints', None) is not None: + self.out_channels = pose_cfg['num_joints'] + elif pose_cfg.get('out_channels', None) is not None: + self.out_channels = pose_cfg['out_channels'] + else: + raise ValueError('VisPredictHead requires \'num_joints\' or' + ' \'out_channels\' in the pose_cfg.') + + self.loss_module = MODELS.build(loss) + + self.pose_head = MODELS.build(pose_cfg) + self.pose_cfg = pose_cfg + + modules = [ + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + nn.Linear(self.in_channels, self.out_channels) + ] + if use_sigmoid: + modules.append(nn.Sigmoid()) + + self.vis_head = nn.Sequential(*modules) + + def vis_forward(self, feats: Tuple[Tensor]): + """Forward the vis_head. The input is multi scale feature maps and the + output is coordinates visibility. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + Tensor: output coordinates visibility. + """ + x = feats[-1] + while len(x.shape) < 4: + x.unsqueeze_(-1) + x = self.vis_head(x) + return x.reshape(-1, self.out_channels) + + def forward(self, feats: Tuple[Tensor]): + """Forward the network. The input is multi scale feature maps and the + output is coordinates and coordinates visibility. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + Tuple[Tensor]: output coordinates and coordinates visibility. + """ + x_pose = self.pose_head.forward(feats) + x_vis = self.vis_forward(feats) + + return x_pose, x_vis + + def integrate(self, batch_vis: Tensor, + pose_preds: Union[Tuple, Predictions]) -> InstanceList: + """Add keypoints visibility prediction to pose prediction.""" + if isinstance(pose_preds, tuple): + pose_pred_instances, pose_pred_fields = pose_preds + else: + pose_pred_instances = pose_preds + pose_pred_fields = None + + batch_vis_np = to_numpy(batch_vis, unzip=True) + + assert len(pose_pred_instances) == len(batch_vis_np) + for index in range(len(pose_pred_instances)): + pose_pred_instances[index].keypoint_visibility = batch_vis_np[ + index] + + return pose_pred_instances, pose_pred_fields + + def predict(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from features. + + Args: + feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage + features (or multiple multi-stage features in TTA) + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + test_cfg (dict): The runtime config for testing process. Defaults + to {} + + Returns: + Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If + posehead's ``test_cfg['output_heatmap']==True``, return both + pose and heatmap prediction; otherwise only return the pose + prediction. + + The pose prediction is a list of ``InstanceData``, each contains + the following fields: + + - keypoints (np.ndarray): predicted keypoint coordinates in + shape (num_instances, K, D) where K is the keypoint number + and D is the keypoint dimension + - keypoint_scores (np.ndarray): predicted keypoint scores in + shape (num_instances, K) + - keypoint_visibility (np.ndarray): predicted keypoints + visibility in shape (num_instances, K) + + The heatmap prediction is a list of ``PixelData``, each contains + the following fields: + + - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) + """ + if test_cfg.get('flip_test', False): + # TTA: flip test -> feats = [orig, flipped] + assert isinstance(feats, list) and len(feats) == 2 + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + _feats, _feats_flip = feats + + _batch_vis = self.vis_forward(_feats) + _batch_vis_flip = flip_visibility( + self.vis_forward(_feats_flip), flip_indices=flip_indices) + batch_vis = (_batch_vis + _batch_vis_flip) * 0.5 + else: + batch_vis = self.vis_forward(feats) # (B, K, D) + + batch_vis.unsqueeze_(dim=1) # (B, N, K, D) + + batch_pose = self.pose_head.predict(feats, batch_data_samples, + test_cfg) + + return self.integrate(batch_vis, batch_pose) + + def loss(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: OptConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + feats (Tuple[Tensor]): The multi-stage features + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + train_cfg (dict): The runtime config for training process. + Defaults to {} + + Returns: + dict: A dictionary of losses. + """ + vis_pred_outputs = self.vis_forward(feats) + vis_labels = torch.cat([ + d.gt_instance_labels.keypoints_visible for d in batch_data_samples + ]) + keypoint_weights = torch.cat([ + d.gt_instance_labels.keypoint_weights for d in batch_data_samples + ]) + + # calculate vis losses + losses = dict() + loss_vis = self.loss_module(vis_pred_outputs, vis_labels, + keypoint_weights) + + losses.update(loss_vis=loss_vis) + + # calculate keypoints losses + loss_kpt = self.pose_head.loss(feats, batch_data_samples) + losses.update(loss_kpt) + + return losses + + @property + def default_init_cfg(self): + init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] + return init_cfg diff --git a/mmpose/models/losses/__init__.py b/mmpose/models/losses/__init__.py index f21071e156..ad0c879897 100644 --- a/mmpose/models/losses/__init__.py +++ b/mmpose/models/losses/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ae_loss import AssociativeEmbeddingLoss -from .classification_loss import BCELoss, JSDiscretLoss, KLDiscretLoss +from .classification_loss import (BCELoss, BCEWithLogitsLoss, JSDiscretLoss, + KLDiscretLoss) from .heatmap_loss import (AdaptiveWingLoss, KeypointMSELoss, KeypointOHKMMSELoss) from .loss_wrappers import CombinedLoss, MultipleLossWrapper @@ -10,8 +11,8 @@ __all__ = [ 'KeypointMSELoss', 'KeypointOHKMMSELoss', 'SmoothL1Loss', 'WingLoss', - 'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BoneLoss', - 'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss', 'RLELoss', - 'KLDiscretLoss', 'MultipleLossWrapper', 'JSDiscretLoss', 'CombinedLoss', - 'AssociativeEmbeddingLoss', 'SoftWeightSmoothL1Loss' + 'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BCEWithLogitsLoss', + 'BoneLoss', 'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss', + 'RLELoss', 'KLDiscretLoss', 'MultipleLossWrapper', 'JSDiscretLoss', + 'CombinedLoss', 'AssociativeEmbeddingLoss', 'SoftWeightSmoothL1Loss' ] diff --git a/mmpose/models/losses/classification_loss.py b/mmpose/models/losses/classification_loss.py index 6c3bdf502b..780a77573f 100644 --- a/mmpose/models/losses/classification_loss.py +++ b/mmpose/models/losses/classification_loss.py @@ -48,6 +48,48 @@ def forward(self, output, target, target_weight=None): return loss * self.loss_weight +@MODELS.register_module() +class BCEWithLogitsLoss(nn.Module): + """Binary Cross Entropy With Logits loss. + + Args: + use_target_weight (bool): Option to use weighted loss. + Different joint types may have different target weights. + loss_weight (float): Weight of the loss. Default: 1.0. + """ + + def __init__(self, use_target_weight=False, loss_weight=1.) -> None: + super().__init__() + self.criterion = F.binary_cross_entropy_with_logits + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + def forward(self, output, target, target_weight=None): + """Forward function. + + Note: + - batch_size: N + - num_labels: K + + Args: + output (torch.Tensor[N, K]): Output classification. + target (torch.Tensor[N, K]): Target classification. + target_weight (torch.Tensor[N, K] or torch.Tensor[N]): + Weights across different labels. + """ + + if self.use_target_weight: + assert target_weight is not None + loss = self.criterion(output, target, reduction='none') + if target_weight.dim() == 1: + target_weight = target_weight[:, None] + loss = (loss * target_weight).mean() + else: + loss = self.criterion(output, target) + + return loss * self.loss_weight + + @MODELS.register_module() class JSDiscretLoss(nn.Module): """Discrete JS Divergence loss for DSNT with Gaussian Heatmap. diff --git a/mmpose/models/utils/tta.py b/mmpose/models/utils/tta.py index 0add48a422..41d2f2fd47 100644 --- a/mmpose/models/utils/tta.py +++ b/mmpose/models/utils/tta.py @@ -114,6 +114,21 @@ def flip_coordinates(coords: Tensor, flip_indices: List[int], return coords +def flip_visibility(vis: Tensor, flip_indices: List[int]): + """Flip keypoints visibility for test-time augmentation. + + Args: + vis (Tensor): The keypoints visibility to flip. Should be a tensor + in shape [B, K] + flip_indices (List[int]): The indices of each keypoint's symmetric + keypoint + """ + assert vis.ndim == 2 + + vis = vis[:, flip_indices] + return vis + + def aggregate_heatmaps(heatmaps: List[Tensor], size: Optional[Tuple[int, int]], align_corners: bool = False, From 7ec801b54865756c340ce6e8044bf98014e7e8b0 Mon Sep 17 00:00:00 2001 From: Billccx_server Date: Sat, 10 Jun 2023 19:52:34 +0800 Subject: [PATCH 2/5] add acc_vis --- mmpose/models/heads/hybrid_heads/vis_head.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mmpose/models/heads/hybrid_heads/vis_head.py b/mmpose/models/heads/hybrid_heads/vis_head.py index eceaba607f..68cdceb2ef 100644 --- a/mmpose/models/heads/hybrid_heads/vis_head.py +++ b/mmpose/models/heads/hybrid_heads/vis_head.py @@ -168,6 +168,15 @@ def predict(self, return self.integrate(batch_vis, batch_pose) + def vis_accuracy(self, vis_pred_outputs, vis_labels): + probabilities = torch.sigmoid(torch.flatten(vis_pred_outputs)) + threshold = 0.5 + predictions = (probabilities >= threshold).int() + labels = torch.flatten(vis_labels) + correct = torch.sum(predictions == labels).item() + accuracy = correct / len(labels) + return torch.tensor(accuracy) + def loss(self, feats: Tuple[Tensor], batch_data_samples: OptSampleList, @@ -199,6 +208,10 @@ def loss(self, losses.update(loss_vis=loss_vis) + # calculate vis accuracy + acc_vis = self.vis_accuracy(vis_pred_outputs, vis_labels) + losses.update(acc_vis=acc_vis) + # calculate keypoints losses loss_kpt = self.pose_head.loss(feats, batch_data_samples) losses.update(loss_kpt) From 88f33c16a4bd1fff61bf0d9c36f633067e09f56b Mon Sep 17 00:00:00 2001 From: Billccx_server Date: Sun, 11 Jun 2023 14:54:14 +0800 Subject: [PATCH 3/5] refine vis_head --- mmpose/models/heads/hybrid_heads/vis_head.py | 24 ++++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/mmpose/models/heads/hybrid_heads/vis_head.py b/mmpose/models/heads/hybrid_heads/vis_head.py index 68cdceb2ef..68392edd63 100644 --- a/mmpose/models/heads/hybrid_heads/vis_head.py +++ b/mmpose/models/heads/hybrid_heads/vis_head.py @@ -30,7 +30,7 @@ class VisPredictHead(BaseHead): def __init__(self, pose_cfg: ConfigType, loss: ConfigType = dict( - type='BCEWithLogitsLoss', use_target_weight=True), + type='BCEWithLogitsLoss', use_target_weight=False), use_sigmoid: bool = False, init_cfg: OptConfigType = None): @@ -53,6 +53,8 @@ def __init__(self, self.pose_head = MODELS.build(pose_cfg) self.pose_cfg = pose_cfg + self.use_sigmoid = use_sigmoid + modules = [ nn.AdaptiveAvgPool2d(1), nn.Flatten(), @@ -96,7 +98,10 @@ def forward(self, feats: Tuple[Tensor]): def integrate(self, batch_vis: Tensor, pose_preds: Union[Tuple, Predictions]) -> InstanceList: - """Add keypoints visibility prediction to pose prediction.""" + """Add keypoints visibility prediction to pose prediction. + + Overwrite the original keypoint_scores. + """ if isinstance(pose_preds, tuple): pose_pred_instances, pose_pred_fields = pose_preds else: @@ -106,9 +111,8 @@ def integrate(self, batch_vis: Tensor, batch_vis_np = to_numpy(batch_vis, unzip=True) assert len(pose_pred_instances) == len(batch_vis_np) - for index in range(len(pose_pred_instances)): - pose_pred_instances[index].keypoint_visibility = batch_vis_np[ - index] + for index, _ in enumerate(pose_pred_instances): + pose_pred_instances[index].keypoint_scores = batch_vis_np[index] return pose_pred_instances, pose_pred_fields @@ -163,12 +167,16 @@ def predict(self, batch_vis.unsqueeze_(dim=1) # (B, N, K, D) + if not self.use_sigmoid: + batch_vis = torch.sigmoid(batch_vis) + batch_pose = self.pose_head.predict(feats, batch_data_samples, test_cfg) return self.integrate(batch_vis, batch_pose) def vis_accuracy(self, vis_pred_outputs, vis_labels): + """Calculate visibility prediction accuracy.""" probabilities = torch.sigmoid(torch.flatten(vis_pred_outputs)) threshold = 0.5 predictions = (probabilities >= threshold).int() @@ -195,16 +203,12 @@ def loss(self, """ vis_pred_outputs = self.vis_forward(feats) vis_labels = torch.cat([ - d.gt_instance_labels.keypoints_visible for d in batch_data_samples - ]) - keypoint_weights = torch.cat([ d.gt_instance_labels.keypoint_weights for d in batch_data_samples ]) # calculate vis losses losses = dict() - loss_vis = self.loss_module(vis_pred_outputs, vis_labels, - keypoint_weights) + loss_vis = self.loss_module(vis_pred_outputs, vis_labels) losses.update(loss_vis=loss_vis) From 1ee410ab6b9921e720394cdb01b6deef88e18b3a Mon Sep 17 00:00:00 2001 From: Billccx_server Date: Sat, 17 Jun 2023 21:05:22 +0800 Subject: [PATCH 4/5] refine BCELoss and add test_vis_head --- mmpose/models/heads/__init__.py | 7 +- mmpose/models/heads/hybrid_heads/vis_head.py | 5 +- mmpose/models/losses/__init__.py | 11 +- mmpose/models/losses/classification_loss.py | 51 +---- .../test_hybrid_heads/test_vis_head.py | 190 ++++++++++++++++++ 5 files changed, 209 insertions(+), 55 deletions(-) create mode 100644 tests/test_models/test_heads/test_hybrid_heads/test_vis_head.py diff --git a/mmpose/models/heads/__init__.py b/mmpose/models/heads/__init__.py index 75a626569b..e01f2269e3 100644 --- a/mmpose/models/heads/__init__.py +++ b/mmpose/models/heads/__init__.py @@ -3,7 +3,7 @@ from .coord_cls_heads import RTMCCHead, SimCCHead from .heatmap_heads import (AssociativeEmbeddingHead, CIDHead, CPMHead, HeatmapHead, MSPNHead, ViPNASHead) -from .hybrid_heads import DEKRHead +from .hybrid_heads import DEKRHead, VisPredictHead from .regression_heads import (DSNTHead, IntegralRegressionHead, RegressionHead, RLEHead, TemporalRegressionHead, TrajectoryRegressionHead) @@ -11,6 +11,7 @@ __all__ = [ 'BaseHead', 'HeatmapHead', 'CPMHead', 'MSPNHead', 'ViPNASHead', 'RegressionHead', 'IntegralRegressionHead', 'SimCCHead', 'RLEHead', - 'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'CIDHead', 'RTMCCHead', - 'TemporalRegressionHead', 'TrajectoryRegressionHead' + 'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'VisPredictHead', + 'CIDHead', 'RTMCCHead', 'TemporalRegressionHead', + 'TrajectoryRegressionHead' ] diff --git a/mmpose/models/heads/hybrid_heads/vis_head.py b/mmpose/models/heads/hybrid_heads/vis_head.py index 68392edd63..e9ea271ac5 100644 --- a/mmpose/models/heads/hybrid_heads/vis_head.py +++ b/mmpose/models/heads/hybrid_heads/vis_head.py @@ -21,7 +21,7 @@ class VisPredictHead(BaseHead): Args: pose_cfg (Config): Config to construct keypoints prediction head loss (Config): Config for visibility loss. Defaults to use - :class:`BCEWithLogitsLoss` + :class:`BCELoss` use_sigmoid (bool): Whether to use sigmoid activation function init_cfg (Config, optional): Config to control the initialization. See :attr:`default_init_cfg` for default settings @@ -30,7 +30,8 @@ class VisPredictHead(BaseHead): def __init__(self, pose_cfg: ConfigType, loss: ConfigType = dict( - type='BCEWithLogitsLoss', use_target_weight=False), + type='BCELoss', use_target_weight=False, + with_logits=True), use_sigmoid: bool = False, init_cfg: OptConfigType = None): diff --git a/mmpose/models/losses/__init__.py b/mmpose/models/losses/__init__.py index ad0c879897..f21071e156 100644 --- a/mmpose/models/losses/__init__.py +++ b/mmpose/models/losses/__init__.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ae_loss import AssociativeEmbeddingLoss -from .classification_loss import (BCELoss, BCEWithLogitsLoss, JSDiscretLoss, - KLDiscretLoss) +from .classification_loss import BCELoss, JSDiscretLoss, KLDiscretLoss from .heatmap_loss import (AdaptiveWingLoss, KeypointMSELoss, KeypointOHKMMSELoss) from .loss_wrappers import CombinedLoss, MultipleLossWrapper @@ -11,8 +10,8 @@ __all__ = [ 'KeypointMSELoss', 'KeypointOHKMMSELoss', 'SmoothL1Loss', 'WingLoss', - 'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BCEWithLogitsLoss', - 'BoneLoss', 'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss', - 'RLELoss', 'KLDiscretLoss', 'MultipleLossWrapper', 'JSDiscretLoss', - 'CombinedLoss', 'AssociativeEmbeddingLoss', 'SoftWeightSmoothL1Loss' + 'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BoneLoss', + 'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss', 'RLELoss', + 'KLDiscretLoss', 'MultipleLossWrapper', 'JSDiscretLoss', 'CombinedLoss', + 'AssociativeEmbeddingLoss', 'SoftWeightSmoothL1Loss' ] diff --git a/mmpose/models/losses/classification_loss.py b/mmpose/models/losses/classification_loss.py index 780a77573f..4605acabd3 100644 --- a/mmpose/models/losses/classification_loss.py +++ b/mmpose/models/losses/classification_loss.py @@ -14,53 +14,16 @@ class BCELoss(nn.Module): use_target_weight (bool): Option to use weighted loss. Different joint types may have different target weights. loss_weight (float): Weight of the loss. Default: 1.0. + with_logits (bool): Whether to use BCEWithLogitsLoss. Default: False. """ - def __init__(self, use_target_weight=False, loss_weight=1.): + def __init__(self, + use_target_weight=False, + loss_weight=1., + with_logits=False): super().__init__() - self.criterion = F.binary_cross_entropy - self.use_target_weight = use_target_weight - self.loss_weight = loss_weight - - def forward(self, output, target, target_weight=None): - """Forward function. - - Note: - - batch_size: N - - num_labels: K - - Args: - output (torch.Tensor[N, K]): Output classification. - target (torch.Tensor[N, K]): Target classification. - target_weight (torch.Tensor[N, K] or torch.Tensor[N]): - Weights across different labels. - """ - - if self.use_target_weight: - assert target_weight is not None - loss = self.criterion(output, target, reduction='none') - if target_weight.dim() == 1: - target_weight = target_weight[:, None] - loss = (loss * target_weight).mean() - else: - loss = self.criterion(output, target) - - return loss * self.loss_weight - - -@MODELS.register_module() -class BCEWithLogitsLoss(nn.Module): - """Binary Cross Entropy With Logits loss. - - Args: - use_target_weight (bool): Option to use weighted loss. - Different joint types may have different target weights. - loss_weight (float): Weight of the loss. Default: 1.0. - """ - - def __init__(self, use_target_weight=False, loss_weight=1.) -> None: - super().__init__() - self.criterion = F.binary_cross_entropy_with_logits + self.criterion = F.binary_cross_entropy if not with_logits\ + else F.binary_cross_entropy_with_logits self.use_target_weight = use_target_weight self.loss_weight = loss_weight diff --git a/tests/test_models/test_heads/test_hybrid_heads/test_vis_head.py b/tests/test_models/test_heads/test_hybrid_heads/test_vis_head.py new file mode 100644 index 0000000000..a6aecc2852 --- /dev/null +++ b/tests/test_models/test_heads/test_hybrid_heads/test_vis_head.py @@ -0,0 +1,190 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest +from typing import List, Tuple +from unittest import TestCase + +import torch +from mmengine.structures import InstanceData, PixelData +from torch import nn + +from mmpose.models.heads import VisPredictHead +from mmpose.testing import get_packed_inputs + + +class TestVisPredictHead(TestCase): + + def _get_feats( + self, + batch_size: int = 2, + feat_shapes: List[Tuple[int, int, int]] = [(32, 8, 6)], + ): + feats = [ + torch.rand((batch_size, ) + shape, dtype=torch.float32) + for shape in feat_shapes + ] + return feats + + def test_init(self): + codec = dict( + type='MSRAHeatmap', + input_size=(192, 256), + heatmap_size=(48, 64), + sigma=2.) + + head = VisPredictHead( + pose_cfg=dict( + type='HeatmapHead', + in_channels=32, + out_channels=17, + deconv_out_channels=None, + loss=dict(type='KeypointMSELoss', use_target_weight=True), + decoder=codec)) + + self.assertTrue(isinstance(head.vis_head, nn.Sequential)) + self.assertEqual(head.vis_head[2].weight.shape, (17, 32)) + self.assertIsNotNone(head.pose_head) + + def test_forward(self): + + codec = dict( + type='MSRAHeatmap', + input_size=(192, 256), + heatmap_size=(48, 64), + sigma=2) + + head = VisPredictHead( + pose_cfg=dict( + type='HeatmapHead', + in_channels=32, + out_channels=17, + deconv_out_channels=None, + loss=dict(type='KeypointMSELoss', use_target_weight=True), + decoder=codec)) + + feats = [torch.rand(1, 32, 128, 128)] + output_pose, output_vis = head.forward(feats) + + self.assertIsInstance(output_pose, torch.Tensor) + self.assertEqual(output_pose.shape, (1, 17, 128, 128)) + + self.assertIsInstance(output_vis, torch.Tensor) + self.assertEqual(output_vis.shape, (1, 17)) + + def test_predict(self): + + codec = dict( + type='MSRAHeatmap', + input_size=(192, 256), + heatmap_size=(48, 64), + sigma=2.) + + head = VisPredictHead( + pose_cfg=dict( + type='HeatmapHead', + in_channels=32, + out_channels=17, + deconv_out_channels=None, + loss=dict(type='KeypointMSELoss', use_target_weight=True), + decoder=codec)) + + feats = self._get_feats(batch_size=2, feat_shapes=[(32, 128, 128)]) + batch_data_samples = get_packed_inputs(batch_size=2)['data_samples'] + + preds, _ = head.predict(feats, batch_data_samples) + + self.assertTrue(len(preds), 2) + self.assertIsInstance(preds[0], InstanceData) + self.assertEqual(preds[0].keypoints.shape, + batch_data_samples[0].gt_instances.keypoints.shape) + self.assertEqual( + preds[0].keypoint_scores.shape, + batch_data_samples[0].gt_instance_labels.keypoint_weights.shape) + + # output heatmap + head = VisPredictHead( + pose_cfg=dict( + type='HeatmapHead', + in_channels=32, + out_channels=17, + decoder=codec)) + feats = self._get_feats(batch_size=2, feat_shapes=[(32, 8, 6)]) + batch_data_samples = get_packed_inputs(batch_size=2)['data_samples'] + _, pred_heatmaps = head.predict( + feats, batch_data_samples, test_cfg=dict(output_heatmaps=True)) + + self.assertIsInstance(pred_heatmaps[0], PixelData) + self.assertEqual(pred_heatmaps[0].heatmaps.shape, (17, 64, 48)) + + def test_tta(self): + # flip test: vis and heatmap + decoder_cfg = dict( + type='MSRAHeatmap', + input_size=(192, 256), + heatmap_size=(48, 64), + sigma=2.) + + head = VisPredictHead( + pose_cfg=dict( + type='HeatmapHead', + in_channels=32, + out_channels=17, + decoder=decoder_cfg)) + + feats = self._get_feats(batch_size=2, feat_shapes=[(32, 8, 6)]) + batch_data_samples = get_packed_inputs(batch_size=2)['data_samples'] + preds, _ = head.predict([feats, feats], + batch_data_samples, + test_cfg=dict( + flip_test=True, + flip_mode='heatmap', + shift_heatmap=True, + )) + + self.assertTrue(len(preds), 2) + self.assertIsInstance(preds[0], InstanceData) + self.assertEqual(preds[0].keypoints.shape, + batch_data_samples[0].gt_instances.keypoints.shape) + self.assertEqual( + preds[0].keypoint_scores.shape, + batch_data_samples[0].gt_instance_labels.keypoint_weights.shape) + + def test_loss(self): + head = VisPredictHead( + pose_cfg=dict( + type='HeatmapHead', + in_channels=32, + out_channels=17, + )) + + feats = self._get_feats(batch_size=2, feat_shapes=[(32, 8, 6)]) + batch_data_samples = get_packed_inputs(batch_size=2)['data_samples'] + losses = head.loss(feats, batch_data_samples) + self.assertIsInstance(losses['loss_kpt'], torch.Tensor) + self.assertEqual(losses['loss_kpt'].shape, torch.Size(())) + self.assertIsInstance(losses['acc_pose'], torch.Tensor) + + self.assertIsInstance(losses['loss_vis'], torch.Tensor) + self.assertEqual(losses['loss_vis'].shape, torch.Size(())) + self.assertIsInstance(losses['acc_vis'], torch.Tensor) + + head = VisPredictHead( + pose_cfg=dict( + type='HeatmapHead', + in_channels=32, + out_channels=17, + )) + + feats = self._get_feats(batch_size=2, feat_shapes=[(32, 8, 6)]) + batch_data_samples = get_packed_inputs(batch_size=2)['data_samples'] + losses = head.loss(feats, batch_data_samples) + self.assertIsInstance(losses['loss_kpt'], torch.Tensor) + self.assertEqual(losses['loss_kpt'].shape, torch.Size(())) + self.assertIsInstance(losses['acc_pose'], torch.Tensor) + + self.assertIsInstance(losses['loss_vis'], torch.Tensor) + self.assertEqual(losses['loss_vis'].shape, torch.Size(())) + self.assertIsInstance(losses['acc_vis'], torch.Tensor) + + +if __name__ == '__main__': + unittest.main() From bd1840150365c8fa1afaeab5c16b4545bec3d44a Mon Sep 17 00:00:00 2001 From: lupeng Date: Mon, 19 Jun 2023 09:04:16 +0800 Subject: [PATCH 5/5] fix ut --- mmpose/datasets/transforms/formatting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py index eb421f12ad..05aeef179f 100644 --- a/mmpose/datasets/transforms/formatting.py +++ b/mmpose/datasets/transforms/formatting.py @@ -210,7 +210,8 @@ def transform(self, results: dict) -> dict: if key in results: # For pose-lifting, store only target-related fields if 'lifting_target_label' in results and key in { - 'keypoint_labels', 'keypoint_weights' + 'keypoint_labels', 'keypoint_weights', + 'transformed_keypoints_visible' }: continue if isinstance(results[key], list):