From 9bc147e12fa43e1cdbbfad448b0227b19133d2d3 Mon Sep 17 00:00:00 2001 From: zhangshilong <2392587229zsl@gmail.com> Date: Mon, 29 Nov 2021 15:58:26 +0800 Subject: [PATCH 1/3] add an example of swin is used in one-stage model --- .../retinanet_swin-t-p4-w7_fpn_1x_coco.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 configs/swin/retinanet_swin-t-p4-w7_fpn_1x_coco.py diff --git a/configs/swin/retinanet_swin-t-p4-w7_fpn_1x_coco.py b/configs/swin/retinanet_swin-t-p4-w7_fpn_1x_coco.py new file mode 100644 index 00000000000..d5f5b699fa0 --- /dev/null +++ b/configs/swin/retinanet_swin-t-p4-w7_fpn_1x_coco.py @@ -0,0 +1,30 @@ +_base_ = [ + '../_base_/models/retinanet_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +# optimizer +pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa +model = dict( + backbone=dict( + _delete_=True, + type='SwinTransformer', + embed_dims=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + patch_norm=True, + out_indices=(1, 2, 3), + # Please only add indices that would be used + # in FPN, otherwise some parameter will not be used + with_cp=False, + init_cfg=dict(type='Pretrained', checkpoint=pretrained)), + neck=dict(in_channels=[192, 384, 768], start_level=0, num_outs=5)) + +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) From 3489893beb7e5c25a14785d58008e2165882dd74 Mon Sep 17 00:00:00 2001 From: zhangshilong <2392587229zsl@gmail.com> Date: Mon, 29 Nov 2021 20:45:55 +0800 Subject: [PATCH 2/3] fix comments --- configs/swin/retinanet_swin-t-p4-w7_fpn_1x_coco.py | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/swin/retinanet_swin-t-p4-w7_fpn_1x_coco.py b/configs/swin/retinanet_swin-t-p4-w7_fpn_1x_coco.py index d5f5b699fa0..9d620fd980a 100644 --- a/configs/swin/retinanet_swin-t-p4-w7_fpn_1x_coco.py +++ b/configs/swin/retinanet_swin-t-p4-w7_fpn_1x_coco.py @@ -3,7 +3,6 @@ '../_base_/datasets/coco_detection.py', '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' ] -# optimizer pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa model = dict( backbone=dict( From 0de48fd4450da22d9b0e2fe0a3971b3a044ebb4e Mon Sep 17 00:00:00 2001 From: zhangshilong <2392587229zsl@gmail.com> Date: Thu, 2 Dec 2021 13:14:16 +0800 Subject: [PATCH 3/3] add a notice --- configs/swin/README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/configs/swin/README.md b/configs/swin/README.md index d18632a38ac..2b86d1ee799 100644 --- a/configs/swin/README.md +++ b/configs/swin/README.md @@ -23,3 +23,8 @@ | Swin-T | ImageNet-1K | 3x | yes | no | 10.2 | | 46.0 | 41.6 | [config](./mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco/mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco_20210906_131725-bacf6f7b.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco/mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco_20210906_131725.log.json) | | Swin-T | ImageNet-1K | 3x | yes | yes | 7.8 | | 46.0 | 41.7 | [config](./mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco.py)| [model](https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco/mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco_20210908_165006-90a4008c.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco/mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco_20210908_165006.log.json) | | Swin-S | ImageNet-1K | 3x | yes | yes | 11.9 | | 48.2 | 43.2 | [config](./mask_rcnn_swin-s-p4-w7_fpn_fp16_ms-crop-3x_coco.py)| [model](https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-s-p4-w7_fpn_fp16_ms-crop-3x_coco/mask_rcnn_swin-s-p4-w7_fpn_fp16_ms-crop-3x_coco_20210903_104808-b92c91f1.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-s-p4-w7_fpn_fp16_ms-crop-3x_coco/mask_rcnn_swin-s-p4-w7_fpn_fp16_ms-crop-3x_coco_20210903_104808.log.json) | + +### Notice +Please follow the example +of `retinanet_swin-t-p4-w7_fpn_1x_coco.py` when you want to combine Swin Transformer with +the one-stage detector. Because there is a layer norm at the outs of Swin Transformer, you must set `start_level` as 0 in FPN, so we have to set the `out_indices` of backbone as `[1,2,3]`.