Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Maskformer to mmdet #7212

Merged
merged 6 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions configs/maskformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Per-Pixel Classification is Not All You Need for Semantic Segmentation

## Abstract

Modern approaches typically formulate semantic segmentation as a per-pixel classification
task, while instance-level segmentation is handled with an alternative mask
classification. Our key insight: mask classification is sufficiently general to solve
both semantic- and instance-level segmentation tasks in a unified manner using
the exact same model, loss, and training procedure. Following this observation,
we propose MaskFormer, a simple mask classification model which predicts a
set of binary masks, each associated with a single global class label prediction.
Overall, the proposed mask classification-based method simplifies the landscape
of effective approaches to semantic and panoptic segmentation tasks and shows
excellent empirical results. In particular, we observe that MaskFormer outperforms
per-pixel classification baselines when the number of classes is large. Our mask
classification-based method outperforms both current state-of-the-art semantic
(55.6 mIoU on ADE20K) and panoptic segmentation (52.7 PQ on COCO) models.

<div align=center>
<img src="https://camo.githubusercontent.com/29fb22298d506ce176caad3006a7b05ef2603ca12cece6c788b7e73c046e8bc9/68747470733a2f2f626f77656e63303232312e6769746875622e696f2f696d616765732f6d61736b666f726d65722e706e67" height="300"/>
</div>

## Citation

```
@inproceedings{cheng2021maskformer,
title={Per-Pixel Classification is Not All You Need for Semantic Segmentation},
author={Bowen Cheng and Alexander G. Schwing and Alexander Kirillov},
journal={NeurIPS},
year={2021}
}
```

## Dataset

MaskFormer requires COCO and [COCO-panoptic](http://images.cocodataset.org/annotations/panoptic_annotations_trainval2017.zip) dataset for training and evaluation. You need to download and extract it in the COCO dataset path.
The directory should be like this.

```none
mmdetection
├── mmdet
├── tools
├── configs
├── data
│ ├── coco
│ │ ├── annotations
│ │ │ ├── panoptic_train2017.json
│ │ │ ├── panoptic_train2017
│ │ │ ├── panoptic_val2017.json
│ │ │ ├── panoptic_val2017
│ │ ├── train2017
│ │ ├── val2017
│ │ ├── test2017
```

## Results and Models

| Backbone | style | Lr schd | Mem (GB) | Inf time (fps) | PQ | SQ | RQ | PQ_th | SQ_th | RQ_th | PQ_st | SQ_st | RQ_st | Config | Download | detail |
| :------: | :-----: | :-----: | :------: | :------------: | :-: | :-: | :-: | :---: | :---: | :---: | :---: | :---: | :---: | :---------------------------------------------------------------------------------------------------------------------: | :----------------------: | :---: |
| R-50 | pytorch | 75e | | | | | | | | | | | | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/maskformer/maskformer_r50_mstrain_16x1_75e_coco.py) | | This version was mentioned in Table XI, in paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) |
220 changes: 220 additions & 0 deletions configs/maskformer/maskformer_r50_mstrain_16x1_75e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
_base_ = [
'../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
]

model = dict(
type='MaskFormer',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
panoptic_head=dict(
type='MaskFormerHead',
in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside
feat_channels=256,
out_channels=256,
num_things_classes=80,
num_stuff_classes=53,
num_queries=100,
pixel_decoder=dict(
type='TransformerEncoderPixelDecoder',
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
attn_drop=0.1,
proj_drop=0.1,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.1,
dropout_layer=None,
add_identity=True),
operation_order=('self_attn', 'norm', 'ffn', 'norm'),
norm_cfg=dict(type='LN'),
init_cfg=None,
batch_first=False),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True)),
enforce_decoder_input_project=False,
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
transformer_decoder=dict(
type='DetrTransformerDecoder',
return_intermediate=True,
num_layers=6,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
attn_drop=0.1,
proj_drop=0.1,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.1,
dropout_layer=None,
add_identity=True),
# the following parameter was not used,
# just make current api happy
feedforward_channels=2048,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')),
init_cfg=None),
loss_cls=dict(
type='CrossEntropyLoss',
bg_cls_weight=0.1,
use_sigmoid=False,
loss_weight=1.0,
reduction='mean',
class_weight=1.0),
loss_mask=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=20.0),
loss_dice=dict(
type='DiceLoss',
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=1.0)),
train_cfg=dict(
assigner=dict(
type='MaskHungarianAssigner',
cls_cost=dict(type='ClassificationCost', weight=1.0),
mask_cost=dict(
type='FocalLossCost', weight=20.0, binary_input=True),
dice_cost=dict(
type='DiceCost', weight=1.0, pred_act=True, eps=1.0)),
sampler=dict(type='MaskPseudoSampler')),
test_cfg=dict(object_mask_thr=0.8, iou_thr=0.8),
# pretrained=None,
init_cfg=None)

# dataset settings
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='LoadPanopticAnnotations',
with_bbox=True,
with_mask=True,
with_seg=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='AutoAugment',
policies=[[
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(
type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=1),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=1),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=1,
workers_per_gpu=1,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

# optimizer
optimizer = dict(
type='AdamW',
lr=0.0001,
weight_decay=0.0001,
eps=1e-8,
betas=(0.9, 0.999),
paramwise_cfg=dict(
custom_keys={
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'query_embed': dict(lr_mult=1.0, decay_mult=0.0)
},
norm_decay_mult=0.0))
optimizer_config = dict(grad_clip=dict(max_norm=0.01, norm_type=2))

# learning policy
lr_config = dict(
policy='step',
gamma=0.1,
by_epoch=True,
step=[50],
warmup='linear',
warmup_by_epoch=False,
warmup_ratio=1.0, # no warmup
warmup_iters=10)
runner = dict(type='EpochBasedRunner', max_epochs=75)
3 changes: 2 additions & 1 deletion mmdet/core/bbox/assigners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .center_region_assigner import CenterRegionAssigner
from .grid_assigner import GridAssigner
from .hungarian_assigner import HungarianAssigner
from .mask_hungarian_assigner import MaskHungarianAssigner
from .max_iou_assigner import MaxIoUAssigner
from .point_assigner import PointAssigner
from .region_assigner import RegionAssigner
Expand All @@ -17,5 +18,5 @@
'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner',
'TaskAlignedAssigner'
'TaskAlignedAssigner', 'MaskHungarianAssigner'
]
Loading