diff --git a/README.md b/README.md
index 73f5f8fe26..00620bcb3e 100644
--- a/README.md
+++ b/README.md
@@ -121,6 +121,7 @@ Supported methods:
- [x] [DPT (ArXiv'2021)](configs/dpt)
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
+- [x] [K-Net (NeurIPS'2021)](configs/knet)
Supported datasets:
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 657f50c814..afaaa8a3e6 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -120,6 +120,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [DPT (ArXiv'2021)](configs/dpt)
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
+- [x] [K-Net (NeurIPS'2021)](configs/knet)
已支持的数据集:
diff --git a/configs/knet/README.md b/configs/knet/README.md
new file mode 100644
index 0000000000..ef223360bd
--- /dev/null
+++ b/configs/knet/README.md
@@ -0,0 +1,49 @@
+# K-Net
+
+[K-Net: Towards Unified Image Segmentation](https://arxiv.org/abs/2106.14855)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+Semantic, instance, and panoptic segmentations have been addressed using different and specialized frameworks despite their underlying connections. This paper presents a unified, simple, and effective framework for these essentially similar tasks. The framework, named K-Net, segments both instances and semantic categories consistently by a group of learnable kernels, where each kernel is responsible for generating a mask for either a potential instance or a stuff class. To remedy the difficulties of distinguishing various instances, we propose a kernel update strategy that enables each kernel dynamic and conditional on its meaningful group in the input image. K-Net can be trained in an end-to-end manner with bipartite matching, and its training and inference are naturally NMS-free and box-free. Without bells and whistles, K-Net surpasses all previous published state-of-the-art single-model results of panoptic segmentation on MS COCO test-dev split and semantic segmentation on ADE20K val split with 55.2% PQ and 54.3% mIoU, respectively. Its instance segmentation performance is also on par with Cascade Mask R-CNN on MS COCO with 60%-90% faster inference speeds. Code and models will be released at [this https URL](https://github.com/ZwwWayne/K-Net/).
+
+
+
+
+
+
+```bibtex
+@inproceedings{zhang2021knet,
+ title={{K-Net: Towards} Unified Image Segmentation},
+ author={Wenwei Zhang and Jiangmiao Pang and Kai Chen and Chen Change Loy},
+ year={2021},
+ booktitle={NeurIPS},
+}
+```
+
+## Results and models
+
+### ADE20K
+
+| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
+| --------------- | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | ---------------------------------------------------------------------------------------------------------------------------------------- | ----- |
+| KNet + FCN | R-50-D8 | 512x512 | 80000 | 7.01 | 19.24 | 43.60 | 45.12 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_043751-abcab920.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_043751.log.json) |
+| KNet + PSPNet | R-50-D8 | 512x512 | 80000 | 6.98 | 20.04 | 44.18 | 45.58 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_054634-d2c72240.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_054634.log.json) |
+| KNet + DeepLabV3| R-50-D8 | 512x512 | 80000 | 7.42 | 12.10 | 45.06 | 46.11 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_041642-00c8fbeb.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_041642.log.json) |
+| KNet + UperNet | R-50-D8 | 512x512 | 80000 | 7.34 | 17.11 | 43.45 | 44.07 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220304_125657-215753b0.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220304_125657.log.json) |
+| KNet + UperNet | Swin-T | 512x512 | 80000 | 7.57 | 15.56 | 45.84 | 46.27 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k_20220303_133059-7545e1dc.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k_20220303_133059.log.json) |
+| KNet + UperNet | Swin-L | 512x512 | 80000 | 13.5 | 8.29 | 52.05 | 53.24 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k_20220303_154559-d8da9a90.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k_20220303_154559.log.json) |
+| KNet + UperNet | Swin-L | 640x640 | 80000 | 13.54 | 8.29 | 52.21 | 53.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k_20220301_220747-8787fc71.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k_20220301_220747.log.json) |
+
+Note:
+
+- All experiments of K-Net are implemented with 8 V100 (32G) GPUs with 2 samplers per GPU.
diff --git a/configs/knet/knet.yml b/configs/knet/knet.yml
new file mode 100644
index 0000000000..5e2e529557
--- /dev/null
+++ b/configs/knet/knet.yml
@@ -0,0 +1,169 @@
+Collections:
+- Name: KNet
+ Metadata:
+ Training Data:
+ - ADE20K
+ Paper:
+ URL: https://arxiv.org/abs/2106.14855
+ Title: 'K-Net: Towards Unified Image Segmentation'
+ README: configs/knet/README.md
+ Code:
+ URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.23.0/mmseg/models/decode_heads/knet_head.py#L392
+ Version: v0.23.0
+ Converted From:
+ Code: https://github.com/ZwwWayne/K-Net/
+Models:
+- Name: knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k
+ In Collection: KNet
+ Metadata:
+ backbone: R-50-D8
+ crop size: (512,512)
+ lr schd: 80000
+ inference time (ms/im):
+ - value: 51.98
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 7.01
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 43.6
+ mIoU(ms+flip): 45.12
+ Config: configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_043751-abcab920.pth
+- Name: knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k
+ In Collection: KNet
+ Metadata:
+ backbone: R-50-D8
+ crop size: (512,512)
+ lr schd: 80000
+ inference time (ms/im):
+ - value: 49.9
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 6.98
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 44.18
+ mIoU(ms+flip): 45.58
+ Config: configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_054634-d2c72240.pth
+- Name: knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k
+ In Collection: KNet
+ Metadata:
+ backbone: R-50-D8
+ crop size: (512,512)
+ lr schd: 80000
+ inference time (ms/im):
+ - value: 82.64
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 7.42
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 45.06
+ mIoU(ms+flip): 46.11
+ Config: configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_041642-00c8fbeb.pth
+- Name: knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k
+ In Collection: KNet
+ Metadata:
+ backbone: R-50-D8
+ crop size: (512,512)
+ lr schd: 80000
+ inference time (ms/im):
+ - value: 58.45
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 7.34
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 43.45
+ mIoU(ms+flip): 44.07
+ Config: configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220304_125657-215753b0.pth
+- Name: knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k
+ In Collection: KNet
+ Metadata:
+ backbone: Swin-T
+ crop size: (512,512)
+ lr schd: 80000
+ inference time (ms/im):
+ - value: 64.27
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 7.57
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 45.84
+ mIoU(ms+flip): 46.27
+ Config: configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k_20220303_133059-7545e1dc.pth
+- Name: knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k
+ In Collection: KNet
+ Metadata:
+ backbone: Swin-L
+ crop size: (512,512)
+ lr schd: 80000
+ inference time (ms/im):
+ - value: 120.63
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 13.5
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 52.05
+ mIoU(ms+flip): 53.24
+ Config: configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k_20220303_154559-d8da9a90.pth
+- Name: knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k
+ In Collection: KNet
+ Metadata:
+ backbone: Swin-L
+ crop size: (640,640)
+ lr schd: 80000
+ inference time (ms/im):
+ - value: 120.63
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (640,640)
+ Training Memory (GB): 13.54
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 52.21
+ mIoU(ms+flip): 53.34
+ Config: configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k_20220301_220747-8787fc71.pth
diff --git a/configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py b/configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py
new file mode 100644
index 0000000000..3edb05c875
--- /dev/null
+++ b/configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py
@@ -0,0 +1,93 @@
+_base_ = [
+ '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_80k.py'
+]
+
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+num_stages = 3
+conv_kernel_size = 1
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='IterativeDecodeHead',
+ num_stages=num_stages,
+ kernel_update_head=[
+ dict(
+ type='KernelUpdateHead',
+ num_classes=150,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_mask_fcs=1,
+ feedforward_channels=2048,
+ in_channels=512,
+ out_channels=512,
+ dropout=0.0,
+ conv_kernel_size=conv_kernel_size,
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ with_ffn=True,
+ feat_transform_cfg=dict(
+ conv_cfg=dict(type='Conv2d'), act_cfg=None),
+ kernel_updator_cfg=dict(
+ type='KernelUpdator',
+ in_channels=256,
+ feat_channels=256,
+ out_channels=256,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'))) for _ in range(num_stages)
+ ],
+ kernel_generate_head=dict(
+ type='ASPPHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dilations=(1, 12, 24, 36),
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
+
+# optimizer
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
+optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
+# learning policy
+lr_config = dict(
+ _delete_=True,
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[60000, 72000],
+ by_epoch=False)
+# In K-Net implementation we use batch size 2 per GPU as default
+data = dict(samples_per_gpu=2, workers_per_gpu=2)
diff --git a/configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py b/configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py
new file mode 100644
index 0000000000..29a088f721
--- /dev/null
+++ b/configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py
@@ -0,0 +1,93 @@
+_base_ = [
+ '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_80k.py'
+]
+
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+num_stages = 3
+conv_kernel_size = 1
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='IterativeDecodeHead',
+ num_stages=num_stages,
+ kernel_update_head=[
+ dict(
+ type='KernelUpdateHead',
+ num_classes=150,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_mask_fcs=1,
+ feedforward_channels=2048,
+ in_channels=512,
+ out_channels=512,
+ dropout=0.0,
+ conv_kernel_size=conv_kernel_size,
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ with_ffn=True,
+ feat_transform_cfg=dict(
+ conv_cfg=dict(type='Conv2d'), act_cfg=None),
+ kernel_updator_cfg=dict(
+ type='KernelUpdator',
+ in_channels=256,
+ feat_channels=256,
+ out_channels=256,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'))) for _ in range(num_stages)
+ ],
+ kernel_generate_head=dict(
+ type='FCNHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ num_convs=2,
+ concat_input=True,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
+# optimizer
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
+optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
+# learning policy
+lr_config = dict(
+ _delete_=True,
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[60000, 72000],
+ by_epoch=False)
+# In K-Net implementation we use batch size 2 per GPU as default
+data = dict(samples_per_gpu=2, workers_per_gpu=2)
diff --git a/configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py b/configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py
new file mode 100644
index 0000000000..d77a3b4423
--- /dev/null
+++ b/configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py
@@ -0,0 +1,92 @@
+_base_ = [
+ '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_80k.py'
+]
+
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+num_stages = 3
+conv_kernel_size = 1
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='IterativeDecodeHead',
+ num_stages=num_stages,
+ kernel_update_head=[
+ dict(
+ type='KernelUpdateHead',
+ num_classes=150,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_mask_fcs=1,
+ feedforward_channels=2048,
+ in_channels=512,
+ out_channels=512,
+ dropout=0.0,
+ conv_kernel_size=conv_kernel_size,
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ with_ffn=True,
+ feat_transform_cfg=dict(
+ conv_cfg=dict(type='Conv2d'), act_cfg=None),
+ kernel_updator_cfg=dict(
+ type='KernelUpdator',
+ in_channels=256,
+ feat_channels=256,
+ out_channels=256,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'))) for _ in range(num_stages)
+ ],
+ kernel_generate_head=dict(
+ type='PSPHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ pool_scales=(1, 2, 3, 6),
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
+# optimizer
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
+optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
+# learning policy
+lr_config = dict(
+ _delete_=True,
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[60000, 72000],
+ by_epoch=False)
+# In K-Net implementation we use batch size 2 per GPU as default
+data = dict(samples_per_gpu=2, workers_per_gpu=2)
diff --git a/configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py b/configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py
new file mode 100644
index 0000000000..0071cea750
--- /dev/null
+++ b/configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py
@@ -0,0 +1,93 @@
+_base_ = [
+ '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_80k.py'
+]
+
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+num_stages = 3
+conv_kernel_size = 1
+
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 1, 1),
+ strides=(1, 2, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='IterativeDecodeHead',
+ num_stages=num_stages,
+ kernel_update_head=[
+ dict(
+ type='KernelUpdateHead',
+ num_classes=150,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_mask_fcs=1,
+ feedforward_channels=2048,
+ in_channels=512,
+ out_channels=512,
+ dropout=0.0,
+ conv_kernel_size=conv_kernel_size,
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ with_ffn=True,
+ feat_transform_cfg=dict(
+ conv_cfg=dict(type='Conv2d'), act_cfg=None),
+ kernel_updator_cfg=dict(
+ type='KernelUpdator',
+ in_channels=256,
+ feat_channels=256,
+ out_channels=256,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'))) for _ in range(num_stages)
+ ],
+ kernel_generate_head=dict(
+ type='UPerHead',
+ in_channels=[256, 512, 1024, 2048],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
+# optimizer
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
+optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
+# learning policy
+lr_config = dict(
+ _delete_=True,
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[60000, 72000],
+ by_epoch=False)
+# In K-Net implementation we use batch size 2 per GPU as default
+data = dict(samples_per_gpu=2, workers_per_gpu=2)
diff --git a/configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py b/configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py
new file mode 100644
index 0000000000..b9d1a0952d
--- /dev/null
+++ b/configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py
@@ -0,0 +1,19 @@
+_base_ = 'knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py'
+
+checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220308-d5bdebaf.pth' # noqa
+# model settings
+model = dict(
+ pretrained=checkpoint_file,
+ backbone=dict(
+ embed_dims=192,
+ depths=[2, 2, 18, 2],
+ num_heads=[6, 12, 24, 48],
+ window_size=7,
+ use_abs_pos_embed=False,
+ drop_path_rate=0.3,
+ patch_norm=True),
+ decode_head=dict(
+ kernel_generate_head=dict(in_channels=[192, 384, 768, 1536])),
+ auxiliary_head=dict(in_channels=768))
+# In K-Net implementation we use batch size 2 per GPU as default
+data = dict(samples_per_gpu=2, workers_per_gpu=2)
diff --git a/configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py b/configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py
new file mode 100644
index 0000000000..fc6e9fe39f
--- /dev/null
+++ b/configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py
@@ -0,0 +1,54 @@
+_base_ = 'knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py'
+
+checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220308-d5bdebaf.pth' # noqa
+# model settings
+model = dict(
+ pretrained=checkpoint_file,
+ backbone=dict(
+ embed_dims=192,
+ depths=[2, 2, 18, 2],
+ num_heads=[6, 12, 24, 48],
+ window_size=7,
+ use_abs_pos_embed=False,
+ drop_path_rate=0.4,
+ patch_norm=True),
+ decode_head=dict(
+ kernel_generate_head=dict(in_channels=[192, 384, 768, 1536])),
+ auxiliary_head=dict(in_channels=768))
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (640, 640)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 640),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ train=dict(pipeline=train_pipeline),
+ val=dict(pipeline=test_pipeline),
+ test=dict(pipeline=test_pipeline))
+# In K-Net implementation we use batch size 2 per GPU as default
+data = dict(samples_per_gpu=2, workers_per_gpu=2)
diff --git a/configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py b/configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py
new file mode 100644
index 0000000000..0b29b2b8c4
--- /dev/null
+++ b/configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py
@@ -0,0 +1,57 @@
+_base_ = 'knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py'
+
+checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220308-f41b89d3.pth' # noqa
+
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+num_stages = 3
+conv_kernel_size = 1
+
+model = dict(
+ type='EncoderDecoder',
+ pretrained=checkpoint_file,
+ backbone=dict(
+ _delete_=True,
+ type='SwinTransformer',
+ embed_dims=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.3,
+ use_abs_pos_embed=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3)),
+ decode_head=dict(
+ kernel_generate_head=dict(in_channels=[96, 192, 384, 768])),
+ auxiliary_head=dict(in_channels=384))
+
+# modify learning rate following the official implementation of Swin Transformer # noqa
+optimizer = dict(
+ _delete_=True,
+ type='AdamW',
+ lr=0.00006,
+ betas=(0.9, 0.999),
+ weight_decay=0.0005,
+ paramwise_cfg=dict(
+ custom_keys={
+ 'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)
+ }))
+optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
+# learning policy
+lr_config = dict(
+ _delete_=True,
+ policy='step',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=0.001,
+ step=[60000, 72000],
+ by_epoch=False)
+# In K-Net implementation we use batch size 2 per GPU as default
+data = dict(samples_per_gpu=2, workers_per_gpu=2)
diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py
index dcde813264..8add7615c2 100644
--- a/mmseg/models/decode_heads/__init__.py
+++ b/mmseg/models/decode_heads/__init__.py
@@ -13,6 +13,7 @@
from .fpn_head import FPNHead
from .gc_head import GCHead
from .isa_head import ISAHead
+from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
from .lraspp_head import LRASPPHead
from .nl_head import NLHead
from .ocr_head import OCRHead
@@ -34,5 +35,6 @@
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
- 'SegformerHead', 'ISAHead', 'STDCHead'
+ 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
+ 'KernelUpdateHead', 'KernelUpdator'
]
diff --git a/mmseg/models/decode_heads/knet_head.py b/mmseg/models/decode_heads/knet_head.py
new file mode 100644
index 0000000000..f73daccb64
--- /dev/null
+++ b/mmseg/models/decode_heads/knet_head.py
@@ -0,0 +1,453 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
+from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER,
+ MultiheadAttention,
+ build_transformer_layer)
+
+from mmseg.models.builder import HEADS, build_head
+from mmseg.models.decode_heads.decode_head import BaseDecodeHead
+from mmseg.utils import get_root_logger
+
+
+@TRANSFORMER_LAYER.register_module()
+class KernelUpdator(nn.Module):
+ """Dynamic Kernel Updator in Kernel Update Head.
+
+ Args:
+ in_channels (int): The number of channels of input feature map.
+ Default: 256.
+ feat_channels (int): The number of middle-stage channels in
+ the kernel updator. Default: 64.
+ out_channels (int): The number of output channels.
+ gate_sigmoid (bool): Whether use sigmoid function in gate
+ mechanism. Default: True.
+ gate_norm_act (bool): Whether add normalization and activation
+ layer in gate mechanism. Default: False.
+ activate_out: Whether add activation after gate mechanism.
+ Default: False.
+ norm_cfg (dict | None): Config of norm layers.
+ Default: dict(type='LN').
+ act_cfg (dict): Config of activation layers.
+ Default: dict(type='ReLU').
+ """
+
+ def __init__(
+ self,
+ in_channels=256,
+ feat_channels=64,
+ out_channels=None,
+ gate_sigmoid=True,
+ gate_norm_act=False,
+ activate_out=False,
+ norm_cfg=dict(type='LN'),
+ act_cfg=dict(type='ReLU', inplace=True),
+ ):
+ super(KernelUpdator, self).__init__()
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.out_channels_raw = out_channels
+ self.gate_sigmoid = gate_sigmoid
+ self.gate_norm_act = gate_norm_act
+ self.activate_out = activate_out
+ self.act_cfg = act_cfg
+ self.norm_cfg = norm_cfg
+ self.out_channels = out_channels if out_channels else in_channels
+
+ self.num_params_in = self.feat_channels
+ self.num_params_out = self.feat_channels
+ self.dynamic_layer = nn.Linear(
+ self.in_channels, self.num_params_in + self.num_params_out)
+ self.input_layer = nn.Linear(self.in_channels,
+ self.num_params_in + self.num_params_out,
+ 1)
+ self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
+ self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
+ if self.gate_norm_act:
+ self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1]
+
+ self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
+ self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
+ self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
+ self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
+
+ self.activation = build_activation_layer(act_cfg)
+
+ self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)
+ self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
+
+ def forward(self, update_feature, input_feature):
+ """Forward function of KernelUpdator.
+
+ Args:
+ update_feature (torch.Tensor): Feature map assembled from
+ each group. It would be reshaped with last dimension
+ shape: `self.in_channels`.
+ input_feature (torch.Tensor): Intermediate feature
+ with shape: (N, num_classes, conv_kernel_size**2, channels).
+ Returns:
+ Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is
+ the number of classes, C1 and C2 are the feature map channels of
+ KernelUpdateHead and KernelUpdator, respectively.
+ """
+
+ update_feature = update_feature.reshape(-1, self.in_channels)
+ num_proposals = update_feature.size(0)
+ # dynamic_layer works for
+ # phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper
+ parameters = self.dynamic_layer(update_feature)
+ param_in = parameters[:, :self.num_params_in].view(
+ -1, self.feat_channels)
+ param_out = parameters[:, -self.num_params_out:].view(
+ -1, self.feat_channels)
+
+ # input_layer works for
+ # phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper
+ input_feats = self.input_layer(
+ input_feature.reshape(num_proposals, -1, self.feat_channels))
+ input_in = input_feats[..., :self.num_params_in]
+ input_out = input_feats[..., -self.num_params_out:]
+
+ # `gate_feats` is F^G in K-Net paper
+ gate_feats = input_in * param_in.unsqueeze(-2)
+ if self.gate_norm_act:
+ gate_feats = self.activation(self.gate_norm(gate_feats))
+
+ input_gate = self.input_norm_in(self.input_gate(gate_feats))
+ update_gate = self.norm_in(self.update_gate(gate_feats))
+ if self.gate_sigmoid:
+ input_gate = input_gate.sigmoid()
+ update_gate = update_gate.sigmoid()
+ param_out = self.norm_out(param_out)
+ input_out = self.input_norm_out(input_out)
+
+ if self.activate_out:
+ param_out = self.activation(param_out)
+ input_out = self.activation(input_out)
+
+ # Gate mechanism. Eq.(5) in original paper.
+ # param_out has shape (batch_size, feat_channels, out_channels)
+ features = update_gate * param_out.unsqueeze(
+ -2) + input_gate * input_out
+
+ features = self.fc_layer(features)
+ features = self.fc_norm(features)
+ features = self.activation(features)
+
+ return features
+
+
+@HEADS.register_module()
+class KernelUpdateHead(nn.Module):
+ """Kernel Update Head in K-Net.
+
+ Args:
+ num_classes (int): Number of classes. Default: 150.
+ num_ffn_fcs (int): The number of fully-connected layers in
+ FFNs. Default: 2.
+ num_heads (int): The number of parallel attention heads.
+ Default: 8.
+ num_mask_fcs (int): The number of fully connected layers for
+ mask prediction. Default: 3.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ Defaults: 2048.
+ in_channels (int): The number of channels of input feature map.
+ Default: 256.
+ out_channels (int): The number of output channels.
+ Default: 256.
+ dropout (float): The Probability of an element to be
+ zeroed in MultiheadAttention and FFN. Default 0.0.
+ act_cfg (dict): Config of activation layers.
+ Default: dict(type='ReLU').
+ ffn_act_cfg (dict): Config of activation layers in FFN.
+ Default: dict(type='ReLU').
+ conv_kernel_size (int): The kernel size of convolution in
+ Kernel Update Head for dynamic kernel updation.
+ Default: 1.
+ feat_transform_cfg (dict | None): Config of feature transform.
+ Default: None.
+ kernel_init (bool): Whether initiate mask kernel in mask head.
+ Default: False.
+ with_ffn (bool): Whether add FFN in kernel update head.
+ Default: True.
+ feat_gather_stride (int): Stride of convolution in feature transform.
+ Default: 1.
+ mask_transform_stride (int): Stride of mask transform.
+ Default: 1.
+ kernel_updator_cfg (dict): Config of kernel updator.
+ Default: dict(
+ type='DynamicConv',
+ in_channels=256,
+ feat_channels=64,
+ out_channels=256,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN')).
+ """
+
+ def __init__(self,
+ num_classes=150,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_mask_fcs=3,
+ feedforward_channels=2048,
+ in_channels=256,
+ out_channels=256,
+ dropout=0.0,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ conv_kernel_size=1,
+ feat_transform_cfg=None,
+ kernel_init=False,
+ with_ffn=True,
+ feat_gather_stride=1,
+ mask_transform_stride=1,
+ kernel_updator_cfg=dict(
+ type='DynamicConv',
+ in_channels=256,
+ feat_channels=64,
+ out_channels=256,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'))):
+ super(KernelUpdateHead, self).__init__()
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.fp16_enabled = False
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.kernel_init = kernel_init
+ self.with_ffn = with_ffn
+ self.conv_kernel_size = conv_kernel_size
+ self.feat_gather_stride = feat_gather_stride
+ self.mask_transform_stride = mask_transform_stride
+
+ self.attention = MultiheadAttention(in_channels * conv_kernel_size**2,
+ num_heads, dropout)
+ self.attention_norm = build_norm_layer(
+ dict(type='LN'), in_channels * conv_kernel_size**2)[1]
+ self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)
+
+ if feat_transform_cfg is not None:
+ kernel_size = feat_transform_cfg.pop('kernel_size', 1)
+ transform_channels = in_channels
+ self.feat_transform = ConvModule(
+ transform_channels,
+ in_channels,
+ kernel_size,
+ stride=feat_gather_stride,
+ padding=int(feat_gather_stride // 2),
+ **feat_transform_cfg)
+ else:
+ self.feat_transform = None
+
+ if self.with_ffn:
+ self.ffn = FFN(
+ in_channels,
+ feedforward_channels,
+ num_ffn_fcs,
+ act_cfg=ffn_act_cfg,
+ dropout=dropout)
+ self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
+
+ self.mask_fcs = nn.ModuleList()
+ for _ in range(num_mask_fcs):
+ self.mask_fcs.append(
+ nn.Linear(in_channels, in_channels, bias=False))
+ self.mask_fcs.append(
+ build_norm_layer(dict(type='LN'), in_channels)[1])
+ self.mask_fcs.append(build_activation_layer(act_cfg))
+
+ self.fc_mask = nn.Linear(in_channels, out_channels)
+
+ def init_weights(self):
+ """Use xavier initialization for all weight parameter and set
+ classification head bias as a specific value when use focal loss."""
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ else:
+ # adopt the default initialization for
+ # the weight and bias of the layer norm
+ pass
+ if self.kernel_init:
+ logger = get_root_logger()
+ logger.info(
+ 'mask kernel in mask head is normal initialized by std 0.01')
+ nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)
+
+ def forward(self, x, proposal_feat, mask_preds, mask_shape=None):
+ """Forward function of Dynamic Instance Interactive Head.
+
+ Args:
+ x (Tensor): Feature map from FPN with shape
+ (batch_size, feature_dimensions, H , W).
+ proposal_feat (Tensor): Intermediate feature get from
+ diihead in last stage, has shape
+ (batch_size, num_proposals, feature_dimensions)
+ mask_preds (Tensor): mask prediction from the former stage in shape
+ (batch_size, num_proposals, H, W).
+
+ Returns:
+ Tuple: The first tensor is predicted mask with shape
+ (N, num_classes, H, W), the second tensor is dynamic kernel
+ with shape (N, num_classes, channels, K, K).
+ """
+ N, num_proposals = proposal_feat.shape[:2]
+ if self.feat_transform is not None:
+ x = self.feat_transform(x)
+
+ C, H, W = x.shape[-3:]
+
+ mask_h, mask_w = mask_preds.shape[-2:]
+ if mask_h != H or mask_w != W:
+ gather_mask = F.interpolate(
+ mask_preds, (H, W), align_corners=False, mode='bilinear')
+ else:
+ gather_mask = mask_preds
+
+ sigmoid_masks = gather_mask.softmax(dim=1)
+
+ # Group Feature Assembling. Eq.(3) in original paper.
+ # einsum is faster than bmm by 30%
+ x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x)
+
+ # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]
+ proposal_feat = proposal_feat.reshape(N, num_proposals,
+ self.in_channels,
+ -1).permute(0, 1, 3, 2)
+ obj_feat = self.kernel_update_conv(x_feat, proposal_feat)
+
+ # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]
+ obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2)
+ obj_feat = self.attention_norm(self.attention(obj_feat))
+ # [N, B, K*K*C] -> [B, N, K*K*C]
+ obj_feat = obj_feat.permute(1, 0, 2)
+
+ # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]
+ obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)
+
+ # FFN
+ if self.with_ffn:
+ obj_feat = self.ffn_norm(self.ffn(obj_feat))
+
+ mask_feat = obj_feat
+
+ for reg_layer in self.mask_fcs:
+ mask_feat = reg_layer(mask_feat)
+
+ # [B, N, K*K, C] -> [B, N, C, K*K]
+ mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)
+
+ if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1):
+ mask_x = F.interpolate(
+ x, scale_factor=0.5, mode='bilinear', align_corners=False)
+ H, W = mask_x.shape[-2:]
+ else:
+ mask_x = x
+ # group conv is 5x faster than unfold and uses about 1/5 memory
+ # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms
+ # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369
+ # but in real training group conv is slower than concat batch
+ # so we keep using concat batch.
+ # fold_x = F.unfold(
+ # mask_x,
+ # self.conv_kernel_size,
+ # padding=int(self.conv_kernel_size // 2))
+ # mask_feat = mask_feat.reshape(N, num_proposals, -1)
+ # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)
+ # [B, N, C, K*K] -> [B*N, C, K, K]
+ mask_feat = mask_feat.reshape(N, num_proposals, C,
+ self.conv_kernel_size,
+ self.conv_kernel_size)
+ # [B, C, H, W] -> [1, B*C, H, W]
+ new_mask_preds = []
+ for i in range(N):
+ new_mask_preds.append(
+ F.conv2d(
+ mask_x[i:i + 1],
+ mask_feat[i],
+ padding=int(self.conv_kernel_size // 2)))
+
+ new_mask_preds = torch.cat(new_mask_preds, dim=0)
+ new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W)
+ if self.mask_transform_stride == 2:
+ new_mask_preds = F.interpolate(
+ new_mask_preds,
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=False)
+
+ if mask_shape is not None and mask_shape[0] != H:
+ new_mask_preds = F.interpolate(
+ new_mask_preds,
+ mask_shape,
+ align_corners=False,
+ mode='bilinear')
+
+ return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(
+ N, num_proposals, self.in_channels, self.conv_kernel_size,
+ self.conv_kernel_size)
+
+
+@HEADS.register_module()
+class IterativeDecodeHead(BaseDecodeHead):
+ """K-Net: Towards Unified Image Segmentation.
+
+ This head is the implementation of
+ `K-Net: `_.
+
+ Args:
+ num_stages (int): The number of stages (kernel update heads)
+ in IterativeDecodeHead. Default: 3.
+ kernel_generate_head:(dict): Config of kernel generate head which
+ generate mask predictions, dynamic kernels and class predictions
+ for next kernel update heads.
+ kernel_update_head (dict): Config of kernel update head which refine
+ dynamic kernels and class predictions iteratively.
+
+ """
+
+ def __init__(self, num_stages, kernel_generate_head, kernel_update_head,
+ **kwargs):
+ super(BaseDecodeHead, self).__init__(**kwargs)
+ assert num_stages == len(kernel_update_head)
+ self.num_stages = num_stages
+ self.kernel_generate_head = build_head(kernel_generate_head)
+ self.kernel_update_head = nn.ModuleList()
+ self.align_corners = self.kernel_generate_head.align_corners
+ self.num_classes = self.kernel_generate_head.num_classes
+ self.input_transform = self.kernel_generate_head.input_transform
+ self.ignore_index = self.kernel_generate_head.ignore_index
+
+ for head_cfg in kernel_update_head:
+ self.kernel_update_head.append(build_head(head_cfg))
+
+ def forward(self, inputs):
+ """Forward function."""
+ feats = self.kernel_generate_head._forward_feature(inputs)
+ sem_seg = self.kernel_generate_head.cls_seg(feats)
+ seg_kernels = self.kernel_generate_head.conv_seg.weight.clone()
+ seg_kernels = seg_kernels[None].expand(
+ feats.size(0), *seg_kernels.size())
+
+ stage_segs = [sem_seg]
+ for i in range(self.num_stages):
+ sem_seg, seg_kernels = self.kernel_update_head[i](feats,
+ seg_kernels,
+ sem_seg)
+ stage_segs.append(sem_seg)
+ if self.training:
+ return stage_segs
+ # only return the prediction of the last stage during testing
+ return stage_segs[-1]
+
+ def losses(self, seg_logit, seg_label):
+ losses = dict()
+ for i, logit in enumerate(seg_logit):
+ loss = self.kernel_generate_head.losses(logit, seg_label)
+ for k, v in loss.items():
+ losses[f'{k}.s{i}'] = v
+
+ return losses
diff --git a/model-index.yml b/model-index.yml
index cd82220bbd..235ad7f6e7 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -22,6 +22,7 @@ Import:
- configs/hrnet/hrnet.yml
- configs/icnet/icnet.yml
- configs/isanet/isanet.yml
+- configs/knet/knet.yml
- configs/mobilenet_v2/mobilenet_v2.yml
- configs/mobilenet_v3/mobilenet_v3.yml
- configs/nonlocal_net/nonlocal_net.yml
diff --git a/tests/test_models/test_heads/test_knet_head.py b/tests/test_models/test_heads/test_knet_head.py
new file mode 100644
index 0000000000..e6845a6d3f
--- /dev/null
+++ b/tests/test_models/test_heads/test_knet_head.py
@@ -0,0 +1,195 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmseg.models.decode_heads.knet_head import (IterativeDecodeHead,
+ KernelUpdateHead)
+from .utils import to_cuda
+
+num_stages = 3
+conv_kernel_size = 1
+
+kernel_updator_cfg = dict(
+ type='KernelUpdator',
+ in_channels=16,
+ feat_channels=16,
+ out_channels=16,
+ gate_norm_act=True,
+ activate_out=True,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'))
+
+
+def test_knet_head():
+ # test init function of kernel update head
+ kernel_update_head = KernelUpdateHead(
+ num_classes=150,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_mask_fcs=1,
+ feedforward_channels=128,
+ in_channels=32,
+ out_channels=32,
+ dropout=0.0,
+ conv_kernel_size=conv_kernel_size,
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ with_ffn=True,
+ feat_transform_cfg=dict(conv_cfg=dict(type='Conv2d'), act_cfg=None),
+ kernel_init=True,
+ kernel_updator_cfg=kernel_updator_cfg)
+ kernel_update_head.init_weights()
+
+ head = IterativeDecodeHead(
+ num_stages=num_stages,
+ kernel_update_head=[
+ dict(
+ type='KernelUpdateHead',
+ num_classes=150,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_mask_fcs=1,
+ feedforward_channels=128,
+ in_channels=32,
+ out_channels=32,
+ dropout=0.0,
+ conv_kernel_size=conv_kernel_size,
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ with_ffn=True,
+ feat_transform_cfg=dict(
+ conv_cfg=dict(type='Conv2d'), act_cfg=None),
+ kernel_init=False,
+ kernel_updator_cfg=kernel_updator_cfg)
+ for _ in range(num_stages)
+ ],
+ kernel_generate_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ channels=32,
+ num_convs=2,
+ concat_input=True,
+ dropout_ratio=0.1,
+ num_classes=150,
+ align_corners=False))
+ head.init_weights()
+ inputs = [
+ torch.randn(1, 16, 27, 32),
+ torch.randn(1, 32, 27, 16),
+ torch.randn(1, 64, 27, 16),
+ torch.randn(1, 128, 27, 16)
+ ]
+
+ if torch.cuda.is_available():
+ head, inputs = to_cuda(head, inputs)
+ outputs = head(inputs)
+ assert outputs[-1].shape == (1, head.num_classes, 27, 16)
+
+ # test whether only return the prediction of
+ # the last stage during testing
+ with torch.no_grad():
+ head.eval()
+ outputs = head(inputs)
+ assert outputs.shape == (1, head.num_classes, 27, 16)
+
+ # test K-Net without `feat_transform_cfg`
+ head = IterativeDecodeHead(
+ num_stages=num_stages,
+ kernel_update_head=[
+ dict(
+ type='KernelUpdateHead',
+ num_classes=150,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_mask_fcs=1,
+ feedforward_channels=128,
+ in_channels=32,
+ out_channels=32,
+ dropout=0.0,
+ conv_kernel_size=conv_kernel_size,
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ with_ffn=True,
+ feat_transform_cfg=None,
+ kernel_updator_cfg=kernel_updator_cfg)
+ for _ in range(num_stages)
+ ],
+ kernel_generate_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ channels=32,
+ num_convs=2,
+ concat_input=True,
+ dropout_ratio=0.1,
+ num_classes=150,
+ align_corners=False))
+ head.init_weights()
+
+ inputs = [
+ torch.randn(1, 16, 27, 32),
+ torch.randn(1, 32, 27, 16),
+ torch.randn(1, 64, 27, 16),
+ torch.randn(1, 128, 27, 16)
+ ]
+
+ if torch.cuda.is_available():
+ head, inputs = to_cuda(head, inputs)
+ outputs = head(inputs)
+ assert outputs[-1].shape == (1, head.num_classes, 27, 16)
+
+ # test K-Net with
+ # self.mask_transform_stride == 2 and self.feat_gather_stride == 1
+ head = IterativeDecodeHead(
+ num_stages=num_stages,
+ kernel_update_head=[
+ dict(
+ type='KernelUpdateHead',
+ num_classes=150,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_mask_fcs=1,
+ feedforward_channels=128,
+ in_channels=32,
+ out_channels=32,
+ dropout=0.0,
+ conv_kernel_size=conv_kernel_size,
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ with_ffn=True,
+ feat_transform_cfg=dict(
+ conv_cfg=dict(type='Conv2d'), act_cfg=None),
+ kernel_init=False,
+ mask_transform_stride=2,
+ feat_gather_stride=1,
+ kernel_updator_cfg=kernel_updator_cfg)
+ for _ in range(num_stages)
+ ],
+ kernel_generate_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ channels=32,
+ num_convs=2,
+ concat_input=True,
+ dropout_ratio=0.1,
+ num_classes=150,
+ align_corners=False))
+ head.init_weights()
+
+ inputs = [
+ torch.randn(1, 16, 27, 32),
+ torch.randn(1, 32, 27, 16),
+ torch.randn(1, 64, 27, 16),
+ torch.randn(1, 128, 27, 16)
+ ]
+
+ if torch.cuda.is_available():
+ head, inputs = to_cuda(head, inputs)
+ outputs = head(inputs)
+ assert outputs[-1].shape == (1, head.num_classes, 26, 16)
+
+ # test loss function in K-Net
+ fake_label = torch.ones_like(
+ outputs[-1][:, 0:1, :, :], dtype=torch.int16).long()
+ loss = head.losses(seg_logit=outputs, seg_label=fake_label)
+ assert loss['loss_ce.s0'] != torch.zeros_like(loss['loss_ce.s0'])
+ assert loss['loss_ce.s1'] != torch.zeros_like(loss['loss_ce.s1'])
+ assert loss['loss_ce.s2'] != torch.zeros_like(loss['loss_ce.s2'])
+ assert loss['loss_ce.s3'] != torch.zeros_like(loss['loss_ce.s3'])