Skip to content

Commit

Permalink
[Fix] Fix MaskFormer and Mask2Former (open-mmlab#9515)
Browse files Browse the repository at this point in the history
Co-authored-by: Kei-Chi Tse <109070650+KeiChiTse@users.noreply.github.com>
  • Loading branch information
2 people authored and yumion committed Jan 31, 2024
1 parent e5afc49 commit 4eafa7c
Show file tree
Hide file tree
Showing 21 changed files with 528 additions and 223 deletions.
3 changes: 1 addition & 2 deletions configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@
ffn_drop=0.,
act_cfg=dict(type='PReLU'))),
return_intermediate=True),
positional_encoding_cfg=dict(
num_feats=128, temperature=20, normalize=True),
positional_encoding=dict(num_feats=128, temperature=20, normalize=True),
bbox_head=dict(
type='DABDETRHead',
num_classes=80,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)),
post_norm_cfg=None),
positional_encoding_cfg=dict(num_feats=128, normalize=True, offset=-0.5),
positional_encoding=dict(num_feats=128, normalize=True, offset=-0.5),
bbox_head=dict(
type='DeformableDETRHead',
num_classes=80,
Expand Down
2 changes: 1 addition & 1 deletion configs/detr/detr_r50_8xb2-150e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True))),
return_intermediate=True),
positional_encoding_cfg=dict(num_feats=128, normalize=True),
positional_encoding=dict(num_feats=128, normalize=True),
bbox_head=dict(
type='DETRHead',
num_classes=80,
Expand Down
2 changes: 1 addition & 1 deletion configs/dino/dino_4scale_r50_8xb2-12e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
feedforward_channels=2048, # 1024 for DeformDETR
ffn_drop=0.0)), # 0.1 for DeformDETR
post_norm_cfg=None),
positional_encoding_cfg=dict(
positional_encoding=dict(
num_feats=128,
normalize=True,
offset=0.0, # -0.5 for DeformDETR
Expand Down
57 changes: 20 additions & 37 deletions configs/mask2former/mask2former_r50_8xb2-lsj-50e_coco-panoptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,62 +55,45 @@
num_outs=3,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='DetrTransformerEncoder',
encoder=dict( # DeformableDetrTransformerEncoder
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
ffn_cfgs=dict(
type='FFN',
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.0,
act_cfg=dict(type='ReLU', inplace=True)),
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
init_cfg=None),
act_cfg=dict(type='ReLU', inplace=True)))),
positional_encoding=dict(num_feats=128, normalize=True)),
enforce_decoder_input_project=False,
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
transformer_decoder=dict(
type='DetrTransformerDecoder',
positional_encoding=dict(num_feats=128, normalize=True),
transformer_decoder=dict( # Mask2FormerTransformerDecoder
return_intermediate=True,
num_layers=9,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.0,
batch_first=True),
cross_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
dropout=0.0,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True),
feedforward_channels=2048,
operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
'ffn', 'norm')),
act_cfg=dict(type='ReLU', inplace=True))),
init_cfg=None),
loss_cls=dict(
type='CrossEntropyLoss',
Expand Down
4 changes: 2 additions & 2 deletions configs/mask2former/mask2former_r50_8xb2-lsj-50e_coco.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
_base_ = ['./mask2former_r50_8xb2-lsj-50e_coco-panoptic.py']

num_things_classes = 80
num_stuff_classes = 0
num_classes = num_things_classes + num_stuff_classes
Expand Down Expand Up @@ -56,7 +57,7 @@
]

test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='LoadImageFromFile', to_float32=True),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
# If you don't have a gt annotation, delete the pipeline
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
Expand All @@ -75,7 +76,6 @@
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
pipeline=train_pipeline))

