diff --git a/README.md b/README.md index 73f5f8fe26..00620bcb3e 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,7 @@ Supported methods: - [x] [DPT (ArXiv'2021)](configs/dpt) - [x] [Segmenter (ICCV'2021)](configs/segmenter) - [x] [SegFormer (NeurIPS'2021)](configs/segformer) +- [x] [K-Net (NeurIPS'2021)](configs/knet) Supported datasets: diff --git a/README_zh-CN.md b/README_zh-CN.md index 657f50c814..afaaa8a3e6 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -120,6 +120,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O - [x] [DPT (ArXiv'2021)](configs/dpt) - [x] [Segmenter (ICCV'2021)](configs/segmenter) - [x] [SegFormer (NeurIPS'2021)](configs/segformer) +- [x] [K-Net (NeurIPS'2021)](configs/knet) 已支持的数据集: diff --git a/configs/knet/README.md b/configs/knet/README.md new file mode 100644 index 0000000000..ef223360bd --- /dev/null +++ b/configs/knet/README.md @@ -0,0 +1,49 @@ +# K-Net + +[K-Net: Towards Unified Image Segmentation](https://arxiv.org/abs/2106.14855) + +## Introduction + + + +Official Repo + +Code Snippet + +## Abstract + + + +Semantic, instance, and panoptic segmentations have been addressed using different and specialized frameworks despite their underlying connections. This paper presents a unified, simple, and effective framework for these essentially similar tasks. The framework, named K-Net, segments both instances and semantic categories consistently by a group of learnable kernels, where each kernel is responsible for generating a mask for either a potential instance or a stuff class. To remedy the difficulties of distinguishing various instances, we propose a kernel update strategy that enables each kernel dynamic and conditional on its meaningful group in the input image. K-Net can be trained in an end-to-end manner with bipartite matching, and its training and inference are naturally NMS-free and box-free. Without bells and whistles, K-Net surpasses all previous published state-of-the-art single-model results of panoptic segmentation on MS COCO test-dev split and semantic segmentation on ADE20K val split with 55.2% PQ and 54.3% mIoU, respectively. Its instance segmentation performance is also on par with Cascade Mask R-CNN on MS COCO with 60%-90% faster inference speeds. Code and models will be released at [this https URL](https://github.com/ZwwWayne/K-Net/). + + +
+ +
+ +```bibtex +@inproceedings{zhang2021knet, + title={{K-Net: Towards} Unified Image Segmentation}, + author={Wenwei Zhang and Jiangmiao Pang and Kai Chen and Chen Change Loy}, + year={2021}, + booktitle={NeurIPS}, +} +``` + +## Results and models + +### ADE20K + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| --------------- | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | ---------------------------------------------------------------------------------------------------------------------------------------- | ----- | +| KNet + FCN | R-50-D8 | 512x512 | 80000 | 7.01 | 19.24 | 43.60 | 45.12 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_043751-abcab920.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_043751.log.json) | +| KNet + PSPNet | R-50-D8 | 512x512 | 80000 | 6.98 | 20.04 | 44.18 | 45.58 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_054634-d2c72240.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_054634.log.json) | +| KNet + DeepLabV3| R-50-D8 | 512x512 | 80000 | 7.42 | 12.10 | 45.06 | 46.11 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_041642-00c8fbeb.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_041642.log.json) | +| KNet + UperNet | R-50-D8 | 512x512 | 80000 | 7.34 | 17.11 | 43.45 | 44.07 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220304_125657-215753b0.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220304_125657.log.json) | +| KNet + UperNet | Swin-T | 512x512 | 80000 | 7.57 | 15.56 | 45.84 | 46.27 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k_20220303_133059-7545e1dc.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k_20220303_133059.log.json) | +| KNet + UperNet | Swin-L | 512x512 | 80000 | 13.5 | 8.29 | 52.05 | 53.24 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k_20220303_154559-d8da9a90.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k_20220303_154559.log.json) | +| KNet + UperNet | Swin-L | 640x640 | 80000 | 13.54 | 8.29 | 52.21 | 53.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k_20220301_220747-8787fc71.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k_20220301_220747.log.json) | + +Note: + +- All experiments of K-Net are implemented with 8 V100 (32G) GPUs with 2 samplers per GPU. diff --git a/configs/knet/knet.yml b/configs/knet/knet.yml new file mode 100644 index 0000000000..5e2e529557 --- /dev/null +++ b/configs/knet/knet.yml @@ -0,0 +1,169 @@ +Collections: +- Name: KNet + Metadata: + Training Data: + - ADE20K + Paper: + URL: https://arxiv.org/abs/2106.14855 + Title: 'K-Net: Towards Unified Image Segmentation' + README: configs/knet/README.md + Code: + URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.23.0/mmseg/models/decode_heads/knet_head.py#L392 + Version: v0.23.0 + Converted From: + Code: https://github.com/ZwwWayne/K-Net/ +Models: +- Name: knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k + In Collection: KNet + Metadata: + backbone: R-50-D8 + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 51.98 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 7.01 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 43.6 + mIoU(ms+flip): 45.12 + Config: configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_043751-abcab920.pth +- Name: knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k + In Collection: KNet + Metadata: + backbone: R-50-D8 + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 49.9 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 6.98 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 44.18 + mIoU(ms+flip): 45.58 + Config: configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_054634-d2c72240.pth +- Name: knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k + In Collection: KNet + Metadata: + backbone: R-50-D8 + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 82.64 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 7.42 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 45.06 + mIoU(ms+flip): 46.11 + Config: configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_041642-00c8fbeb.pth +- Name: knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k + In Collection: KNet + Metadata: + backbone: R-50-D8 + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 58.45 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 7.34 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 43.45 + mIoU(ms+flip): 44.07 + Config: configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220304_125657-215753b0.pth +- Name: knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k + In Collection: KNet + Metadata: + backbone: Swin-T + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 64.27 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 7.57 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 45.84 + mIoU(ms+flip): 46.27 + Config: configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k_20220303_133059-7545e1dc.pth +- Name: knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k + In Collection: KNet + Metadata: + backbone: Swin-L + crop size: (512,512) + lr schd: 80000 + inference time (ms/im): + - value: 120.63 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 13.5 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 52.05 + mIoU(ms+flip): 53.24 + Config: configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k_20220303_154559-d8da9a90.pth +- Name: knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k + In Collection: KNet + Metadata: + backbone: Swin-L + crop size: (640,640) + lr schd: 80000 + inference time (ms/im): + - value: 120.63 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (640,640) + Training Memory (GB): 13.54 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 52.21 + mIoU(ms+flip): 53.34 + Config: configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k_20220301_220747-8787fc71.pth diff --git a/configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py b/configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py new file mode 100644 index 0000000000..3edb05c875 --- /dev/null +++ b/configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py @@ -0,0 +1,93 @@ +_base_ = [ + '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] + +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +num_stages = 3 +conv_kernel_size = 1 +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='IterativeDecodeHead', + num_stages=num_stages, + kernel_update_head=[ + dict( + type='KernelUpdateHead', + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=1, + feedforward_channels=2048, + in_channels=512, + out_channels=512, + dropout=0.0, + conv_kernel_size=conv_kernel_size, + ffn_act_cfg=dict(type='ReLU', inplace=True), + with_ffn=True, + feat_transform_cfg=dict( + conv_cfg=dict(type='Conv2d'), act_cfg=None), + kernel_updator_cfg=dict( + type='KernelUpdator', + in_channels=256, + feat_channels=256, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'))) for _ in range(num_stages) + ], + kernel_generate_head=dict( + type='ASPPHead', + in_channels=2048, + in_index=3, + channels=512, + dilations=(1, 12, 24, 36), + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) + +# optimizer +optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005) +optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2)) +# learning policy +lr_config = dict( + _delete_=True, + policy='step', + warmup='linear', + warmup_iters=1000, + warmup_ratio=0.001, + step=[60000, 72000], + by_epoch=False) +# In K-Net implementation we use batch size 2 per GPU as default +data = dict(samples_per_gpu=2, workers_per_gpu=2) diff --git a/configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py b/configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py new file mode 100644 index 0000000000..29a088f721 --- /dev/null +++ b/configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py @@ -0,0 +1,93 @@ +_base_ = [ + '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] + +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +num_stages = 3 +conv_kernel_size = 1 +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='IterativeDecodeHead', + num_stages=num_stages, + kernel_update_head=[ + dict( + type='KernelUpdateHead', + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=1, + feedforward_channels=2048, + in_channels=512, + out_channels=512, + dropout=0.0, + conv_kernel_size=conv_kernel_size, + ffn_act_cfg=dict(type='ReLU', inplace=True), + with_ffn=True, + feat_transform_cfg=dict( + conv_cfg=dict(type='Conv2d'), act_cfg=None), + kernel_updator_cfg=dict( + type='KernelUpdator', + in_channels=256, + feat_channels=256, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'))) for _ in range(num_stages) + ], + kernel_generate_head=dict( + type='FCNHead', + in_channels=2048, + in_index=3, + channels=512, + num_convs=2, + concat_input=True, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) +# optimizer +optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005) +optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2)) +# learning policy +lr_config = dict( + _delete_=True, + policy='step', + warmup='linear', + warmup_iters=1000, + warmup_ratio=0.001, + step=[60000, 72000], + by_epoch=False) +# In K-Net implementation we use batch size 2 per GPU as default +data = dict(samples_per_gpu=2, workers_per_gpu=2) diff --git a/configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py b/configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py new file mode 100644 index 0000000000..d77a3b4423 --- /dev/null +++ b/configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py @@ -0,0 +1,92 @@ +_base_ = [ + '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] + +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +num_stages = 3 +conv_kernel_size = 1 +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='IterativeDecodeHead', + num_stages=num_stages, + kernel_update_head=[ + dict( + type='KernelUpdateHead', + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=1, + feedforward_channels=2048, + in_channels=512, + out_channels=512, + dropout=0.0, + conv_kernel_size=conv_kernel_size, + ffn_act_cfg=dict(type='ReLU', inplace=True), + with_ffn=True, + feat_transform_cfg=dict( + conv_cfg=dict(type='Conv2d'), act_cfg=None), + kernel_updator_cfg=dict( + type='KernelUpdator', + in_channels=256, + feat_channels=256, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'))) for _ in range(num_stages) + ], + kernel_generate_head=dict( + type='PSPHead', + in_channels=2048, + in_index=3, + channels=512, + pool_scales=(1, 2, 3, 6), + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) +# optimizer +optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005) +optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2)) +# learning policy +lr_config = dict( + _delete_=True, + policy='step', + warmup='linear', + warmup_iters=1000, + warmup_ratio=0.001, + step=[60000, 72000], + by_epoch=False) +# In K-Net implementation we use batch size 2 per GPU as default +data = dict(samples_per_gpu=2, workers_per_gpu=2) diff --git a/configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py b/configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py new file mode 100644 index 0000000000..0071cea750 --- /dev/null +++ b/configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py @@ -0,0 +1,93 @@ +_base_ = [ + '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] + +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +num_stages = 3 +conv_kernel_size = 1 + +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='IterativeDecodeHead', + num_stages=num_stages, + kernel_update_head=[ + dict( + type='KernelUpdateHead', + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=1, + feedforward_channels=2048, + in_channels=512, + out_channels=512, + dropout=0.0, + conv_kernel_size=conv_kernel_size, + ffn_act_cfg=dict(type='ReLU', inplace=True), + with_ffn=True, + feat_transform_cfg=dict( + conv_cfg=dict(type='Conv2d'), act_cfg=None), + kernel_updator_cfg=dict( + type='KernelUpdator', + in_channels=256, + feat_channels=256, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'))) for _ in range(num_stages) + ], + kernel_generate_head=dict( + type='UPerHead', + in_channels=[256, 512, 1024, 2048], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) +# optimizer +optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005) +optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2)) +# learning policy +lr_config = dict( + _delete_=True, + policy='step', + warmup='linear', + warmup_iters=1000, + warmup_ratio=0.001, + step=[60000, 72000], + by_epoch=False) +# In K-Net implementation we use batch size 2 per GPU as default +data = dict(samples_per_gpu=2, workers_per_gpu=2) diff --git a/configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py b/configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py new file mode 100644 index 0000000000..b9d1a0952d --- /dev/null +++ b/configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py @@ -0,0 +1,19 @@ +_base_ = 'knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py' + +checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220308-d5bdebaf.pth' # noqa +# model settings +model = dict( + pretrained=checkpoint_file, + backbone=dict( + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=7, + use_abs_pos_embed=False, + drop_path_rate=0.3, + patch_norm=True), + decode_head=dict( + kernel_generate_head=dict(in_channels=[192, 384, 768, 1536])), + auxiliary_head=dict(in_channels=768)) +# In K-Net implementation we use batch size 2 per GPU as default +data = dict(samples_per_gpu=2, workers_per_gpu=2) diff --git a/configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py b/configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py new file mode 100644 index 0000000000..fc6e9fe39f --- /dev/null +++ b/configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py @@ -0,0 +1,54 @@ +_base_ = 'knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py' + +checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220308-d5bdebaf.pth' # noqa +# model settings +model = dict( + pretrained=checkpoint_file, + backbone=dict( + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=7, + use_abs_pos_embed=False, + drop_path_rate=0.4, + patch_norm=True), + decode_head=dict( + kernel_generate_head=dict(in_channels=[192, 384, 768, 1536])), + auxiliary_head=dict(in_channels=768)) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (640, 640) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 640), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) +# In K-Net implementation we use batch size 2 per GPU as default +data = dict(samples_per_gpu=2, workers_per_gpu=2) diff --git a/configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py b/configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py new file mode 100644 index 0000000000..0b29b2b8c4 --- /dev/null +++ b/configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py @@ -0,0 +1,57 @@ +_base_ = 'knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py' + +checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220308-f41b89d3.pth' # noqa + +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +num_stages = 3 +conv_kernel_size = 1 + +model = dict( + type='EncoderDecoder', + pretrained=checkpoint_file, + backbone=dict( + _delete_=True, + type='SwinTransformer', + embed_dims=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + use_abs_pos_embed=False, + patch_norm=True, + out_indices=(0, 1, 2, 3)), + decode_head=dict( + kernel_generate_head=dict(in_channels=[96, 192, 384, 768])), + auxiliary_head=dict(in_channels=384)) + +# modify learning rate following the official implementation of Swin Transformer # noqa +optimizer = dict( + _delete_=True, + type='AdamW', + lr=0.00006, + betas=(0.9, 0.999), + weight_decay=0.0005, + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) +optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2)) +# learning policy +lr_config = dict( + _delete_=True, + policy='step', + warmup='linear', + warmup_iters=1000, + warmup_ratio=0.001, + step=[60000, 72000], + by_epoch=False) +# In K-Net implementation we use batch size 2 per GPU as default +data = dict(samples_per_gpu=2, workers_per_gpu=2) diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index dcde813264..8add7615c2 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -13,6 +13,7 @@ from .fpn_head import FPNHead from .gc_head import GCHead from .isa_head import ISAHead +from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator from .lraspp_head import LRASPPHead from .nl_head import NLHead from .ocr_head import OCRHead @@ -34,5 +35,6 @@ 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', - 'SegformerHead', 'ISAHead', 'STDCHead' + 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', + 'KernelUpdateHead', 'KernelUpdator' ] diff --git a/mmseg/models/decode_heads/knet_head.py b/mmseg/models/decode_heads/knet_head.py new file mode 100644 index 0000000000..f73daccb64 --- /dev/null +++ b/mmseg/models/decode_heads/knet_head.py @@ -0,0 +1,453 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER, + MultiheadAttention, + build_transformer_layer) + +from mmseg.models.builder import HEADS, build_head +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.utils import get_root_logger + + +@TRANSFORMER_LAYER.register_module() +class KernelUpdator(nn.Module): + """Dynamic Kernel Updator in Kernel Update Head. + + Args: + in_channels (int): The number of channels of input feature map. + Default: 256. + feat_channels (int): The number of middle-stage channels in + the kernel updator. Default: 64. + out_channels (int): The number of output channels. + gate_sigmoid (bool): Whether use sigmoid function in gate + mechanism. Default: True. + gate_norm_act (bool): Whether add normalization and activation + layer in gate mechanism. Default: False. + activate_out: Whether add activation after gate mechanism. + Default: False. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='LN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + """ + + def __init__( + self, + in_channels=256, + feat_channels=64, + out_channels=None, + gate_sigmoid=True, + gate_norm_act=False, + activate_out=False, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='ReLU', inplace=True), + ): + super(KernelUpdator, self).__init__() + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.gate_sigmoid = gate_sigmoid + self.gate_norm_act = gate_norm_act + self.activate_out = activate_out + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.feat_channels + self.num_params_out = self.feat_channels + self.dynamic_layer = nn.Linear( + self.in_channels, self.num_params_in + self.num_params_out) + self.input_layer = nn.Linear(self.in_channels, + self.num_params_in + self.num_params_out, + 1) + self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + if self.gate_norm_act: + self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, update_feature, input_feature): + """Forward function of KernelUpdator. + + Args: + update_feature (torch.Tensor): Feature map assembled from + each group. It would be reshaped with last dimension + shape: `self.in_channels`. + input_feature (torch.Tensor): Intermediate feature + with shape: (N, num_classes, conv_kernel_size**2, channels). + Returns: + Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is + the number of classes, C1 and C2 are the feature map channels of + KernelUpdateHead and KernelUpdator, respectively. + """ + + update_feature = update_feature.reshape(-1, self.in_channels) + num_proposals = update_feature.size(0) + # dynamic_layer works for + # phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper + parameters = self.dynamic_layer(update_feature) + param_in = parameters[:, :self.num_params_in].view( + -1, self.feat_channels) + param_out = parameters[:, -self.num_params_out:].view( + -1, self.feat_channels) + + # input_layer works for + # phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper + input_feats = self.input_layer( + input_feature.reshape(num_proposals, -1, self.feat_channels)) + input_in = input_feats[..., :self.num_params_in] + input_out = input_feats[..., -self.num_params_out:] + + # `gate_feats` is F^G in K-Net paper + gate_feats = input_in * param_in.unsqueeze(-2) + if self.gate_norm_act: + gate_feats = self.activation(self.gate_norm(gate_feats)) + + input_gate = self.input_norm_in(self.input_gate(gate_feats)) + update_gate = self.norm_in(self.update_gate(gate_feats)) + if self.gate_sigmoid: + input_gate = input_gate.sigmoid() + update_gate = update_gate.sigmoid() + param_out = self.norm_out(param_out) + input_out = self.input_norm_out(input_out) + + if self.activate_out: + param_out = self.activation(param_out) + input_out = self.activation(input_out) + + # Gate mechanism. Eq.(5) in original paper. + # param_out has shape (batch_size, feat_channels, out_channels) + features = update_gate * param_out.unsqueeze( + -2) + input_gate * input_out + + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features + + +@HEADS.register_module() +class KernelUpdateHead(nn.Module): + """Kernel Update Head in K-Net. + + Args: + num_classes (int): Number of classes. Default: 150. + num_ffn_fcs (int): The number of fully-connected layers in + FFNs. Default: 2. + num_heads (int): The number of parallel attention heads. + Default: 8. + num_mask_fcs (int): The number of fully connected layers for + mask prediction. Default: 3. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 2048. + in_channels (int): The number of channels of input feature map. + Default: 256. + out_channels (int): The number of output channels. + Default: 256. + dropout (float): The Probability of an element to be + zeroed in MultiheadAttention and FFN. Default 0.0. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + ffn_act_cfg (dict): Config of activation layers in FFN. + Default: dict(type='ReLU'). + conv_kernel_size (int): The kernel size of convolution in + Kernel Update Head for dynamic kernel updation. + Default: 1. + feat_transform_cfg (dict | None): Config of feature transform. + Default: None. + kernel_init (bool): Whether initiate mask kernel in mask head. + Default: False. + with_ffn (bool): Whether add FFN in kernel update head. + Default: True. + feat_gather_stride (int): Stride of convolution in feature transform. + Default: 1. + mask_transform_stride (int): Stride of mask transform. + Default: 1. + kernel_updator_cfg (dict): Config of kernel updator. + Default: dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')). + """ + + def __init__(self, + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=3, + feedforward_channels=2048, + in_channels=256, + out_channels=256, + dropout=0.0, + act_cfg=dict(type='ReLU', inplace=True), + ffn_act_cfg=dict(type='ReLU', inplace=True), + conv_kernel_size=1, + feat_transform_cfg=None, + kernel_init=False, + with_ffn=True, + feat_gather_stride=1, + mask_transform_stride=1, + kernel_updator_cfg=dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'))): + super(KernelUpdateHead, self).__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.out_channels = out_channels + self.fp16_enabled = False + self.dropout = dropout + self.num_heads = num_heads + self.kernel_init = kernel_init + self.with_ffn = with_ffn + self.conv_kernel_size = conv_kernel_size + self.feat_gather_stride = feat_gather_stride + self.mask_transform_stride = mask_transform_stride + + self.attention = MultiheadAttention(in_channels * conv_kernel_size**2, + num_heads, dropout) + self.attention_norm = build_norm_layer( + dict(type='LN'), in_channels * conv_kernel_size**2)[1] + self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg) + + if feat_transform_cfg is not None: + kernel_size = feat_transform_cfg.pop('kernel_size', 1) + transform_channels = in_channels + self.feat_transform = ConvModule( + transform_channels, + in_channels, + kernel_size, + stride=feat_gather_stride, + padding=int(feat_gather_stride // 2), + **feat_transform_cfg) + else: + self.feat_transform = None + + if self.with_ffn: + self.ffn = FFN( + in_channels, + feedforward_channels, + num_ffn_fcs, + act_cfg=ffn_act_cfg, + dropout=dropout) + self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1] + + self.mask_fcs = nn.ModuleList() + for _ in range(num_mask_fcs): + self.mask_fcs.append( + nn.Linear(in_channels, in_channels, bias=False)) + self.mask_fcs.append( + build_norm_layer(dict(type='LN'), in_channels)[1]) + self.mask_fcs.append(build_activation_layer(act_cfg)) + + self.fc_mask = nn.Linear(in_channels, out_channels) + + def init_weights(self): + """Use xavier initialization for all weight parameter and set + classification head bias as a specific value when use focal loss.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + else: + # adopt the default initialization for + # the weight and bias of the layer norm + pass + if self.kernel_init: + logger = get_root_logger() + logger.info( + 'mask kernel in mask head is normal initialized by std 0.01') + nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01) + + def forward(self, x, proposal_feat, mask_preds, mask_shape=None): + """Forward function of Dynamic Instance Interactive Head. + + Args: + x (Tensor): Feature map from FPN with shape + (batch_size, feature_dimensions, H , W). + proposal_feat (Tensor): Intermediate feature get from + diihead in last stage, has shape + (batch_size, num_proposals, feature_dimensions) + mask_preds (Tensor): mask prediction from the former stage in shape + (batch_size, num_proposals, H, W). + + Returns: + Tuple: The first tensor is predicted mask with shape + (N, num_classes, H, W), the second tensor is dynamic kernel + with shape (N, num_classes, channels, K, K). + """ + N, num_proposals = proposal_feat.shape[:2] + if self.feat_transform is not None: + x = self.feat_transform(x) + + C, H, W = x.shape[-3:] + + mask_h, mask_w = mask_preds.shape[-2:] + if mask_h != H or mask_w != W: + gather_mask = F.interpolate( + mask_preds, (H, W), align_corners=False, mode='bilinear') + else: + gather_mask = mask_preds + + sigmoid_masks = gather_mask.softmax(dim=1) + + # Group Feature Assembling. Eq.(3) in original paper. + # einsum is faster than bmm by 30% + x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x) + + # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C] + proposal_feat = proposal_feat.reshape(N, num_proposals, + self.in_channels, + -1).permute(0, 1, 3, 2) + obj_feat = self.kernel_update_conv(x_feat, proposal_feat) + + # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C] + obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2) + obj_feat = self.attention_norm(self.attention(obj_feat)) + # [N, B, K*K*C] -> [B, N, K*K*C] + obj_feat = obj_feat.permute(1, 0, 2) + + # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C] + obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels) + + # FFN + if self.with_ffn: + obj_feat = self.ffn_norm(self.ffn(obj_feat)) + + mask_feat = obj_feat + + for reg_layer in self.mask_fcs: + mask_feat = reg_layer(mask_feat) + + # [B, N, K*K, C] -> [B, N, C, K*K] + mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2) + + if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1): + mask_x = F.interpolate( + x, scale_factor=0.5, mode='bilinear', align_corners=False) + H, W = mask_x.shape[-2:] + else: + mask_x = x + # group conv is 5x faster than unfold and uses about 1/5 memory + # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms + # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369 + # but in real training group conv is slower than concat batch + # so we keep using concat batch. + # fold_x = F.unfold( + # mask_x, + # self.conv_kernel_size, + # padding=int(self.conv_kernel_size // 2)) + # mask_feat = mask_feat.reshape(N, num_proposals, -1) + # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x) + # [B, N, C, K*K] -> [B*N, C, K, K] + mask_feat = mask_feat.reshape(N, num_proposals, C, + self.conv_kernel_size, + self.conv_kernel_size) + # [B, C, H, W] -> [1, B*C, H, W] + new_mask_preds = [] + for i in range(N): + new_mask_preds.append( + F.conv2d( + mask_x[i:i + 1], + mask_feat[i], + padding=int(self.conv_kernel_size // 2))) + + new_mask_preds = torch.cat(new_mask_preds, dim=0) + new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W) + if self.mask_transform_stride == 2: + new_mask_preds = F.interpolate( + new_mask_preds, + scale_factor=2, + mode='bilinear', + align_corners=False) + + if mask_shape is not None and mask_shape[0] != H: + new_mask_preds = F.interpolate( + new_mask_preds, + mask_shape, + align_corners=False, + mode='bilinear') + + return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape( + N, num_proposals, self.in_channels, self.conv_kernel_size, + self.conv_kernel_size) + + +@HEADS.register_module() +class IterativeDecodeHead(BaseDecodeHead): + """K-Net: Towards Unified Image Segmentation. + + This head is the implementation of + `K-Net: `_. + + Args: + num_stages (int): The number of stages (kernel update heads) + in IterativeDecodeHead. Default: 3. + kernel_generate_head:(dict): Config of kernel generate head which + generate mask predictions, dynamic kernels and class predictions + for next kernel update heads. + kernel_update_head (dict): Config of kernel update head which refine + dynamic kernels and class predictions iteratively. + + """ + + def __init__(self, num_stages, kernel_generate_head, kernel_update_head, + **kwargs): + super(BaseDecodeHead, self).__init__(**kwargs) + assert num_stages == len(kernel_update_head) + self.num_stages = num_stages + self.kernel_generate_head = build_head(kernel_generate_head) + self.kernel_update_head = nn.ModuleList() + self.align_corners = self.kernel_generate_head.align_corners + self.num_classes = self.kernel_generate_head.num_classes + self.input_transform = self.kernel_generate_head.input_transform + self.ignore_index = self.kernel_generate_head.ignore_index + + for head_cfg in kernel_update_head: + self.kernel_update_head.append(build_head(head_cfg)) + + def forward(self, inputs): + """Forward function.""" + feats = self.kernel_generate_head._forward_feature(inputs) + sem_seg = self.kernel_generate_head.cls_seg(feats) + seg_kernels = self.kernel_generate_head.conv_seg.weight.clone() + seg_kernels = seg_kernels[None].expand( + feats.size(0), *seg_kernels.size()) + + stage_segs = [sem_seg] + for i in range(self.num_stages): + sem_seg, seg_kernels = self.kernel_update_head[i](feats, + seg_kernels, + sem_seg) + stage_segs.append(sem_seg) + if self.training: + return stage_segs + # only return the prediction of the last stage during testing + return stage_segs[-1] + + def losses(self, seg_logit, seg_label): + losses = dict() + for i, logit in enumerate(seg_logit): + loss = self.kernel_generate_head.losses(logit, seg_label) + for k, v in loss.items(): + losses[f'{k}.s{i}'] = v + + return losses diff --git a/model-index.yml b/model-index.yml index cd82220bbd..235ad7f6e7 100644 --- a/model-index.yml +++ b/model-index.yml @@ -22,6 +22,7 @@ Import: - configs/hrnet/hrnet.yml - configs/icnet/icnet.yml - configs/isanet/isanet.yml +- configs/knet/knet.yml - configs/mobilenet_v2/mobilenet_v2.yml - configs/mobilenet_v3/mobilenet_v3.yml - configs/nonlocal_net/nonlocal_net.yml diff --git a/tests/test_models/test_heads/test_knet_head.py b/tests/test_models/test_heads/test_knet_head.py new file mode 100644 index 0000000000..e6845a6d3f --- /dev/null +++ b/tests/test_models/test_heads/test_knet_head.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmseg.models.decode_heads.knet_head import (IterativeDecodeHead, + KernelUpdateHead) +from .utils import to_cuda + +num_stages = 3 +conv_kernel_size = 1 + +kernel_updator_cfg = dict( + type='KernelUpdator', + in_channels=16, + feat_channels=16, + out_channels=16, + gate_norm_act=True, + activate_out=True, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')) + + +def test_knet_head(): + # test init function of kernel update head + kernel_update_head = KernelUpdateHead( + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=1, + feedforward_channels=128, + in_channels=32, + out_channels=32, + dropout=0.0, + conv_kernel_size=conv_kernel_size, + ffn_act_cfg=dict(type='ReLU', inplace=True), + with_ffn=True, + feat_transform_cfg=dict(conv_cfg=dict(type='Conv2d'), act_cfg=None), + kernel_init=True, + kernel_updator_cfg=kernel_updator_cfg) + kernel_update_head.init_weights() + + head = IterativeDecodeHead( + num_stages=num_stages, + kernel_update_head=[ + dict( + type='KernelUpdateHead', + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=1, + feedforward_channels=128, + in_channels=32, + out_channels=32, + dropout=0.0, + conv_kernel_size=conv_kernel_size, + ffn_act_cfg=dict(type='ReLU', inplace=True), + with_ffn=True, + feat_transform_cfg=dict( + conv_cfg=dict(type='Conv2d'), act_cfg=None), + kernel_init=False, + kernel_updator_cfg=kernel_updator_cfg) + for _ in range(num_stages) + ], + kernel_generate_head=dict( + type='FCNHead', + in_channels=128, + in_index=3, + channels=32, + num_convs=2, + concat_input=True, + dropout_ratio=0.1, + num_classes=150, + align_corners=False)) + head.init_weights() + inputs = [ + torch.randn(1, 16, 27, 32), + torch.randn(1, 32, 27, 16), + torch.randn(1, 64, 27, 16), + torch.randn(1, 128, 27, 16) + ] + + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs[-1].shape == (1, head.num_classes, 27, 16) + + # test whether only return the prediction of + # the last stage during testing + with torch.no_grad(): + head.eval() + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 27, 16) + + # test K-Net without `feat_transform_cfg` + head = IterativeDecodeHead( + num_stages=num_stages, + kernel_update_head=[ + dict( + type='KernelUpdateHead', + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=1, + feedforward_channels=128, + in_channels=32, + out_channels=32, + dropout=0.0, + conv_kernel_size=conv_kernel_size, + ffn_act_cfg=dict(type='ReLU', inplace=True), + with_ffn=True, + feat_transform_cfg=None, + kernel_updator_cfg=kernel_updator_cfg) + for _ in range(num_stages) + ], + kernel_generate_head=dict( + type='FCNHead', + in_channels=128, + in_index=3, + channels=32, + num_convs=2, + concat_input=True, + dropout_ratio=0.1, + num_classes=150, + align_corners=False)) + head.init_weights() + + inputs = [ + torch.randn(1, 16, 27, 32), + torch.randn(1, 32, 27, 16), + torch.randn(1, 64, 27, 16), + torch.randn(1, 128, 27, 16) + ] + + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs[-1].shape == (1, head.num_classes, 27, 16) + + # test K-Net with + # self.mask_transform_stride == 2 and self.feat_gather_stride == 1 + head = IterativeDecodeHead( + num_stages=num_stages, + kernel_update_head=[ + dict( + type='KernelUpdateHead', + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=1, + feedforward_channels=128, + in_channels=32, + out_channels=32, + dropout=0.0, + conv_kernel_size=conv_kernel_size, + ffn_act_cfg=dict(type='ReLU', inplace=True), + with_ffn=True, + feat_transform_cfg=dict( + conv_cfg=dict(type='Conv2d'), act_cfg=None), + kernel_init=False, + mask_transform_stride=2, + feat_gather_stride=1, + kernel_updator_cfg=kernel_updator_cfg) + for _ in range(num_stages) + ], + kernel_generate_head=dict( + type='FCNHead', + in_channels=128, + in_index=3, + channels=32, + num_convs=2, + concat_input=True, + dropout_ratio=0.1, + num_classes=150, + align_corners=False)) + head.init_weights() + + inputs = [ + torch.randn(1, 16, 27, 32), + torch.randn(1, 32, 27, 16), + torch.randn(1, 64, 27, 16), + torch.randn(1, 128, 27, 16) + ] + + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs[-1].shape == (1, head.num_classes, 26, 16) + + # test loss function in K-Net + fake_label = torch.ones_like( + outputs[-1][:, 0:1, :, :], dtype=torch.int16).long() + loss = head.losses(seg_logit=outputs, seg_label=fake_label) + assert loss['loss_ce.s0'] != torch.zeros_like(loss['loss_ce.s0']) + assert loss['loss_ce.s1'] != torch.zeros_like(loss['loss_ce.s1']) + assert loss['loss_ce.s2'] != torch.zeros_like(loss['loss_ce.s2']) + assert loss['loss_ce.s3'] != torch.zeros_like(loss['loss_ce.s3'])