diff --git a/configs/maskformer/README.md b/configs/maskformer/README.md new file mode 100644 index 00000000000..ce1384ae77e --- /dev/null +++ b/configs/maskformer/README.md @@ -0,0 +1,60 @@ +# Per-Pixel Classification is Not All You Need for Semantic Segmentation + +## Abstract + +Modern approaches typically formulate semantic segmentation as a per-pixel classification +task, while instance-level segmentation is handled with an alternative mask +classification. Our key insight: mask classification is sufficiently general to solve +both semantic- and instance-level segmentation tasks in a unified manner using +the exact same model, loss, and training procedure. Following this observation, +we propose MaskFormer, a simple mask classification model which predicts a +set of binary masks, each associated with a single global class label prediction. +Overall, the proposed mask classification-based method simplifies the landscape +of effective approaches to semantic and panoptic segmentation tasks and shows +excellent empirical results. In particular, we observe that MaskFormer outperforms +per-pixel classification baselines when the number of classes is large. Our mask +classification-based method outperforms both current state-of-the-art semantic +(55.6 mIoU on ADE20K) and panoptic segmentation (52.7 PQ on COCO) models. + +<div align=center> +<img src="https://camo.githubusercontent.com/29fb22298d506ce176caad3006a7b05ef2603ca12cece6c788b7e73c046e8bc9/68747470733a2f2f626f77656e63303232312e6769746875622e696f2f696d616765732f6d61736b666f726d65722e706e67" height="300"/> +</div> + +## Citation + +``` +@inproceedings{cheng2021maskformer, + title={Per-Pixel Classification is Not All You Need for Semantic Segmentation}, + author={Bowen Cheng and Alexander G. Schwing and Alexander Kirillov}, + journal={NeurIPS}, + year={2021} +} +``` + +## Dataset + +MaskFormer requires COCO and [COCO-panoptic](http://images.cocodataset.org/annotations/panoptic_annotations_trainval2017.zip) dataset for training and evaluation. You need to download and extract it in the COCO dataset path. +The directory should be like this. + +```none +mmdetection +├── mmdet +├── tools +├── configs +├── data +│ ├── coco +│ │ ├── annotations +│ │ │ ├── panoptic_train2017.json +│ │ │ ├── panoptic_train2017 +│ │ │ ├── panoptic_val2017.json +│ │ │ ├── panoptic_val2017 +│ │ ├── train2017 +│ │ ├── val2017 +│ │ ├── test2017 +``` + +## Results and Models + +| Backbone | style | Lr schd | Mem (GB) | Inf time (fps) | PQ | SQ | RQ | PQ_th | SQ_th | RQ_th | PQ_st | SQ_st | RQ_st | Config | Download | detail | +| :------: | :-----: | :-----: | :------: | :------------: | :-: | :-: | :-: | :---: | :---: | :---: | :---: | :---: | :---: | :---------------------------------------------------------------------------------------------------------------------: | :----------------------: | :---: | +| R-50 | pytorch | 75e | | | | | | | | | | | | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/maskformer/maskformer_r50_mstrain_16x1_75e_coco.py) | | This version was mentioned in Table XI, in paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) | diff --git a/configs/maskformer/maskformer_r50_mstrain_16x1_75e_coco.py b/configs/maskformer/maskformer_r50_mstrain_16x1_75e_coco.py new file mode 100644 index 00000000000..c9d92450570 --- /dev/null +++ b/configs/maskformer/maskformer_r50_mstrain_16x1_75e_coco.py @@ -0,0 +1,220 @@ +_base_ = [ + '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py' +] + +model = dict( + type='MaskFormer', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + panoptic_head=dict( + type='MaskFormerHead', + in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside + feat_channels=256, + out_channels=256, + num_things_classes=80, + num_stuff_classes=53, + num_queries=100, + pixel_decoder=dict( + type='TransformerEncoderPixelDecoder', + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.1, + proj_drop=0.1, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.1, + dropout_layer=None, + add_identity=True), + operation_order=('self_attn', 'norm', 'ffn', 'norm'), + norm_cfg=dict(type='LN'), + init_cfg=None, + batch_first=False), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=6, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.1, + proj_drop=0.1, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.1, + dropout_layer=None, + add_identity=True), + # the following parameter was not used, + # just make current api happy + feedforward_channels=2048, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None), + loss_cls=dict( + type='CrossEntropyLoss', + bg_cls_weight=0.1, + use_sigmoid=False, + loss_weight=1.0, + reduction='mean', + class_weight=1.0), + loss_mask=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=20.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaskHungarianAssigner', + cls_cost=dict(type='ClassificationCost', weight=1.0), + mask_cost=dict( + type='FocalLossCost', weight=20.0, binary_input=True), + dice_cost=dict( + type='DiceCost', weight=1.0, pred_act=True, eps=1.0)), + sampler=dict(type='MaskPseudoSampler')), + test_cfg=dict(object_mask_thr=0.8, iou_thr=0.8), + # pretrained=None, + init_cfg=None) + +# dataset settings +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadPanopticAnnotations', + with_bbox=True, + with_mask=True, + with_seg=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict( + type='AutoAugment', + policies=[[ + dict( + type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + multiscale_mode='value', + keep_ratio=True) + ], + [ + dict( + type='Resize', + img_scale=[(400, 1333), (500, 1333), (600, 1333)], + multiscale_mode='value', + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), + (576, 1333), (608, 1333), (640, 1333), + (672, 1333), (704, 1333), (736, 1333), + (768, 1333), (800, 1333)], + multiscale_mode='value', + override=True, + keep_ratio=True) + ]]), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=1), + dict(type='DefaultFormatBundle'), + dict( + type='Collect', + keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=1), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=1, + workers_per_gpu=1, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) + +# optimizer +optimizer = dict( + type='AdamW', + lr=0.0001, + weight_decay=0.0001, + eps=1e-8, + betas=(0.9, 0.999), + paramwise_cfg=dict( + custom_keys={ + 'backbone': dict(lr_mult=0.1, decay_mult=1.0), + 'query_embed': dict(lr_mult=1.0, decay_mult=0.0) + }, + norm_decay_mult=0.0)) +optimizer_config = dict(grad_clip=dict(max_norm=0.01, norm_type=2)) + +# learning policy +lr_config = dict( + policy='step', + gamma=0.1, + by_epoch=True, + step=[50], + warmup='linear', + warmup_by_epoch=False, + warmup_ratio=1.0, # no warmup + warmup_iters=10) +runner = dict(type='EpochBasedRunner', max_epochs=75) diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py index a182686491d..5eaf7fa3af6 100644 --- a/mmdet/core/bbox/assigners/__init__.py +++ b/mmdet/core/bbox/assigners/__init__.py @@ -6,6 +6,7 @@ from .center_region_assigner import CenterRegionAssigner from .grid_assigner import GridAssigner from .hungarian_assigner import HungarianAssigner +from .mask_hungarian_assigner import MaskHungarianAssigner from .max_iou_assigner import MaxIoUAssigner from .point_assigner import PointAssigner from .region_assigner import RegionAssigner @@ -17,5 +18,5 @@ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner', 'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner', - 'TaskAlignedAssigner' + 'TaskAlignedAssigner', 'MaskHungarianAssigner' ] diff --git a/mmdet/core/bbox/assigners/mask_hungarian_assigner.py b/mmdet/core/bbox/assigners/mask_hungarian_assigner.py new file mode 100644 index 00000000000..ef0f35831d6 --- /dev/null +++ b/mmdet/core/bbox/assigners/mask_hungarian_assigner.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdet.core.bbox.builder import BBOX_ASSIGNERS +from mmdet.core.bbox.match_costs.builder import build_match_cost +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + +try: + from scipy.optimize import linear_sum_assignment +except ImportError: + linear_sum_assignment = None + + +@BBOX_ASSIGNERS.register_module() +class MaskHungarianAssigner(BaseAssigner): + """Computes one-to-one matching between predictions and ground truth for + mask. + + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of three components: + classification cost, mask focal cost and mask dice cost. The + targets don't include the no_object, so generally there are more + predictions than targets. After the one-to-one matching, the un-matched + are treated as backgrounds. Thus each query prediction will be assigned + with `0` or a positive integer indicating the ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + cls_cost (obj:`mmcv.ConfigDict` | dict): Classification cost config. + mask_cost (obj:`mmcv.ConfigDict` | dict): Mask cost config. + dice_cost (obj:`mmcv.ConfigDict` | dict): Dice cost config. + """ + + def __init__(self, + cls_cost=dict(type='ClassificationCost', weight=1.0), + mask_cost=dict( + type='FocalLossCost', weight=1.0, binary_input=True), + dice_cost=dict(type='DiceCost', weight=1.0)): + self.cls_cost = build_match_cost(cls_cost) + self.mask_cost = build_match_cost(mask_cost) + self.dice_cost = build_match_cost(dice_cost) + + def assign(self, + cls_pred, + mask_pred, + gt_labels, + gt_mask, + img_meta, + gt_bboxes_ignore=None, + eps=1e-7): + """Computes one-to-one matching based on the weighted costs. + + Args: + cls_pred (Tensor): Class prediction in shape + (num_query, cls_out_channels). + mask_pred (Tensor): Mask prediction in shape (num_query, H, W). + gt_labels (Tensor): Label of 'gt_mask'in shape = (num_gt, ). + gt_mask (Tensor): Ground truth mask in shape = (num_gt, H, W). + img_meta (dict): Meta information for current image. + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`. Default None. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + assert gt_bboxes_ignore is None, \ + 'Only case when gt_bboxes_ignore is None is supported.' + num_gt, num_query = gt_labels.shape[0], cls_pred.shape[0] + + # 1. assign -1 by default + assigned_gt_inds = cls_pred.new_full((num_query, ), + -1, + dtype=torch.long) + assigned_labels = cls_pred.new_full((num_query, ), + -1, + dtype=torch.long) + if num_gt == 0 or num_query == 0: + # No ground truth or boxes, return empty assignment + if num_gt == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult( + num_gt, assigned_gt_inds, None, labels=assigned_labels) + + # 2. compute the weighted costs + # classification and maskcost. + if self.cls_cost.weight != 0 and cls_pred is not None: + cls_cost = self.cls_cost(cls_pred, gt_labels) + else: + cls_cost = 0 + + if self.mask_cost.weight != 0: + # mask_pred shape = [num_query, h, w] + # gt_mask shape = [num_gt, h, w] + # mask_cost shape = [num_query, num_gt] + mask_cost = self.mask_cost(mask_pred, gt_mask) + else: + mask_cost = 0 + + if self.dice_cost.weight != 0: + dice_cost = self.dice_cost(mask_pred, gt_mask) + else: + dice_cost = 0 + cost = cls_cost + mask_cost + dice_cost + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' + 'to install scipy first.') + + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + matched_row_inds = torch.from_numpy(matched_row_inds).to( + cls_pred.device) + matched_col_inds = torch.from_numpy(matched_col_inds).to( + cls_pred.device) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] + return AssignResult( + num_gt, assigned_gt_inds, None, labels=assigned_labels) diff --git a/mmdet/core/bbox/match_costs/__init__.py b/mmdet/core/bbox/match_costs/__init__.py index 3f79a1ce36a..81ee588571e 100644 --- a/mmdet/core/bbox/match_costs/__init__.py +++ b/mmdet/core/bbox/match_costs/__init__.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .builder import build_match_cost -from .match_cost import BBoxL1Cost, ClassificationCost, FocalLossCost, IoUCost +from .match_cost import (BBoxL1Cost, ClassificationCost, DiceCost, + FocalLossCost, IoUCost) __all__ = [ 'build_match_cost', 'ClassificationCost', 'BBoxL1Cost', 'IoUCost', - 'FocalLossCost' + 'FocalLossCost', 'DiceCost' ] diff --git a/mmdet/core/bbox/match_costs/match_cost.py b/mmdet/core/bbox/match_costs/match_cost.py index d5ce4ca9f59..3c0a164b3c8 100644 --- a/mmdet/core/bbox/match_costs/match_cost.py +++ b/mmdet/core/bbox/match_costs/match_cost.py @@ -35,9 +35,9 @@ def __call__(self, bbox_pred, gt_bboxes): Args: bbox_pred (Tensor): Predicted boxes with normalized coordinates (cx, cy, w, h), which are all in range [0, 1]. Shape - [num_query, 4]. + (num_query, 4). gt_bboxes (Tensor): Ground truth boxes with normalized - coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. + coordinates (x1, y1, x2, y2). Shape (num_gt, 4). Returns: torch.Tensor: bbox_cost value with weight @@ -59,6 +59,8 @@ class FocalLossCost: alpha (int | float, optional): focal_loss alpha gamma (int | float, optional): focal_loss gamma eps (float, optional): default 1e-12 + binary_input (bool, optional): Whether the input is binary, + default False. Examples: >>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost @@ -74,17 +76,23 @@ class FocalLossCost: [-0.1950, -0.1207, -0.2626]]) """ - def __init__(self, weight=1., alpha=0.25, gamma=2, eps=1e-12): + def __init__(self, + weight=1., + alpha=0.25, + gamma=2, + eps=1e-12, + binary_input=False): self.weight = weight self.alpha = alpha self.gamma = gamma self.eps = eps + self.binary_input = binary_input - def __call__(self, cls_pred, gt_labels): + def _focal_loss_cost(self, cls_pred, gt_labels): """ Args: cls_pred (Tensor): Predicted classification logits, shape - [num_query, num_class]. + (num_query, num_class). gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). Returns: @@ -95,9 +103,50 @@ def __call__(self, cls_pred, gt_labels): 1 - self.alpha) * cls_pred.pow(self.gamma) pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( 1 - cls_pred).pow(self.gamma) + cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels] return cls_cost * self.weight + def _mask_focal_loss_cost(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classfication logits + in shape (num_query, d1, ..., dn), dtype=torch.float32. + gt_labels (Tensor): Ground truth in shape (num_gt, d1, ..., dn), + dtype=torch.long. Labels should be binary. + + Returns: + Tensor: Focal cost matrix with weight in shape\ + (num_query, num_gt). + """ + cls_pred = cls_pred.flatten(1) + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + + cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \ + torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels)) + return cls_cost / n * self.weight + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classfication logits. + gt_labels (Tensor)): Labels. + + Returns: + Tensor: Focal cost matrix with weight in shape\ + (num_query, num_gt). + """ + if self.binary_input: + return self._mask_focal_loss_cost(cls_pred, gt_labels) + else: + return self._focal_loss_cost(cls_pred, gt_labels) + @MATCH_COST.register_module() class ClassificationCost: @@ -128,7 +177,7 @@ def __call__(self, cls_pred, gt_labels): """ Args: cls_pred (Tensor): Predicted classification logits, shape - [num_query, num_class]. + (num_query, num_class). gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). Returns: @@ -170,9 +219,9 @@ def __call__(self, bboxes, gt_bboxes): """ Args: bboxes (Tensor): Predicted boxes with unnormalized coordinates - (x1, y1, x2, y2). Shape [num_query, 4]. + (x1, y1, x2, y2). Shape (num_query, 4). gt_bboxes (Tensor): Ground truth boxes with unnormalized - coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. + coordinates (x1, y1, x2, y2). Shape (num_gt, 4). Returns: torch.Tensor: iou_cost value with weight @@ -183,3 +232,52 @@ def __call__(self, bboxes, gt_bboxes): # The 1 is a constant that doesn't change the matching, so omitted. iou_cost = -overlaps return iou_cost * self.weight + + +@MATCH_COST.register_module() +class DiceCost: + """Cost of mask assignments based on dice losses. + + Args: + weight (int | float, optional): loss_weight. Defaults to 1. + pred_act (bool, optional): Whether to apply sigmoid to mask_pred. + Defaults to False. + eps (float, optional): default 1e-12. + """ + + def __init__(self, weight=1., pred_act=False, eps=1e-3): + self.weight = weight + self.pred_act = pred_act + self.eps = eps + + def binary_mask_dice_loss(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction in shape (num_query, *). + gt_masks (Tensor): Ground truth in shape (num_gt, *) + store 0 or 1, 0 for negative class and 1 for + positive class. + + Returns: + Tensor: Dice cost matrix in shape (num_query, num_gt). + """ + mask_preds = mask_preds.flatten(1) + gt_masks = gt_masks.flatten(1).float() + numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks) + denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :] + loss = 1 - (numerator + self.eps) / (denominator + self.eps) + return loss + + def __call__(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction logits in shape (num_query, *) + gt_masks (Tensor): Ground truth in shape (num_gt, *) + + Returns: + Tensor: Dice cost matrix with weight in shape (num_query, num_gt). + """ + if self.pred_act: + mask_preds = mask_preds.sigmoid() + dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks) + return dice_cost * self.weight diff --git a/mmdet/core/bbox/samplers/__init__.py b/mmdet/core/bbox/samplers/__init__.py index b9e83913eaa..f58505b59dc 100644 --- a/mmdet/core/bbox/samplers/__init__.py +++ b/mmdet/core/bbox/samplers/__init__.py @@ -3,6 +3,8 @@ from .combined_sampler import CombinedSampler from .instance_balanced_pos_sampler import InstanceBalancedPosSampler from .iou_balanced_neg_sampler import IoUBalancedNegSampler +from .mask_pseudo_sampler import MaskPseudoSampler +from .mask_sampling_result import MaskSamplingResult from .ohem_sampler import OHEMSampler from .pseudo_sampler import PseudoSampler from .random_sampler import RandomSampler @@ -12,5 +14,6 @@ __all__ = [ 'BaseSampler', 'PseudoSampler', 'RandomSampler', 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', - 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler' + 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'MaskPseudoSampler', + 'MaskSamplingResult' ] diff --git a/mmdet/core/bbox/samplers/mask_pseudo_sampler.py b/mmdet/core/bbox/samplers/mask_pseudo_sampler.py new file mode 100644 index 00000000000..b5f69658d02 --- /dev/null +++ b/mmdet/core/bbox/samplers/mask_pseudo_sampler.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""copy from +https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py.""" + +import torch + +from mmdet.core.bbox.builder import BBOX_SAMPLERS +from .base_sampler import BaseSampler +from .mask_sampling_result import MaskSamplingResult + + +@BBOX_SAMPLERS.register_module() +class MaskPseudoSampler(BaseSampler): + """A pseudo sampler that does not do sampling actually.""" + + def __init__(self, **kwargs): + pass + + def _sample_pos(self, **kwargs): + """Sample positive samples.""" + raise NotImplementedError + + def _sample_neg(self, **kwargs): + """Sample negative samples.""" + raise NotImplementedError + + def sample(self, assign_result, masks, gt_masks, **kwargs): + """Directly returns the positive and negative indices of samples. + + Args: + assign_result (:obj:`AssignResult`): Assigned results + masks (torch.Tensor): Bounding boxes + gt_masks (torch.Tensor): Ground truth boxes + Returns: + :obj:`SamplingResult`: sampler results + """ + pos_inds = torch.nonzero( + assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero( + assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8) + sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, + gt_masks, assign_result, gt_flags) + return sampling_result diff --git a/mmdet/core/bbox/samplers/mask_sampling_result.py b/mmdet/core/bbox/samplers/mask_sampling_result.py new file mode 100644 index 00000000000..3d109432260 --- /dev/null +++ b/mmdet/core/bbox/samplers/mask_sampling_result.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""copy from +https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py.""" + +import torch + +from .sampling_result import SamplingResult + + +class MaskSamplingResult(SamplingResult): + """Mask sampling result.""" + + def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, + gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_masks = masks[pos_inds] + self.neg_masks = masks[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_masks.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_masks.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_masks = torch.empty_like(gt_masks) + else: + self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] + + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def masks(self): + """torch.Tensor: concatenated positive and negative boxes""" + return torch.cat([self.pos_masks, self.neg_masks]) + + def __nice__(self): + data = self.info.copy() + data['pos_masks'] = data.pop('pos_masks').shape + data['neg_masks'] = data.pop('neg_masks').shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = ' ' + ',\n '.join(parts) + return '{\n' + body + '\n}' + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + 'pos_inds': self.pos_inds, + 'neg_inds': self.neg_inds, + 'pos_masks': self.pos_masks, + 'neg_masks': self.neg_masks, + 'pos_is_gt': self.pos_is_gt, + 'num_gts': self.num_gts, + 'pos_assigned_gt_inds': self.pos_assigned_gt_inds, + } diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index 81d6ec2f74d..e931e608028 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -20,6 +20,7 @@ from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead from .lad_head import LADHead from .ld_head import LDHead +from .maskformer_head import MaskFormerHead from .nasfcos_head import NASFCOSHead from .paa_head import PAAHead from .pisa_retinanet_head import PISARetinaHead @@ -49,5 +50,5 @@ 'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead', 'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead', 'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead', - 'DecoupledSOLOLightHead', 'LADHead', 'TOODHead' + 'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead' ] diff --git a/mmdet/models/dense_heads/maskformer_head.py b/mmdet/models/dense_heads/maskformer_head.py new file mode 100644 index 00000000000..3cd060e53b6 --- /dev/null +++ b/mmdet/models/dense_heads/maskformer_head.py @@ -0,0 +1,666 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init +from mmcv.cnn.bricks.transformer import (build_positional_encoding, + build_transformer_layer_sequence) +from mmcv.runner import force_fp32 + +from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean +from mmdet.core.evaluation import INSTANCE_OFFSET +from mmdet.models.utils import preprocess_panoptic_gt +from ..builder import HEADS, build_loss +from .anchor_free_head import AnchorFreeHead + + +@HEADS.register_module() +class MaskFormerHead(AnchorFreeHead): + """Implements the MaskFormer head. + + See `Per-Pixel Classification is Not All You Need for Semantic + Segmentation <https://arxiv.org/pdf/2107.06278>`_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for feature. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer. + pixel_decoder (obj:`mmcv.ConfigDict`|dict): Config for pixel decoder. + Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add a layer + to change the embed_dim of tranformer encoder in pixel decoder to + the embed_dim of transformer decoder. Defaults to False. + transformer_decoder (obj:`mmcv.ConfigDict`|dict): Config for + transformer decoder. Defaults to None. + positional_encoding (obj:`mmcv.ConfigDict`|dict): Config for + transformer decoder position encoding. Defaults to None. + loss_cls (obj:`mmcv.ConfigDict`|dict): Config of the classification + loss. Defaults to `CrossEntropyLoss`. + loss_mask (obj:`mmcv.ConfigDict`|dict): Config of the mask loss. + Defaults to `FocalLoss`. + loss_dice (obj:`mmcv.ConfigDict`|dict): Config of the dice loss. + Defaults to `DiceLoss`. + train_cfg (obj:`mmcv.ConfigDict`|dict): Training config of Maskformer + head. + test_cfg (obj:`mmcv.ConfigDict`|dict): Testing config of Maskformer + head. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels, + feat_channels, + out_channels, + num_things_classes=80, + num_stuff_classes=53, + num_queries=100, + pixel_decoder=None, + enforce_decoder_input_project=False, + transformer_decoder=None, + positional_encoding=None, + loss_cls=dict( + type='CrossEntropyLoss', + bg_cls_weight=0.1, + use_sigmoid=False, + loss_weight=1.0, + class_weight=1.0), + loss_mask=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=20.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + naive_dice=True, + loss_weight=1.0), + train_cfg=None, + test_cfg=None, + init_cfg=None, + **kwargs): + super(AnchorFreeHead, self).__init__(init_cfg) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + + pixel_decoder.update( + in_channels=in_channels, + feat_channels=feat_channels, + out_channels=out_channels) + self.pixel_decoder = build_plugin_layer(pixel_decoder)[1] + self.transformer_decoder = build_transformer_layer_sequence( + transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + pixel_decoder_type = pixel_decoder.get('type') + if pixel_decoder_type == 'PixelDecoder' and ( + self.decoder_embed_dims != in_channels[-1] + or enforce_decoder_input_project): + self.decoder_input_proj = Conv2d( + in_channels[-1], self.decoder_embed_dims, kernel_size=1) + else: + self.decoder_input_proj = nn.Identity() + self.decoder_pe = build_positional_encoding(positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, out_channels) + + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels)) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + assert 'assigner' in train_cfg, 'assigner should be provided '\ + 'when train_cfg is set.' + assigner = train_cfg['assigner'] + self.assigner = build_assigner(assigner) + sampler_cfg = dict(type='MaskPseudoSampler') + self.sampler = build_sampler(sampler_cfg, context=self) + + self.bg_cls_weight = 0 + class_weight = loss_cls.get('class_weight', None) + if class_weight is not None and (self.__class__ is MaskFormerHead): + assert isinstance(class_weight, float), 'Expected ' \ + 'class_weight to have type float. Found ' \ + f'{type(class_weight)}.' + # NOTE following the official MaskFormerHead repo, bg_cls_weight + # means relative classification weight of the VOID class. + bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight) + assert isinstance(bg_cls_weight, float), 'Expected ' \ + 'bg_cls_weight to have type float. Found ' \ + f'{type(bg_cls_weight)}.' + class_weight = torch.ones(self.num_classes + 1) * class_weight + # set VOID class as the last indice + class_weight[self.num_classes] = bg_cls_weight + loss_cls.update({'class_weight': class_weight}) + if 'bg_cls_weight' in loss_cls: + loss_cls.pop('bg_cls_weight') + self.bg_cls_weight = bg_cls_weight + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.loss_dice = build_loss(loss_dice) + + def init_weights(self): + if isinstance(self.decoder_input_proj, Conv2d): + caffe2_xavier_init(self.decoder_input_proj, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def preprocess_gt(self, gt_labels_list, gt_masks_list, gt_semantic_segs): + """Preprocess the ground truth for all images. + + Args: + gt_labels_list (list[Tensor]): Each is ground truth + labels of each bbox, with shape (num_gts, ). + gt_masks_list (list[BitmapMasks]): Each is ground truth + masks of each instances of a image, shape + (num_gts, h, w). + gt_semantic_seg (Tensor): Ground truth of semantic + segmentation with the shape (batch_size, n, h, w). + [0, num_thing_class - 1] means things, + [num_thing_class, num_class-1] means stuff, + 255 means VOID. + target_shape (tuple[int]): Shape of output mask_preds. + Resize the masks to shape of mask_preds. + + Returns: + tuple: a tuple containing the following targets. + + - labels (list[Tensor]): Ground truth class indices for all\ + images. Each with shape (n, ), n is the sum of number\ + of stuff type and number of instance in a image. + - masks (list[Tensor]): Ground truth mask for each image, each\ + with shape (n, h, w). + """ + num_things_list = [self.num_things_classes] * len(gt_labels_list) + num_stuff_list = [self.num_stuff_classes] * len(gt_labels_list) + + targets = multi_apply(preprocess_panoptic_gt, gt_labels_list, + gt_masks_list, gt_semantic_segs, num_things_list, + num_stuff_list) + labels, masks = targets + return labels, masks + + def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, + gt_masks_list, img_metas): + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape (num_queries, + cls_out_channels). + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape (num_queries, h, w). + gt_labels_list (list[Tensor]): Ground truth class indices for all + images. Each with shape (n, ), n is the sum of number of stuff + type and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[list[Tensor]]: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images.\ + Each with shape (num_queries, ). + - label_weights_list (list[Tensor]): Label weights of all\ + images. Each with shape (num_queries, ). + - mask_targets_list (list[Tensor]): Mask targets of all\ + images. Each with shape (num_queries, h, w). + - mask_weights_list (list[Tensor]): Mask weights of all\ + images. Each with shape (num_queries, ). + - num_total_pos (int): Number of positive samples in all\ + images. + - num_total_neg (int): Number of negative samples in all\ + images. + """ + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + pos_inds_list, + neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list, + mask_preds_list, gt_labels_list, + gt_masks_list, img_metas) + + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, mask_targets_list, + mask_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, + img_metas): + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_labels (Tensor): Ground truth class indices for one image with + shape (n, ). n is the sum of number of stuff type and number + of instance in a image. + gt_masks (Tensor): Ground truth mask for each image, each with + shape (n, h, w). + img_metas (dict): Image informtation. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + target_shape = mask_pred.shape[-2:] + if gt_masks.shape[0] > 0: + gt_masks_downsampled = F.interpolate( + gt_masks.unsqueeze(1).float(), target_shape, + mode='nearest').squeeze(1).long() + else: + gt_masks_downsampled = gt_masks + + # assign and sample + assign_result = self.assigner.assign(cls_score, mask_pred, gt_labels, + gt_masks_downsampled, img_metas) + sampling_result = self.sampler.sample(assign_result, mask_pred, + gt_masks) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones(self.num_queries) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((self.num_queries, )) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, + neg_inds) + + @force_fp32(apply_to=('all_cls_scores', 'all_mask_preds')) + def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, + gt_masks_list, img_metas): + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape (num_decoder, batch_size, num_queries, + cls_out_channels). + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape (num_decoder, batch_size, num_queries, h, w). + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (n, ). n is the sum of number of stuff type + and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image with + shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + losses_cls, losses_mask, losses_dice = multi_apply( + self.loss_single, all_cls_scores, all_mask_preds, + all_gt_labels_list, all_gt_masks_list, img_metas_list) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_mask'] = losses_mask[-1] + loss_dict['loss_dice'] = losses_dice[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_mask_i, loss_dice_i in zip( + losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i + loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i + num_dec_layer += 1 + return loss_dict + + def loss_single(self, cls_scores, mask_preds, gt_labels_list, + gt_masks_list, img_metas): + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + gt_labels_list (list[Tensor]): Ground truth class indices for each + image, each with shape (n, ). n is the sum of number of stuff + types and number of instances in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single decoder\ + layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + num_total_pos, + num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, + gt_labels_list, gt_masks_list, + img_metas) + # shape (batch_size, num_queries) + labels = torch.stack(labels_list, dim=0) + # shape (batch_size, num_queries) + label_weights = torch.stack(label_weights_list, dim=0) + # shape (num_total_gts, h, w) + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_ones(self.num_classes + 1) + class_weight[-1] = self.bg_cls_weight + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + target_shape = mask_targets.shape[-2:] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + # upsample to shape of target + # shape (num_total_gts, h, w) + mask_preds = F.interpolate( + mask_preds.unsqueeze(1), + target_shape, + mode='bilinear', + align_corners=False).squeeze(1) + + # dice loss + loss_dice = self.loss_dice( + mask_preds, mask_targets, avg_factor=num_total_masks) + + # mask loss + # FocalLoss support input of shape (n, num_class) + h, w = mask_preds.shape[-2:] + # shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1) + mask_preds = mask_preds.reshape(-1, 1) + # shape (num_total_gts, h, w) -> (num_total_gts * h * w) + mask_targets = mask_targets.reshape(-1) + # target is (1 - mask_targets) !!! + loss_mask = self.loss_mask( + mask_preds, 1 - mask_targets, avg_factor=num_total_masks * h * w) + + return loss_cls, loss_mask, loss_dice + + def forward(self, feats, img_metas): + """Forward function. + + Args: + feats (list[Tensor]): Features from the upstream network, each + is a 4D-tensor. + img_metas (list[dict]): List of image information. + + Returns: + all_cls_scores (Tensor): Classification scores for each\ + scale level. Each is a 4D-tensor with shape\ + (num_decoder, batch_size, num_queries, cls_out_channels).\ + Note `cls_out_channels` should includes background. + all_mask_preds (Tensor): Mask scores for each decoder\ + layer. Each with shape (num_decoder, batch_size,\ + num_queries, h, w). + """ + batch_size = len(img_metas) + input_img_h, input_img_w = img_metas[0]['batch_input_shape'] + padding_mask = feats[-1].new_ones( + (batch_size, input_img_h, input_img_w), dtype=torch.float32) + for i in range(batch_size): + img_h, img_w, _ = img_metas[i]['img_shape'] + padding_mask[i, :img_h, :img_w] = 0 + padding_mask = F.interpolate( + padding_mask.unsqueeze(1), + size=feats[-1].shape[-2:], + mode='nearest').to(torch.bool).squeeze(1) + # when backbone is swin, memory is output of last stage of swin. + # when backbone is r50, memory is output of tranformer encoder. + mask_features, memory = self.pixel_decoder(feats, img_metas) + pos_embed = self.decoder_pe(padding_mask) + memory = self.decoder_input_proj(memory) + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + memory = memory.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + # shape (batch_size, h * w) + padding_mask = padding_mask.flatten(1) + # shape = (num_queries, embed_dims) + query_embed = self.query_embed.weight + # shape = (num_queries, batch_size, embed_dims) + query_embed = query_embed.unsqueeze(1).repeat(1, batch_size, 1) + target = torch.zeros_like(query_embed) + # shape (num_decoder, num_queries, batch_size, embed_dims) + out_dec = self.transformer_decoder( + query=target, + key=memory, + value=memory, + key_pos=pos_embed, + query_pos=query_embed, + key_padding_mask=padding_mask) + # shape (num_decoder, batch_size, num_queries, embed_dims) + out_dec = out_dec.transpose(1, 2) + + # cls_scores + all_cls_scores = self.cls_embed(out_dec) + + # mask_preds + mask_embed = self.mask_embed(out_dec) + all_mask_preds = torch.einsum('lbqc,bchw->lbqhw', mask_embed, + mask_features) + + return all_cls_scores, all_mask_preds + + def forward_train(self, + feats, + img_metas, + gt_bboxes, + gt_labels, + gt_masks, + gt_semantic_seg, + gt_bboxes_ignore=None): + """Forward function for training mode. + + Args: + feats (list[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + img_metas (list[Dict]): List of image information. + gt_bboxes (list[Tensor]): Each element is ground truth bboxes of + the image, shape (num_gts, 4). Not used here. + gt_labels (list[Tensor]): Each element is ground truth labels of + each box, shape (num_gts,). + gt_masks (list[BitmapMasks]): Each element is masks of instances + of a image, shape (num_gts, h, w). + gt_semantic_seg (list[tensor]):Each element is the ground truth + of semantic segmentation with the shape (N, H, W). + [0, num_thing_class - 1] means things, + [num_thing_class, num_class-1] means stuff, + 255 means VOID. + gt_bboxes_ignore (list[Tensor]): Ground truth bboxes to be + ignored. Defaults to None. + + Returns: + losses (dict[str, Tensor]): a dictionary of loss components + """ + # not consider ignoring bboxes + assert gt_bboxes_ignore is None + + # forward + all_cls_scores, all_mask_preds = self(feats, img_metas) + + # preprocess ground truth + gt_labels, gt_masks = self.preprocess_gt(gt_labels, gt_masks, + gt_semantic_seg) + + # loss + losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, + img_metas) + + return losses + + def simple_test(self, feats, img_metas, rescale=False): + """Test segment without test-time aumengtation. + + Only the output of last decoder layers was used. + + Args: + feats (list[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + rescale (bool, optional): If True, return boxes in + original image space. Default False. + + Returns: + list[dict[str, np.array]]: semantic segmentation results\ + and panoptic segmentation results for each image. + + .. code-block:: none + + [ + { + 'pan_results': <np.ndarray>, # shape = [h, w] + }, + ... + ] + """ + all_cls_scores, all_mask_preds = self(feats, img_metas) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + + # upsample masks + img_shape = img_metas[0]['batch_input_shape'] + mask_pred_results = F.interpolate( + mask_pred_results, + size=(img_shape[0], img_shape[1]), + mode='bilinear', + align_corners=False) + + results = [] + for mask_cls_result, mask_pred_result, meta in zip( + mask_cls_results, mask_pred_results, img_metas): + # remove padding + img_height, img_width = meta['img_shape'][:2] + mask_pred_result = mask_pred_result[:, :img_height, :img_width] + + if rescale: + # return result in original resolution + ori_height, ori_width = meta['ori_shape'][:2] + mask_pred_result = F.interpolate(mask_pred_result.unsqueeze(1), + size=(ori_height, ori_width), + mode='bilinear', + align_corners=False)\ + .squeeze(1) + + mask = self.post_process(mask_cls_result, mask_pred_result) + results.append(mask) + + return results + + def post_process(self, mask_cls, mask_pred): + """Panoptic segmengation inference. + + This implementation is modified from\ + https://github.com/facebookresearch/MaskFormer + + Args: + mask_cls (Tensor): Classfication outputs for a image. + shape = (num_queries, cls_out_channels). + mask_pred (Tensor): Mask outputs for a image. + shape = (num_queries, h, w). + + Returns: + panoptic_seg (Tensor): panoptic segment result of shape (h, w),\ + each element in Tensor means: + segment_id = _cls + instance_id * INSTANCE_OFFSET. + """ + object_mask_thr = self.test_cfg.get('object_mask_thr', 0.8) + iou_thr = self.test_cfg.get('iou_thr', 0.8) + + scores, labels = F.softmax(mask_cls, dim=-1).max(-1) + mask_pred = mask_pred.sigmoid() + + keep = labels.ne(self.num_classes) & (scores > object_mask_thr) + cur_scores = scores[keep] + cur_classes = labels[keep] + cur_masks = mask_pred[keep] + + cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks + + h, w = cur_masks.shape[-2:] + panoptic_seg = torch.full((h, w), + self.num_classes, + dtype=torch.int32, + device=cur_masks.device) + if cur_masks.shape[0] == 0: + # We didn't detect any mask :( + pass + else: + cur_mask_ids = cur_prob_masks.argmax(0) + instance_id = 1 + for k in range(cur_classes.shape[0]): + pred_class = int(cur_classes[k].item()) + isthing = pred_class < self.num_things_classes + mask = cur_mask_ids == k + mask_area = mask.sum().item() + original_area = (cur_masks[k] >= 0.5).sum().item() + if mask_area > 0 and original_area > 0: + if mask_area / original_area < iou_thr: + continue + + if not isthing: + # different stuff regions of same class will be + # merged here, and stuff share the instance_id 0. + panoptic_seg[mask] = pred_class + else: + panoptic_seg[mask] = ( + pred_class + instance_id * INSTANCE_OFFSET) + instance_id += 1 + return panoptic_seg diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 456b8d424fb..9f05a282c18 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -19,6 +19,7 @@ from .lad import LAD from .mask_rcnn import MaskRCNN from .mask_scoring_rcnn import MaskScoringRCNN +from .maskformer import MaskFormer from .nasfcos import NASFCOS from .paa import PAA from .panoptic_fpn import PanopticFPN @@ -49,5 +50,6 @@ 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', 'YOLOV3', 'YOLACT', 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX', - 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD' + 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD', + 'MaskFormer' ] diff --git a/mmdet/models/detectors/maskformer.py b/mmdet/models/detectors/maskformer.py new file mode 100644 index 00000000000..17c5d6c895c --- /dev/null +++ b/mmdet/models/detectors/maskformer.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..builder import DETECTORS, build_backbone, build_head, build_neck +from .single_stage import SingleStageDetector + + +@DETECTORS.register_module() +class MaskFormer(SingleStageDetector): + r"""Implementation of `Per-Pixel Classification is + NOT All You Need for Semantic Segmentation + <https://arxiv.org/pdf/2107.06278>`_""" + + def __init__(self, + backbone, + neck=None, + panoptic_head=None, + train_cfg=None, + test_cfg=None, + init_cfg=None): + super(SingleStageDetector, self).__init__(init_cfg=init_cfg) + self.backbone = build_backbone(backbone) + if neck is not None: + self.neck = build_neck(neck) + panoptic_head.update(train_cfg=train_cfg) + panoptic_head.update(test_cfg=test_cfg) + self.panoptic_head = build_head(panoptic_head) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def forward_dummy(self, img, img_metas): + """Used for computing network flops. See + `mmdetection/tools/analysis_tools/get_flops.py` + + Args: + img (Tensor): of shape (N, C, H, W) encoding input images. + Typically these should be mean centered and std scaled. + img_metas (list[Dict]): list of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmdet/datasets/pipelines/formatting.py:Collect`. + """ + super(SingleStageDetector, self).forward_train(img, img_metas) + x = self.extract_feat(img) + outs = self.panoptic_head(x, img_metas) + return outs + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_masks, + gt_semantic_seg, + gt_bboxes_ignore=None, + **kargs): + """ + Args: + img (Tensor): of shape (N, C, H, W) encoding input images. + Typically these should be mean centered and std scaled. + img_metas (list[Dict]): list of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmdet/datasets/pipelines/formatting.py:Collect`. + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box. + gt_masks (list[BitmapMasks]): true segmentation masks for each box + used if the architecture supports a segmentation task. + gt_semantic_seg (list[tensor]): semantic segmentation mask for + images. + gt_bboxes_ignore (list[Tensor]): specify which bounding + boxes can be ignored when computing the loss. + Defaults to None. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # add batch_input_shape in img_metas + super(SingleStageDetector, self).forward_train(img, img_metas) + x = self.extract_feat(img) + losses = self.panoptic_head.forward_train(x, img_metas, gt_bboxes, + gt_labels, gt_masks, + gt_semantic_seg, + gt_bboxes_ignore) + + return losses + + def simple_test(self, img, img_metas, **kwargs): + """Test without augmentation.""" + feat = self.extract_feat(img) + mask_results = self.panoptic_head.simple_test(feat, img_metas, + **kwargs) + + results = [] + for mask in mask_results: + result = {'pan_results': mask.detach().cpu().numpy()} + results.append(result) + + return results + + def aug_test(self, imgs, img_metas, **kwargs): + raise NotImplementedError + + def onnx_export(self, img, img_metas): + raise NotImplementedError diff --git a/mmdet/models/losses/dice_loss.py b/mmdet/models/losses/dice_loss.py index 121367a36ec..585beeaf1c6 100644 --- a/mmdet/models/losses/dice_loss.py +++ b/mmdet/models/losses/dice_loss.py @@ -11,10 +11,16 @@ def dice_loss(pred, weight=None, eps=1e-3, reduction='mean', + naive_dice=False, avg_factor=None): - """Calculate dice loss, which is proposed in - `V-Net: Fully Convolutional Neural Networks for Volumetric - Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_. + """Calculate dice loss, there are two forms of dice loss is supported: + + - the one proposed in `V-Net: Fully Convolutional Neural + Networks for Volumetric Medical Image Segmentation + <https://arxiv.org/abs/1606.04797>`_. + - the dice loss in which the power of the number in the + denominator is the first power instead of the second + power. Args: pred (torch.Tensor): The prediction, has a shape (n, *) @@ -26,6 +32,11 @@ def dice_loss(pred, reduction (str, optional): The method used to reduce the loss into a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power.Defaults to False. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. """ @@ -34,9 +45,15 @@ def dice_loss(pred, target = target.flatten(1).float() a = torch.sum(input * target, 1) - b = torch.sum(input * input, 1) + eps - c = torch.sum(target * target, 1) + eps - d = (2 * a) / (b + c) + if naive_dice: + b = torch.sum(input, 1) + c = torch.sum(target, 1) + d = (2 * a + eps) / (b + c + eps) + else: + b = torch.sum(input * input, 1) + eps + c = torch.sum(target * target, 1) + eps + d = (2 * a) / (b + c) + loss = 1 - d if weight is not None: assert weight.ndim == loss.ndim @@ -52,11 +69,10 @@ def __init__(self, use_sigmoid=True, activate=True, reduction='mean', + naive_dice=False, loss_weight=1.0, eps=1e-3): - """`Dice Loss, which is proposed in - `V-Net: Fully Convolutional Neural Networks for Volumetric - Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_. + """Compute dice loss. Args: use_sigmoid (bool, optional): Whether to the prediction is @@ -67,6 +83,11 @@ def __init__(self, reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". Defaults to 'mean'. + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power. Defaults to False. loss_weight (float, optional): Weight of loss. Defaults to 1.0. eps (float): Avoid dividing by zero. Defaults to 1e-3. """ @@ -74,6 +95,7 @@ def __init__(self, super(DiceLoss, self).__init__() self.use_sigmoid = use_sigmoid self.reduction = reduction + self.naive_dice = naive_dice self.loss_weight = loss_weight self.eps = eps self.activate = activate @@ -118,6 +140,7 @@ def forward(self, weight, eps=self.eps, reduction=reduction, + naive_dice=self.naive_dice, avg_factor=avg_factor) return loss diff --git a/mmdet/models/plugins/__init__.py b/mmdet/models/plugins/__init__.py index a4368551ddc..940d94e884a 100644 --- a/mmdet/models/plugins/__init__.py +++ b/mmdet/models/plugins/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dropblock import DropBlock +from .pixel_decoder import PixelDecoder, TransformerEncoderPixelDecoder -__all__ = ['DropBlock'] +__all__ = ['DropBlock', 'PixelDecoder', 'TransformerEncoderPixelDecoder'] diff --git a/mmdet/models/plugins/pixel_decoder.py b/mmdet/models/plugins/pixel_decoder.py new file mode 100644 index 00000000000..f69daf46f9a --- /dev/null +++ b/mmdet/models/plugins/pixel_decoder.py @@ -0,0 +1,245 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init +from mmcv.cnn.bricks.transformer import (build_positional_encoding, + build_transformer_layer_sequence) +from mmcv.runner import BaseModule, ModuleList + + +@PLUGIN_LAYERS.register_module() +class PixelDecoder(BaseModule): + """Pixel decoder with a structure like fpn. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + feat_channels (int): Number channels for feature. + out_channels (int): Number channels for output. + norm_cfg (obj:`mmcv.ConfigDict`|dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (obj:`mmcv.ConfigDict`|dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (obj:`mmcv.ConfigDict`|dict): Config for transorformer + encoder.Defaults to None. + positional_encoding (obj:`mmcv.ConfigDict`|dict): Config for + transformer encoder position encoding. Defaults to + dict(type='SinePositionalEncoding', num_feats=128, + normalize=True). + init_cfg (obj:`mmcv.ConfigDict`|dict): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + feat_channels, + out_channels, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.num_inputs = len(in_channels) + self.lateral_convs = ModuleList() + self.output_convs = ModuleList() + self.use_bias = norm_cfg is None + for i in range(0, self.num_inputs - 1): + l_conv = ConvModule( + in_channels[i], + feat_channels, + kernel_size=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=None) + o_conv = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.lateral_convs.append(l_conv) + self.output_convs.append(o_conv) + + self.last_feat_conv = ConvModule( + in_channels[-1], + feat_channels, + kernel_size=3, + padding=1, + stride=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.mask_feature = Conv2d( + feat_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def init_weights(self): + """Initialize weights.""" + for i in range(0, self.num_inputs - 2): + caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) + caffe2_xavier_init(self.output_convs[i].conv, bias=0) + + caffe2_xavier_init(self.mask_feature, bias=0) + caffe2_xavier_init(self.last_feat_conv, bias=0) + + def forward(self, feats, img_metas): + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + img_metas (list[dict]): List of image information. Pass in + for creating more accurate padding mask. Not used here. + + Returns: + tuple: a tuple containing the following: + + - mask_feature (Tensor): Shape (batch_size, c, h, w). + - memory (Tensor): Output of last stage of backbone.\ + Shape (batch_size, c, h, w). + """ + y = self.last_feat_conv(feats[-1]) + for i in range(self.num_inputs - 2, -1, -1): + x = feats[i] + cur_fpn = self.lateral_convs[i](x) + y = cur_fpn + \ + F.interpolate(y, size=cur_fpn.shape[-2:], mode='nearest') + y = self.output_convs[i](y) + + mask_feature = self.mask_feature(y) + memory = feats[-1] + return mask_feature, memory + + +@PLUGIN_LAYERS.register_module() +class TransformerEncoderPixelDecoder(PixelDecoder): + """Pixel decoder with transormer encoder inside. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + feat_channels (int): Number channels for feature. + out_channels (int): Number channels for output. + norm_cfg (obj:`mmcv.ConfigDict`|dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (obj:`mmcv.ConfigDict`|dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (obj:`mmcv.ConfigDict`|dict): Config for transorformer + encoder.Defaults to None. + positional_encoding (obj:`mmcv.ConfigDict`|dict): Config for + transformer encoder position encoding. Defaults to + dict(type='SinePositionalEncoding', num_feats=128, + normalize=True). + init_cfg (obj:`mmcv.ConfigDict`|dict): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + feat_channels, + out_channels, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=None, + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=128, + normalize=True), + init_cfg=None): + super(TransformerEncoderPixelDecoder, self).__init__( + in_channels, + feat_channels, + out_channels, + norm_cfg, + act_cfg, + init_cfg=init_cfg) + self.last_feat_conv = None + + self.encoder = build_transformer_layer_sequence(encoder) + self.encoder_embed_dims = self.encoder.embed_dims + assert self.encoder_embed_dims == feat_channels, 'embed_dims({}) of ' \ + 'tranformer encoder must equal to feat_channels({})'.format( + feat_channels, self.encoder_embed_dims) + self.positional_encoding = build_positional_encoding( + positional_encoding) + self.encoder_in_proj = Conv2d( + in_channels[-1], feat_channels, kernel_size=1) + self.encoder_out_proj = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def init_weights(self): + """Initialize weights.""" + for i in range(0, self.num_inputs - 2): + caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) + caffe2_xavier_init(self.output_convs[i].conv, bias=0) + + caffe2_xavier_init(self.mask_feature, bias=0) + caffe2_xavier_init(self.encoder_in_proj, bias=0) + caffe2_xavier_init(self.encoder_out_proj.conv, bias=0) + + for p in self.encoder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feats, img_metas): + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + img_metas (list[dict]): List of image information. Pass in + for creating more accurate padding mask. + + Returns: + tuple: a tuple containing the following: + + - mask_feature (Tensor): shape (batch_size, c, h, w). + - memory (Tensor): shape (batch_size, c, h, w). + """ + feat_last = feats[-1] + bs, c, h, w = feat_last.shape + input_img_h, input_img_w = img_metas[0]['batch_input_shape'] + padding_mask = feat_last.new_ones((bs, input_img_h, input_img_w), + dtype=torch.float32) + for i in range(bs): + img_h, img_w, _ = img_metas[i]['img_shape'] + padding_mask[i, :img_h, :img_w] = 0 + padding_mask = F.interpolate( + padding_mask.unsqueeze(1), + size=feat_last.shape[-2:], + mode='nearest').to(torch.bool).squeeze(1) + + pos_embed = self.positional_encoding(padding_mask) + feat_last = self.encoder_in_proj(feat_last) + # (batch_size, c, h, w) -> (num_queries, batch_size, c) + feat_last = feat_last.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + # (batch_size, h, w) -> (batch_size, h*w) + padding_mask = padding_mask.flatten(1) + memory = self.encoder( + query=feat_last, + key=None, + value=None, + query_pos=pos_embed, + query_key_padding_mask=padding_mask) + # (num_queries, batch_size, c) -> (batch_size, c, h, w) + memory = memory.permute(1, 2, 0).view(bs, self.encoder_embed_dims, h, + w) + y = self.encoder_out_proj(memory) + for i in range(self.num_inputs - 2, -1, -1): + x = feats[i] + cur_fpn = self.lateral_convs[i](x) + y = cur_fpn + \ + F.interpolate(y, size=cur_fpn.shape[-2:], mode='nearest') + y = self.output_convs[i](y) + + mask_feature = self.mask_feature(y) + return mask_feature, memory diff --git a/mmdet/models/utils/__init__.py b/mmdet/models/utils/__init__.py index 84dc141e850..add5693b60c 100644 --- a/mmdet/models/utils/__init__.py +++ b/mmdet/models/utils/__init__.py @@ -9,6 +9,7 @@ from .make_divisible import make_divisible from .misc import interpolate_as, sigmoid_geometric_mean from .normed_predictor import NormedConv2d, NormedLinear +from .panoptic_gt_processing import preprocess_panoptic_gt from .positional_encoding import (LearnedPositionalEncoding, SinePositionalEncoding) from .res_layer import ResLayer, SimplifiedBasicBlock @@ -25,5 +26,6 @@ 'NormedLinear', 'NormedConv2d', 'make_divisible', 'InvertedResidual', 'SELayer', 'interpolate_as', 'ConvUpsample', 'CSPLayer', 'adaptive_avg_pool2d', 'AdaptiveAvgPool2d', 'PatchEmbed', 'nchw_to_nlc', - 'nlc_to_nchw', 'pvt_convert', 'sigmoid_geometric_mean', 'DyReLU' + 'nlc_to_nchw', 'pvt_convert', 'sigmoid_geometric_mean', + 'preprocess_panoptic_gt', 'DyReLU' ] diff --git a/mmdet/models/utils/panoptic_gt_processing.py b/mmdet/models/utils/panoptic_gt_processing.py new file mode 100644 index 00000000000..513f644945c --- /dev/null +++ b/mmdet/models/utils/panoptic_gt_processing.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def preprocess_panoptic_gt(gt_labels, gt_masks, gt_semantic_seg, num_things, + num_stuff): + """Preprocess the ground truth for a image. + + Args: + gt_labels (Tensor): Ground truth labels of each bbox, + with shape (num_gts, ). + gt_masks (BitmapMasks): Ground truth masks of each instances + of a image, shape (num_gts, h, w). + gt_semantic_seg (Tensor): Ground truth of semantic + segmentation with the shape (1, h, w). + [0, num_thing_class - 1] means things, + [num_thing_class, num_class-1] means stuff, + 255 means VOID. + target_shape (tuple[int]): Shape of output mask_preds. + Resize the masks to shape of mask_preds. + + Returns: + tuple: a tuple containing the following targets. + + - labels (Tensor): Ground truth class indices for a + image, with shape (n, ), n is the sum of number + of stuff type and number of instance in a image. + - masks (Tensor): Ground truth mask for a image, with + shape (n, h, w). + """ + num_classes = num_things + num_stuff + things_labels = gt_labels + gt_semantic_seg = gt_semantic_seg.squeeze(0) + + things_masks = gt_masks.pad(gt_semantic_seg.shape[-2:], pad_val=0)\ + .to_tensor(dtype=torch.bool, device=gt_labels.device) + + semantic_labels = torch.unique( + gt_semantic_seg, + sorted=False, + return_inverse=False, + return_counts=False) + stuff_masks_list = [] + stuff_labels_list = [] + for label in semantic_labels: + if label < num_things or label >= num_classes: + continue + stuff_mask = gt_semantic_seg == label + stuff_masks_list.append(stuff_mask) + stuff_labels_list.append(label) + + if len(stuff_masks_list) > 0: + stuff_masks = torch.stack(stuff_masks_list, dim=0) + stuff_labels = torch.stack(stuff_labels_list, dim=0) + labels = torch.cat([things_labels, stuff_labels], dim=0) + masks = torch.cat([things_masks, stuff_masks], dim=0) + else: + labels = things_labels + masks = things_masks + + masks = masks.long() + return labels, masks diff --git a/tests/test_models/test_dense_heads/test_maskformer_head.py b/tests/test_models/test_dense_heads/test_maskformer_head.py new file mode 100644 index 00000000000..e70f09afe3f --- /dev/null +++ b/tests/test_models/test_dense_heads/test_maskformer_head.py @@ -0,0 +1,203 @@ +import numpy as np +import torch +from mmcv import ConfigDict + +from mmdet.core.mask import BitmapMasks +from mmdet.models.dense_heads import MaskFormerHead + + +def test_maskformer_head_loss(): + """Tests head loss when truth is empty and non-empty.""" + base_channels = 64 + # batch_input_shape = (128, 160) + img_metas = [{ + 'batch_input_shape': (128, 160), + 'img_shape': (126, 160, 3), + 'ori_shape': (63, 80, 3) + }, { + 'batch_input_shape': (128, 160), + 'img_shape': (120, 160, 3), + 'ori_shape': (60, 80, 3) + }] + feats = [ + torch.rand((2, 64 * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i))) + for i in range(4) + ] + + config = ConfigDict( + dict( + type='MaskFormerHead', + in_channels=[base_channels * 2**i for i in range(4)], + feat_channels=base_channels, + out_channels=base_channels, + num_things_classes=80, + num_stuff_classes=53, + num_queries=100, + pixel_decoder=dict( + type='TransformerEncoderPixelDecoder', + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=base_channels, + num_heads=8, + attn_drop=0.1, + proj_drop=0.1, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=base_channels, + feedforward_channels=base_channels * 8, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.1, + dropout_layer=None, + add_identity=True), + operation_order=('self_attn', 'norm', 'ffn', 'norm'), + norm_cfg=dict(type='LN'), + init_cfg=None, + batch_first=False), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=base_channels // 2, + normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=base_channels // 2, + normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=6, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=base_channels, + num_heads=8, + attn_drop=0.1, + proj_drop=0.1, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=base_channels, + feedforward_channels=base_channels * 8, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.1, + dropout_layer=None, + add_identity=True), + # the following parameter was not used, + # just make current api happy + feedforward_channels=base_channels * 8, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None), + loss_cls=dict( + type='CrossEntropyLoss', + bg_cls_weight=0.1, + use_sigmoid=False, + loss_weight=1.0, + reduction='mean', + class_weight=1.0), + loss_mask=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=20.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=1.0), + train_cfg=dict( + assigner=dict( + type='MaskHungarianAssigner', + cls_cost=dict(type='ClassificationCost', weight=1.0), + mask_cost=dict( + type='FocalLossCost', weight=20.0, binary_input=True), + dice_cost=dict( + type='DiceCost', weight=1.0, pred_act=True, eps=1.0)), + sampler=dict(type='MaskPseudoSampler')), + test_cfg=dict(object_mask_thr=0.8, iou_thr=0.8))) + self = MaskFormerHead(**config) + self.init_weights() + all_cls_scores, all_mask_preds = self.forward(feats, img_metas) + # Test that empty ground truth encourages the network to predict background + gt_labels_list = [torch.LongTensor([]), torch.LongTensor([])] + gt_masks_list = [ + torch.zeros((0, 128, 160)).long(), + torch.zeros((0, 128, 160)).long() + ] + + empty_gt_losses = self.loss(all_cls_scores, all_mask_preds, gt_labels_list, + gt_masks_list, img_metas) + # When there is no truth, the cls loss should be nonzero but there should + # be no mask loss. + for key, loss in empty_gt_losses.items(): + if 'cls' in key: + assert loss.item() > 0, 'cls loss should be non-zero' + elif 'mask' in key: + assert loss.item( + ) == 0, 'there should be no mask loss when there are no true mask' + elif 'dice' in key: + assert loss.item( + ) == 0, 'there should be no dice loss when there are no true mask' + + # when truth is non-empty then both cls, mask, dice loss should be nonzero + # random inputs + gt_labels_list = [ + torch.tensor([10, 100]).long(), + torch.tensor([100, 10]).long() + ] + mask1 = torch.zeros((2, 128, 160)).long() + mask1[0, :50] = 1 + mask1[1, 50:] = 1 + mask2 = torch.zeros((2, 128, 160)).long() + mask2[0, :, :50] = 1 + mask2[1, :, 50:] = 1 + gt_masks_list = [mask1, mask2] + two_gt_losses = self.loss(all_cls_scores, all_mask_preds, gt_labels_list, + gt_masks_list, img_metas) + for loss in two_gt_losses.values(): + assert loss.item() > 0, 'all loss should be non-zero' + + # test forward_train + gt_bboxes = None + gt_labels = [ + torch.tensor([10]).long(), + torch.tensor([10]).long(), + ] + thing_mask1 = np.zeros((1, 128, 160), dtype=np.int32) + thing_mask1[0, :50] = 1 + thing_mask2 = np.zeros((1, 128, 160), dtype=np.int32) + thing_mask2[0, :, 50:] = 1 + gt_masks = [ + BitmapMasks(thing_mask1, 128, 160), + BitmapMasks(thing_mask2, 128, 160), + ] + stuff_mask1 = torch.zeros((1, 128, 160)).long() + stuff_mask1[0, :50] = 10 + stuff_mask1[0, 50:] = 100 + stuff_mask2 = torch.zeros((1, 128, 160)).long() + stuff_mask2[0, :, 50:] = 10 + stuff_mask2[0, :, :50] = 100 + gt_semantic_seg = [stuff_mask1, stuff_mask2] + + self.forward_train(feats, img_metas, gt_bboxes, gt_labels, gt_masks, + gt_semantic_seg) + + # test inference mode + self.simple_test(feats, img_metas) diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index db75b2fd418..6b28ba61514 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -700,3 +700,114 @@ def test_yolox_random_size(): gt_labels=gt_labels, return_loss=True) assert detector._input_size == (64, 96) + + +def test_maskformer_forward(): + model_cfg = _get_detector_cfg( + 'maskformer/maskformer_r50_mstrain_16x1_75e_coco.py') + base_channels = 32 + model_cfg.backbone.depth = 18 + model_cfg.backbone.init_cfg = None + model_cfg.backbone.base_channels = base_channels + model_cfg.panoptic_head.in_channels = [ + base_channels * 2**i for i in range(4) + ] + model_cfg.panoptic_head.feat_channels = base_channels + model_cfg.panoptic_head.out_channels = base_channels + model_cfg.panoptic_head.pixel_decoder.encoder.\ + transformerlayers.attn_cfgs.embed_dims = base_channels + model_cfg.panoptic_head.pixel_decoder.encoder.\ + transformerlayers.ffn_cfgs.embed_dims = base_channels + model_cfg.panoptic_head.pixel_decoder.encoder.\ + transformerlayers.ffn_cfgs.feedforward_channels = base_channels * 8 + model_cfg.panoptic_head.pixel_decoder.\ + positional_encoding.num_feats = base_channels // 2 + model_cfg.panoptic_head.positional_encoding.\ + num_feats = base_channels // 2 + model_cfg.panoptic_head.transformer_decoder.\ + transformerlayers.attn_cfgs.embed_dims = base_channels + model_cfg.panoptic_head.transformer_decoder.\ + transformerlayers.ffn_cfgs.embed_dims = base_channels + model_cfg.panoptic_head.transformer_decoder.\ + transformerlayers.ffn_cfgs.feedforward_channels = base_channels * 8 + model_cfg.panoptic_head.transformer_decoder.\ + transformerlayers.feedforward_channels = base_channels * 8 + + from mmdet.core import BitmapMasks + from mmdet.models import build_detector + detector = build_detector(model_cfg) + + # Test forward train with non-empty truth batch + detector.train() + img_metas = [ + { + 'batch_input_shape': (128, 160), + 'img_shape': (126, 160, 3), + 'ori_shape': (63, 80, 3), + 'pad_shape': (128, 160, 3) + }, + ] + img = torch.rand((1, 3, 128, 160)) + gt_bboxes = None + gt_labels = [ + torch.tensor([10]).long(), + ] + thing_mask1 = np.zeros((1, 128, 160), dtype=np.int32) + thing_mask1[0, :50] = 1 + gt_masks = [ + BitmapMasks(thing_mask1, 128, 160), + ] + stuff_mask1 = torch.zeros((1, 128, 160)).long() + stuff_mask1[0, :50] = 10 + stuff_mask1[0, 50:] = 100 + gt_semantic_seg = [ + stuff_mask1, + ] + losses = detector.forward( + img=img, + img_metas=img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + gt_masks=gt_masks, + gt_semantic_seg=gt_semantic_seg, + return_loss=True) + assert isinstance(losses, dict) + loss, _ = detector._parse_losses(losses) + assert float(loss.item()) > 0 + + # Test forward train with an empty truth batch + gt_bboxes = [ + torch.empty((0, 4)).float(), + ] + gt_labels = [ + torch.empty((0, )).long(), + ] + mask = np.zeros((0, 128, 160), dtype=np.uint8) + gt_masks = [ + BitmapMasks(mask, 128, 160), + ] + gt_semantic_seg = [ + torch.randint(0, 133, (0, 128, 160)), + ] + losses = detector.forward( + img, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + gt_masks=gt_masks, + gt_semantic_seg=gt_semantic_seg, + return_loss=True) + assert isinstance(losses, dict) + loss, _ = detector._parse_losses(losses) + assert float(loss.item()) > 0 + + # Test forward test + detector.eval() + with torch.no_grad(): + img_list = [g[None, :] for g in img] + batch_results = [] + for one_img, one_meta in zip(img_list, img_metas): + result = detector.forward([one_img], [[one_meta]], + rescale=True, + return_loss=False) + batch_results.append(result) diff --git a/tests/test_models/test_loss.py b/tests/test_models/test_loss.py index 101e5efef5d..380bc3263f7 100644 --- a/tests/test_models/test_loss.py +++ b/tests/test_models/test_loss.py @@ -165,51 +165,55 @@ def test_loss_with_ignore_index(use_sigmoid): assert torch.allclose(loss, loss_with_forward_ignore) -def test_dice_loss(): +@pytest.mark.parametrize('naive_dice', [True, False]) +def test_dice_loss(naive_dice): loss_class = DiceLoss pred = torch.rand((10, 4, 4)) target = torch.rand((10, 4, 4)) weight = torch.rand((10)) # Test loss forward - loss = loss_class()(pred, target) + loss = loss_class(naive_dice=naive_dice)(pred, target) assert isinstance(loss, torch.Tensor) # Test loss forward with weight - loss = loss_class()(pred, target, weight) + loss = loss_class(naive_dice=naive_dice)(pred, target, weight) assert isinstance(loss, torch.Tensor) # Test loss forward with reduction_override - loss = loss_class()(pred, target, reduction_override='mean') + loss = loss_class(naive_dice=naive_dice)( + pred, target, reduction_override='mean') assert isinstance(loss, torch.Tensor) # Test loss forward with avg_factor - loss = loss_class()(pred, target, avg_factor=10) + loss = loss_class(naive_dice=naive_dice)(pred, target, avg_factor=10) assert isinstance(loss, torch.Tensor) with pytest.raises(ValueError): # loss can evaluate with avg_factor only if # reduction is None, 'none' or 'mean'. reduction_override = 'sum' - loss_class()( + loss_class(naive_dice=naive_dice)( pred, target, avg_factor=10, reduction_override=reduction_override) # Test loss forward with avg_factor and reduction for reduction_override in [None, 'none', 'mean']: - loss_class()( + loss_class(naive_dice=naive_dice)( pred, target, avg_factor=10, reduction_override=reduction_override) assert isinstance(loss, torch.Tensor) # Test loss forward with has_acted=False and use_sigmoid=False with pytest.raises(NotImplementedError): - loss_class(use_sigmoid=False, activate=True)(pred, target) + loss_class( + use_sigmoid=False, activate=True, naive_dice=naive_dice)(pred, + target) # Test loss forward with weight.ndim != loss.ndim with pytest.raises(AssertionError): weight = torch.rand((2, 8)) - loss_class()(pred, target, weight) + loss_class(naive_dice=naive_dice)(pred, target, weight) # Test loss forward with len(weight) != len(pred) with pytest.raises(AssertionError): weight = torch.rand((8)) - loss_class()(pred, target, weight) + loss_class(naive_dice=naive_dice)(pred, target, weight) diff --git a/tests/test_models/test_plugins.py b/tests/test_models/test_plugins.py index 59416b20de2..b115fbd73f2 100644 --- a/tests/test_models/test_plugins.py +++ b/tests/test_models/test_plugins.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +from mmcv import ConfigDict +from mmcv.cnn import build_plugin_layer from mmdet.models.plugins import DropBlock @@ -27,3 +29,83 @@ def test_dropblock(): # warmup_iters cannot be less than 0 with pytest.raises(AssertionError): DropBlock(0.5, 3, -1) + + +def test_pixeldecoder(): + base_channels = 64 + pixel_decoder_cfg = ConfigDict( + dict( + type='PixelDecoder', + in_channels=[base_channels * 2**i for i in range(4)], + feat_channels=base_channels, + out_channels=base_channels, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'))) + self = build_plugin_layer(pixel_decoder_cfg)[1] + img_metas = [{}, {}] + feats = [ + torch.rand((2, base_channels * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i))) + for i in range(4) + ] + mask_feature, memory = self(feats, img_metas) + + assert (memory == feats[-1]).all() + assert mask_feature.shape == feats[0].shape + + +def test_transformerencoderpixeldecoer(): + base_channels = 64 + pixel_decoder_cfg = ConfigDict( + dict( + type='TransformerEncoderPixelDecoder', + in_channels=[base_channels * 2**i for i in range(4)], + feat_channels=base_channels, + out_channels=base_channels, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=base_channels, + num_heads=8, + attn_drop=0.1, + proj_drop=0.1, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=base_channels, + feedforward_channels=base_channels * 8, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.1, + dropout_layer=None, + add_identity=True), + operation_order=('self_attn', 'norm', 'ffn', 'norm'), + norm_cfg=dict(type='LN'), + init_cfg=None, + batch_first=False), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=base_channels // 2, + normalize=True))) + self = build_plugin_layer(pixel_decoder_cfg)[1] + img_metas = [{ + 'batch_input_shape': (128, 160), + 'img_shape': (120, 160, 3), + }, { + 'batch_input_shape': (128, 160), + 'img_shape': (125, 160, 3), + }] + feats = [ + torch.rand((2, base_channels * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i))) + for i in range(4) + ] + mask_feature, memory = self(feats, img_metas) + + assert memory.shape[-2:] == feats[-1].shape[-2:] + assert mask_feature.shape == feats[0].shape diff --git a/tests/test_utils/test_assigner.py b/tests/test_utils/test_assigner.py index ca82aeda127..7728510b166 100644 --- a/tests/test_utils/test_assigner.py +++ b/tests/test_utils/test_assigner.py @@ -10,8 +10,9 @@ from mmdet.core.bbox.assigners import (ApproxMaxIoUAssigner, CenterRegionAssigner, HungarianAssigner, - MaxIoUAssigner, PointAssigner, - TaskAlignedAssigner, UniformAssigner) + MaskHungarianAssigner, MaxIoUAssigner, + PointAssigner, TaskAlignedAssigner, + UniformAssigner) def test_max_iou_assigner(): @@ -539,3 +540,69 @@ def test_task_aligned_assigner(): pred_score, pred_bbox, anchor, gt_bboxes=gt_bboxes) expected_gt_inds = torch.LongTensor([0, 0, 0, 0]) assert torch.all(assign_result.gt_inds == expected_gt_inds) + + +def test_mask_hungarian_match_assigner(): + # test no gt masks + assigner_cfg = dict( + cls_cost=dict(type='ClassificationCost', weight=1.0), + mask_cost=dict(type='FocalLossCost', weight=20.0, binary_input=True), + dice_cost=dict(type='DiceCost', weight=1.0, pred_act=True, eps=1.0)) + self = MaskHungarianAssigner(**assigner_cfg) + cls_pred = torch.rand((10, 133)) + mask_pred = torch.rand((10, 50, 50)) + + gt_labels = torch.empty((0, )).long() + gt_masks = torch.empty((0, 50, 50)).float() + img_meta = None + assign_result = self.assign(cls_pred, mask_pred, gt_labels, gt_masks, + img_meta) + assert torch.all(assign_result.gt_inds == 0) + assert torch.all(assign_result.labels == -1) + + # test with gt masks + gt_labels = torch.LongTensor([10, 100]) + gt_masks = torch.zeros((2, 50, 50)).long() + gt_masks[0, :25] = 1 + gt_masks[0, 25:] = 1 + assign_result = self.assign(cls_pred, mask_pred, gt_labels, gt_masks, + img_meta) + assert torch.all(assign_result.gt_inds > -1) + assert (assign_result.gt_inds > 0).sum() == gt_labels.size(0) + assert (assign_result.labels > -1).sum() == gt_labels.size(0) + + # test with cls mode + assigner_cfg = dict( + cls_cost=dict(type='ClassificationCost', weight=1.0), + mask_cost=dict(type='FocalLossCost', weight=0.0, binary_input=True), + dice_cost=dict(type='DiceCost', weight=0.0, pred_act=True, eps=1.0)) + self = MaskHungarianAssigner(**assigner_cfg) + assign_result = self.assign(cls_pred, mask_pred, gt_labels, gt_masks, + img_meta) + assert torch.all(assign_result.gt_inds > -1) + assert (assign_result.gt_inds > 0).sum() == gt_labels.size(0) + assert (assign_result.labels > -1).sum() == gt_labels.size(0) + + # test with mask focal mode + assigner_cfg = dict( + cls_cost=dict(type='ClassificationCost', weight=0.0), + mask_cost=dict(type='FocalLossCost', weight=1.0, binary_input=True), + dice_cost=dict(type='DiceCost', weight=0.0, pred_act=True, eps=1.0)) + self = MaskHungarianAssigner(**assigner_cfg) + assign_result = self.assign(cls_pred, mask_pred, gt_labels, gt_masks, + img_meta) + assert torch.all(assign_result.gt_inds > -1) + assert (assign_result.gt_inds > 0).sum() == gt_labels.size(0) + assert (assign_result.labels > -1).sum() == gt_labels.size(0) + + # test with mask dice mode + assigner_cfg = dict( + cls_cost=dict(type='ClassificationCost', weight=0.0), + mask_cost=dict(type='FocalLossCost', weight=0.0, binary_input=True), + dice_cost=dict(type='DiceCost', weight=1.0, pred_act=True, eps=1.0)) + self = MaskHungarianAssigner(**assigner_cfg) + assign_result = self.assign(cls_pred, mask_pred, gt_labels, gt_masks, + img_meta) + assert torch.all(assign_result.gt_inds > -1) + assert (assign_result.gt_inds > 0).sum() == gt_labels.size(0) + assert (assign_result.labels > -1).sum() == gt_labels.size(0)