val_dataloader = dict(
dataset=dict(
type=dataset_type,
Expand Down
66 changes: 22 additions & 44 deletions configs/maskformer/maskformer_r50_ms-16xb1-75e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,65 +41,43 @@
type='TransformerEncoderPixelDecoder',
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='DetrTransformerEncoder',
encoder=dict( # DetrTransformerEncoder
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiheadAttention',
layer_cfg=dict( # DetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
attn_drop=0.1,
proj_drop=0.1,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
dropout=0.1,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.1,
dropout_layer=None,
add_identity=True),
operation_order=('self_attn', 'norm', 'ffn', 'norm'),
norm_cfg=dict(type='LN'),
init_cfg=None,
batch_first=False),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True)),
act_cfg=dict(type='ReLU', inplace=True)))),
positional_encoding=dict(num_feats=128, normalize=True)),
enforce_decoder_input_project=False,
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
transformer_decoder=dict(
type='DetrTransformerDecoder',
return_intermediate=True,
positional_encoding=dict(num_feats=128, normalize=True),
transformer_decoder=dict( # DetrTransformerDecoder
num_layers=6,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
layer_cfg=dict( # DetrTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
attn_drop=0.1,
proj_drop=0.1,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
dropout=0.1,
batch_first=True),
cross_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.1,
dropout_layer=None,
add_identity=True),
# the following parameter was not used,
# just make current api happy
feedforward_channels=2048,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')),
init_cfg=None),
act_cfg=dict(type='ReLU', inplace=True))),
return_intermediate=True),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
Expand Down
45 changes: 22 additions & 23 deletions mmdet/models/dense_heads/mask2former_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures import SampleList
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig, reduce_mean
from ..layers import Mask2FormerTransformerDecoder, SinePositionalEncoding
from ..utils import get_uncertain_point_coords_with_randomness
from .anchor_free_head import AnchorFreeHead
from .maskformer_head import MaskFormerHead
Expand Down Expand Up @@ -42,7 +43,8 @@ class Mask2FormerHead(MaskFormerHead):
transformer_decoder (:obj:`ConfigDict` or dict): Config for
transformer decoder. Defaults to None.
positional_encoding (:obj:`ConfigDict` or dict): Config for
transformer decoder position encoding. Defaults to None.
transformer decoder position encoding. Defaults to
dict(num_feats=128, normalize=True).
loss_cls (:obj:`ConfigDict` or dict): Config of the classification
loss. Defaults to None.
loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss.
Expand All @@ -69,9 +71,7 @@ def __init__(self,
enforce_decoder_input_project: bool = False,
transformer_decoder: ConfigType = ...,
positional_encoding: ConfigType = dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True),
num_feats=128, normalize=True),
loss_cls: ConfigType = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
Expand Down Expand Up @@ -101,18 +101,18 @@ def __init__(self,
self.num_classes = self.num_things_classes + self.num_stuff_classes
self.num_queries = num_queries
self.num_transformer_feat_level = num_transformer_feat_level
self.num_heads = transformer_decoder.transformerlayers. \
attn_cfgs.num_heads
self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads
self.num_transformer_decoder_layers = transformer_decoder.num_layers
assert pixel_decoder.encoder.transformerlayers. \
attn_cfgs.num_levels == num_transformer_feat_level
assert pixel_decoder.encoder.layer_cfg. \
self_attn_cfg.num_levels == num_transformer_feat_level
pixel_decoder_ = copy.deepcopy(pixel_decoder)
pixel_decoder_.update(
in_channels=in_channels,
feat_channels=feat_channels,
out_channels=out_channels)
self.pixel_decoder = MODELS.build(pixel_decoder_)
self.transformer_decoder = MODELS.build(transformer_decoder)
self.transformer_decoder = Mask2FormerTransformerDecoder(
**transformer_decoder)
self.decoder_embed_dims = self.transformer_decoder.embed_dims

self.decoder_input_projs = ModuleList()
Expand All @@ -125,7 +125,8 @@ def __init__(self,
feat_channels, self.decoder_embed_dims, kernel_size=1))
else:
self.decoder_input_projs.append(nn.Identity())
self.decoder_positional_encoding = MODELS.build(positional_encoding)
self.decoder_positional_encoding = SinePositionalEncoding(
**positional_encoding)
self.query_embed = nn.Embedding(self.num_queries, feat_channels)
self.query_feat = nn.Embedding(self.num_queries, feat_channels)
# from low resolution to high resolution
Expand Down Expand Up @@ -338,7 +339,7 @@ def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor,
"""Forward for head part which is called after every decoder layer.
Args:
decoder_out (Tensor): in shape (num_queries, batch_size, c).
decoder_out (Tensor): in shape (batch_size, num_queries, c).
mask_feature (Tensor): in shape (batch_size, c, h, w).
attn_mask_target_size (tuple[int, int]): target attention
mask size.
Expand All @@ -355,7 +356,6 @@ def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor,
(batch_size * num_heads, num_queries, h, w).
"""
decoder_out = self.transformer_decoder.post_norm(decoder_out)
decoder_out = decoder_out.transpose(0, 1)
# shape (num_queries, batch_size, c)
cls_pred = self.cls_embed(decoder_out)
# shape (num_queries, batch_size, c)
Expand Down Expand Up @@ -408,25 +408,25 @@ def forward(self, x: List[Tensor],
decoder_positional_encodings = []
for i in range(self.num_transformer_feat_level):
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
decoder_input = decoder_input.flatten(2).permute(2, 0, 1)
# shape (batch_size, c, h, w) -> (batch_size, h*w, c)
decoder_input = decoder_input.flatten(2).permute(0, 2, 1)
level_embed = self.level_embed.weight[i].view(1, 1, -1)
decoder_input = decoder_input + level_embed
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
# shape (batch_size, c, h, w) -> (batch_size, h*w, c)
mask = decoder_input.new_zeros(
(batch_size, ) + multi_scale_memorys[i].shape[-2:],
dtype=torch.bool)
decoder_positional_encoding = self.decoder_positional_encoding(
mask)
decoder_positional_encoding = decoder_positional_encoding.flatten(
2).permute(2, 0, 1)
2).permute(0, 2, 1)
decoder_inputs.append(decoder_input)
decoder_positional_encodings.append(decoder_positional_encoding)
# shape (num_queries, c) -> (num_queries, batch_size, c)
query_feat = self.query_feat.weight.unsqueeze(1).repeat(
(1, batch_size, 1))
query_embed = self.query_embed.weight.unsqueeze(1).repeat(
(1, batch_size, 1))
# shape (num_queries, c) -> (batch_size, num_queries, c)
query_feat = self.query_feat.weight.unsqueeze(0).repeat(
(batch_size, 1, 1))
query_embed = self.query_embed.weight.unsqueeze(0).repeat(
(batch_size, 1, 1))

cls_pred_list = []
mask_pred_list = []
Expand All @@ -443,14 +443,13 @@ def forward(self, x: List[Tensor],

# cross_attn + self_attn
layer = self.transformer_decoder.layers[i]
attn_masks = [attn_mask, None]
query_feat = layer(
query=query_feat,
key=decoder_inputs[level_idx],
value=decoder_inputs[level_idx],
query_pos=query_embed,
key_pos=decoder_positional_encodings[level_idx],
attn_masks=attn_masks,
cross_attn_mask=attn_mask,
query_key_padding_mask=None,
# here we do not apply masking on padded region
key_padding_mask=None)
Expand Down
Loading

0 comments on commit 4eafa7c

Please sign in to comment.