diff --git a/README.md b/README.md
index 39688919559..0ce31cf8820 100644
--- a/README.md
+++ b/README.md
@@ -68,6 +68,7 @@ The master branch works with **PyTorch 1.6+**.
- **State of the art**
The toolbox stems from the codebase developed by the *MMDet* team, who won [COCO Detection Challenge](http://cocodataset.org/#detection-leaderboard) in 2018, and we keep pushing it forward.
+ The newly released [RTMDet](configs/rtmdet) also obtains new state-of-the-art results on real-time instance segmentation and rotated object detection tasks and the best parameter-accuracy trade-off on object detection.
@@ -75,6 +76,24 @@ Apart from MMDetection, we also released [MMEngine](https://github.com/open-mmla
## What's New
+### Highlight
+
+We are excited to announce our latest work on real-time object recognition tasks, **RTMDet**, a family of fully convolutional single-stage detectors. RTMDet not only achieves the best parameter-accuracy trade-off on object detection from tiny to extra-large model sizes but also obtains new state-of-the-art performance on instance segmentation and rotated object detection tasks. Details can be found in the [technical report](https://arxiv.org/abs/2212.07784). Pre-trained models are [here](configs/rtmdet).
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-hrsc2016)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=rtmdet-an-empirical-study-of-designing-real)
+
+| Task | Dataset | AP | FPS(TRT FP16 BS1 3090) |
+| ------------------------ | ------- | ------------------------------------ | ---------------------- |
+| Object Detection | COCO | 52.8 | 322 |
+| Instance Segmentation | COCO | 44.6 | 188 |
+| Rotated Object Detection | DOTA | 78.9(single-scale)/81.3(multi-scale) | 121 |
+
+
+
+
+
**v3.0.0rc4** was released in 25/11/2022:
- Support [CondInst](https://arxiv.org/abs/2003.05664)
@@ -187,6 +206,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
Deformable DETR (ICLR'2021)
TOOD (ICCV'2021)
DDOD (ACM MM'2021)
+ RTMDet (ArXiv'2022)
@@ -206,6 +226,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
Mask2Former (ArXiv'2021)
CondInst (ECCV 2020)
SparseInst (CVPR 2022)
+ RTMDet (ArXiv'2022)
|
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 4255bfae257..a8359cfcefc 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -67,6 +67,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
- **性能高**
MMDetection 这个算法库源自于 COCO 2018 目标检测竞赛的冠军团队 *MMDet* 团队开发的代码,我们在之后持续进行了改进和提升。
+ 新发布的 [RTMDet](configs/rtmdet) 还在实时实例分割和旋转目标检测任务中取得了最先进的成果,同时也在目标检测模型中取得了最佳的的参数量和精度平衡。
@@ -74,6 +75,24 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
## 最新进展
+### 亮点
+
+我们很高兴向大家介绍我们在实时目标识别任务方面的最新成果 RTMDet,包含了一系列的全卷积单阶段检测模型。 RTMDet 不仅在从 tiny 到 extra-large 尺寸的目标检测模型上上实现了最佳的参数量和精度的平衡,而且在实时实例分割和旋转目标检测任务上取得了最先进的成果。 更多细节请参阅[技术报告](https://arxiv.org/abs/2212.07784)。 预训练模型可以在[这里](configs/rtmdet)找到。
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-hrsc2016)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=rtmdet-an-empirical-study-of-designing-real)
+
+| Task | Dataset | AP | FPS(TRT FP16 BS1 3090) |
+| ------------------------ | ------- | ------------------------------------ | ---------------------- |
+| Object Detection | COCO | 52.8 | 322 |
+| Instance Segmentation | COCO | 44.6 | 188 |
+| Rotated Object Detection | DOTA | 78.9(single-scale)/81.3(multi-scale) | 121 |
+
+
+
+
+
**v3.0.0rc4** 版本已经在 2022.11.25 发布:
- 支持了 [CondInst](https://arxiv.org/abs/2003.05664)
@@ -188,6 +207,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
Deformable DETR (ICLR'2021)
TOOD (ICCV'2021)
DDOD (ACM MM'2021)
+ RTMDet (ArXiv'2022)
|
@@ -207,6 +227,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
Mask2Former (ArXiv'2021)
CondInst (ECCV 2020)
SparseInst (CVPR 2022)
+ RTMDet (ArXiv'2022)
|
diff --git a/configs/rtmdet/README.md b/configs/rtmdet/README.md
index f677baa5b0a..1c06812a748 100644
--- a/configs/rtmdet/README.md
+++ b/configs/rtmdet/README.md
@@ -1,25 +1,83 @@
-# RTMDet
+# RTMDet: An Empirical Study of Designing Real-Time Object Detectors
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-hrsc2016)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=rtmdet-an-empirical-study-of-designing-real)
## Abstract
-Our tech-report will be released soon.
+In this paper, we aim to design an efficient real-time object detector that exceeds the YOLO series and is easily extensible for many object recognition tasks such as instance segmentation and rotated object detection. To obtain a more efficient model architecture, we explore an architecture that has compatible capacities in the backbone and neck, constructed by a basic building block that consists of large-kernel depth-wise convolutions. We further introduce soft labels when calculating matching costs in the dynamic label assignment to improve accuracy. Together with better training techniques, the resulting object detector, named RTMDet, achieves 52.8% AP on COCO with 300+ FPS on an NVIDIA 3090 GPU, outperforming the current mainstream industrial detectors. RTMDet achieves the best parameter-accuracy trade-off with tiny/small/medium/large/extra-large model sizes for various application scenarios, and obtains new state-of-the-art performance on real-time instance segmentation and rotated object detection. We hope the experimental results can provide new insights into designing versatile real-time object detectors for many object recognition tasks.
-
+
## Results and Models
-| Backbone | size | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download |
+## Object Detection
+
+| Model | size | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download |
| :---------: | :--: | :----: | :-------: | :------: | :------------------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
-| RTMDet-tiny | 640 | 40.9 | 4.8 | 8.1 | 0.98 | [config](./rtmdet_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414.log.json) |
-| RTMDet-s | 640 | 44.5 | 8.89 | 14.8 | 1.22 | [config](./rtmdet_s_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602.log.json) |
-| RTMDet-m | 640 | 49.1 | 24.71 | 39.27 | 1.62 | [config](./rtmdet_m_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220.log.json) |
-| RTMDet-l | 640 | 51.3 | 52.3 | 80.23 | 2.44 | [config](./rtmdet_l_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030.log.json) |
-| RTMDet-x | 640 | 52.6 | 94.86 | 141.67 | 3.10 | [config](./rtmdet_x_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555-cc79b9ae.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555.log.json) |
+| RTMDet-tiny | 640 | 41.1 | 4.8 | 8.1 | 0.98 | [config](./rtmdet_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414.log.json) |
+| RTMDet-s | 640 | 44.6 | 8.89 | 14.8 | 1.22 | [config](./rtmdet_s_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602.log.json) |
+| RTMDet-m | 640 | 49.4 | 24.71 | 39.27 | 1.62 | [config](./rtmdet_m_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220.log.json) |
+| RTMDet-l | 640 | 51.5 | 52.3 | 80.23 | 2.44 | [config](./rtmdet_l_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030.log.json) |
+| RTMDet-x | 640 | 52.8 | 94.86 | 141.67 | 3.10 | [config](./rtmdet_x_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555-cc79b9ae.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555.log.json) |
**Note**:
-1. The inference speed is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1, and without NMS.
+1. The inference speed of RTMDet is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1, and without NMS.
+2. For a fair comparison, the config of bbox postprocessing is changed to be consistent with YOLOv5/6/7 after [PR#9494](https://github.com/open-mmlab/mmdetection/pull/9494), bringing about 0.1~0.3% AP improvement.
+
+## Instance Segmentation
+
+RTMDet-Ins is the state-of-the-art real-time instance segmentation on coco dataset:
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real)
+
+| Model | size | box AP | mask AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download |
+| :-------------: | :--: | :----: | :-----: | :-------: | :------: | :------------------: | :--------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| RTMDet-Ins-tiny | 640 | 40.5 | 35.4 | 5.6 | 11.8 | 1.70 | [config](./rtmdet-ins_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727.log.json) |
+| RTMDet-Ins-s | 640 | 44.0 | 38.7 | 10.18 | 21.5 | 1.93 | [config](./rtmdet-ins_s_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604.log.json) |
+| RTMDet-Ins-m | 640 | 48.8 | 42.1 | 27.58 | 54.13 | 2.69 | [config](./rtmdet-ins_m_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_m_8xb32-300e_coco/rtmdet-ins_m_8xb32-300e_coco_20221123_001039-6eba602e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_m_8xb32-300e_coco/rtmdet-ins_m_8xb32-300e_coco_20221123_001039.log.json) |
+| RTMDet-Ins-l | 640 | 51.2 | 43.7 | 57.37 | 106.56 | 3.68 | [config](./rtmdet-ins_l_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_l_8xb32-300e_coco/rtmdet-ins_l_8xb32-300e_coco_20221124_103237-78d1d652.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_l_8xb32-300e_coco/rtmdet-ins_l_8xb32-300e_coco_20221124_103237.log.json) |
+| RTMDet-Ins-x | 640 | 52.4 | 44.6 | 102.7 | 182.7 | 5.31 | [config](./rtmdet-ins_x_8xb16-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_x_8xb16-300e_coco/rtmdet-ins_x_8xb16-300e_coco_20221124_111313-33d4595b.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_l_8xb16-300e_coco/rtmdet-ins_x_8xb16-300e_coco_20221124_111313.log.json) |
+
+**Note**:
+
+1. The inference speed of RTMDet-Ins is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1. Top 100 masks are kept and the post process latency is included.
+
+## Rotated Object Detection
+
+RTMDet-R achieves state-of-the-art on various remote sensing datasets
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real)
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/one-stage-anchor-free-oriented-object-1)](https://paperswithcode.com/sota/one-stage-anchor-free-oriented-object-1?p=rtmdet-an-empirical-study-of-designing-real)
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-hrsc2016)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=rtmdet-an-empirical-study-of-designing-real)
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/one-stage-anchor-free-oriented-object-3)](https://paperswithcode.com/sota/one-stage-anchor-free-oriented-object-3?p=rtmdet-an-empirical-study-of-designing-real)
+
+Models and configs of RTMDet-R are available in [MMRotate](https://github.com/open-mmlab/mmrotate/tree/1.x/configs/rotated_rtmdet)
+
+## Citation
+
+```latex
+@misc{lyu2022rtmdet,
+ title={RTMDet: An Empirical Study of Designing Real-Time Object Detectors},
+ author={Chengqi Lyu and Wenwei Zhang and Haian Huang and Yue Zhou and Yudong Wang and Yanyi Liu and Shilong Zhang and Kai Chen},
+ year={2022},
+ eprint={2212.07784},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
+
+## Visualization
+
+
+
+
diff --git a/configs/rtmdet/metafile.yml b/configs/rtmdet/metafile.yml
index 0d854191934..9c0487f3ff1 100644
--- a/configs/rtmdet/metafile.yml
+++ b/configs/rtmdet/metafile.yml
@@ -79,3 +79,88 @@ Models:
Metrics:
box AP: 52.6
Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555-cc79b9ae.pth
+
+ - Name: rtmdet-ins_tiny_8xb32-300e_coco
+ In Collection: RTMDet
+ Config: configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py
+ Metadata:
+ Training Memory (GB): 18.4
+ Epochs: 300
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 40.5
+ - Task: Instance Segmentation
+ Dataset: COCO
+ Metrics:
+ mask AP: 35.4
+ Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth
+
+ - Name: rtmdet-ins_s_8xb32-300e_coco
+ In Collection: RTMDet
+ Config: configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py
+ Metadata:
+ Training Memory (GB): 27.6
+ Epochs: 300
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 44.0
+ - Task: Instance Segmentation
+ Dataset: COCO
+ Metrics:
+ mask AP: 38.7
+ Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth
+
+ - Name: rtmdet-ins_m_8xb32-300e_coco
+ In Collection: RTMDet
+ Config: configs/rtmdet/rtmdet-ins_m_8xb32-300e_coco.py
+ Metadata:
+ Training Memory (GB): 42.5
+ Epochs: 300
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 48.8
+ - Task: Instance Segmentation
+ Dataset: COCO
+ Metrics:
+ mask AP: 42.1
+ Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_m_8xb32-300e_coco/rtmdet-ins_m_8xb32-300e_coco_20221123_001039-6eba602e.pth
+
+ - Name: rtmdet-ins_l_8xb32-300e_coco
+ In Collection: RTMDet
+ Config: configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py
+ Metadata:
+ Training Memory (GB): 59.8
+ Epochs: 300
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 51.2
+ - Task: Instance Segmentation
+ Dataset: COCO
+ Metrics:
+ mask AP: 43.7
+ Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_l_8xb32-300e_coco/rtmdet-ins_l_8xb32-300e_coco_20221124_103237-78d1d652.pth
+
+ - Name: rtmdet-ins_x_8xb16-300e_coco
+ In Collection: RTMDet
+ Config: configs/rtmdet/rtmdet-ins_x_8xb16-300e_coco.py
+ Metadata:
+ Training Memory (GB): 33.7
+ Epochs: 300
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 52.4
+ - Task: Instance Segmentation
+ Dataset: COCO
+ Metrics:
+ mask AP: 44.6
+ Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_x_8xb16-300e_coco/rtmdet-ins_x_8xb16-300e_coco_20221124_111313-33d4595b.pth
diff --git a/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py b/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py
new file mode 100644
index 00000000000..1ecacab8044
--- /dev/null
+++ b/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py
@@ -0,0 +1,108 @@
+_base_ = './rtmdet_l_8xb32-300e_coco.py'
+model = dict(
+ bbox_head=dict(
+ _delete_=True,
+ type='RTMDetInsSepBNHead',
+ num_classes=80,
+ in_channels=256,
+ stacked_convs=2,
+ share_conv=True,
+ pred_kernel_size=1,
+ feat_channels=256,
+ act_cfg=dict(type='SiLU', inplace=True),
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ anchor_generator=dict(
+ type='MlvlPointGenerator', offset=0, strides=[8, 16, 32]),
+ bbox_coder=dict(type='DistancePointBBoxCoder'),
+ loss_cls=dict(
+ type='QualityFocalLoss',
+ use_sigmoid=True,
+ beta=2.0,
+ loss_weight=1.0),
+ loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
+ loss_mask=dict(
+ type='DiceLoss', loss_weight=2.0, eps=5e-6, reduction='mean')),
+ test_cfg=dict(
+ nms_pre=1000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(type='nms', iou_threshold=0.6),
+ max_per_img=100,
+ mask_thr_binary=0.5),
+)
+
+train_pipeline = [
+ dict(
+ type='LoadImageFromFile',
+ file_client_args={{_base_.file_client_args}}),
+ dict(
+ type='LoadAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='CachedMosaic', img_scale=(640, 640), pad_val=114.0),
+ dict(
+ type='RandomResize',
+ scale=(1280, 1280),
+ ratio_range=(0.1, 2.0),
+ keep_ratio=True),
+ dict(
+ type='RandomCrop',
+ crop_size=(640, 640),
+ recompute_bbox=True,
+ allow_negative_crop=True),
+ dict(type='YOLOXHSVRandomAug'),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
+ dict(
+ type='CachedMixUp',
+ img_scale=(640, 640),
+ ratio_range=(1.0, 1.0),
+ max_cached_images=20,
+ pad_val=(114, 114, 114)),
+ dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1)),
+ dict(type='PackDetInputs')
+]
+
+train_dataloader = dict(pin_memory=True, dataset=dict(pipeline=train_pipeline))
+
+train_pipeline_stage2 = [
+ dict(
+ type='LoadImageFromFile',
+ file_client_args={{_base_.file_client_args}}),
+ dict(
+ type='LoadAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(
+ type='RandomResize',
+ scale=(640, 640),
+ ratio_range=(0.1, 2.0),
+ keep_ratio=True),
+ dict(
+ type='RandomCrop',
+ crop_size=(640, 640),
+ recompute_bbox=True,
+ allow_negative_crop=True),
+ dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1)),
+ dict(type='YOLOXHSVRandomAug'),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
+ dict(type='PackDetInputs')
+]
+custom_hooks = [
+ dict(
+ type='EMAHook',
+ ema_type='ExpMomentumEMA',
+ momentum=0.0002,
+ update_buffers=True,
+ priority=49),
+ dict(
+ type='PipelineSwitchHook',
+ switch_epoch=280,
+ switch_pipeline=train_pipeline_stage2)
+]
+
+val_evaluator = dict(metric=['bbox', 'segm'])
+test_evaluator = val_evaluator
diff --git a/configs/rtmdet/rtmdet-ins_m_8xb32-300e_coco.py b/configs/rtmdet/rtmdet-ins_m_8xb32-300e_coco.py
new file mode 100644
index 00000000000..66da9148775
--- /dev/null
+++ b/configs/rtmdet/rtmdet-ins_m_8xb32-300e_coco.py
@@ -0,0 +1,6 @@
+_base_ = './rtmdet-ins_l_8xb32-300e_coco.py'
+
+model = dict(
+ backbone=dict(deepen_factor=0.67, widen_factor=0.75),
+ neck=dict(in_channels=[192, 384, 768], out_channels=192, num_csp_blocks=2),
+ bbox_head=dict(in_channels=192, feat_channels=192))
diff --git a/configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py b/configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py
new file mode 100644
index 00000000000..7785f2ff208
--- /dev/null
+++ b/configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py
@@ -0,0 +1,84 @@
+_base_ = './rtmdet-ins_l_8xb32-300e_coco.py'
+checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth' # noqa
+model = dict(
+ backbone=dict(
+ deepen_factor=0.33,
+ widen_factor=0.5,
+ init_cfg=dict(
+ type='Pretrained', prefix='backbone.', checkpoint=checkpoint)),
+ neck=dict(in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1),
+ bbox_head=dict(in_channels=128, feat_channels=128))
+
+train_pipeline = [
+ dict(
+ type='LoadImageFromFile',
+ file_client_args={{_base_.file_client_args}}),
+ dict(
+ type='LoadAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='CachedMosaic', img_scale=(640, 640), pad_val=114.0),
+ dict(
+ type='RandomResize',
+ scale=(1280, 1280),
+ ratio_range=(0.5, 2.0),
+ keep_ratio=True),
+ dict(
+ type='RandomCrop',
+ crop_size=(640, 640),
+ recompute_bbox=True,
+ allow_negative_crop=True),
+ dict(type='YOLOXHSVRandomAug'),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
+ dict(
+ type='CachedMixUp',
+ img_scale=(640, 640),
+ ratio_range=(1.0, 1.0),
+ max_cached_images=20,
+ pad_val=(114, 114, 114)),
+ dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1)),
+ dict(type='PackDetInputs')
+]
+
+train_pipeline_stage2 = [
+ dict(
+ type='LoadImageFromFile',
+ file_client_args={{_base_.file_client_args}}),
+ dict(
+ type='LoadAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(
+ type='RandomResize',
+ scale=(640, 640),
+ ratio_range=(0.5, 2.0),
+ keep_ratio=True),
+ dict(
+ type='RandomCrop',
+ crop_size=(640, 640),
+ recompute_bbox=True,
+ allow_negative_crop=True),
+ dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1)),
+ dict(type='YOLOXHSVRandomAug'),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
+ dict(type='PackDetInputs')
+]
+
+train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
+
+custom_hooks = [
+ dict(
+ type='EMAHook',
+ ema_type='ExpMomentumEMA',
+ momentum=0.0002,
+ update_buffers=True,
+ priority=49),
+ dict(
+ type='PipelineSwitchHook',
+ switch_epoch=280,
+ switch_pipeline=train_pipeline_stage2)
+]
diff --git a/configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py b/configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py
new file mode 100644
index 00000000000..33b62878027
--- /dev/null
+++ b/configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py
@@ -0,0 +1,50 @@
+_base_ = './rtmdet-ins_s_8xb32-300e_coco.py'
+
+checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth' # noqa
+
+model = dict(
+ backbone=dict(
+ deepen_factor=0.167,
+ widen_factor=0.375,
+ init_cfg=dict(
+ type='Pretrained', prefix='backbone.', checkpoint=checkpoint)),
+ neck=dict(in_channels=[96, 192, 384], out_channels=96, num_csp_blocks=1),
+ bbox_head=dict(in_channels=96, feat_channels=96))
+
+train_pipeline = [
+ dict(
+ type='LoadImageFromFile',
+ file_client_args={{_base_.file_client_args}}),
+ dict(
+ type='LoadAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(
+ type='CachedMosaic',
+ img_scale=(640, 640),
+ pad_val=114.0,
+ max_cached_images=20,
+ random_pop=False),
+ dict(
+ type='RandomResize',
+ scale=(1280, 1280),
+ ratio_range=(0.5, 2.0),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=(640, 640)),
+ dict(type='YOLOXHSVRandomAug'),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
+ dict(
+ type='CachedMixUp',
+ img_scale=(640, 640),
+ ratio_range=(1.0, 1.0),
+ max_cached_images=10,
+ random_pop=False,
+ pad_val=(114, 114, 114),
+ prob=0.5),
+ dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1)),
+ dict(type='PackDetInputs')
+]
+
+train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
diff --git a/configs/rtmdet/rtmdet-ins_x_8xb16-300e_coco.py b/configs/rtmdet/rtmdet-ins_x_8xb16-300e_coco.py
new file mode 100644
index 00000000000..daaa640edac
--- /dev/null
+++ b/configs/rtmdet/rtmdet-ins_x_8xb16-300e_coco.py
@@ -0,0 +1,31 @@
+_base_ = './rtmdet-ins_l_8xb32-300e_coco.py'
+
+model = dict(
+ backbone=dict(deepen_factor=1.33, widen_factor=1.25),
+ neck=dict(
+ in_channels=[320, 640, 1280], out_channels=320, num_csp_blocks=4),
+ bbox_head=dict(in_channels=320, feat_channels=320))
+
+base_lr = 0.002
+
+# optimizer
+optim_wrapper = dict(optimizer=dict(lr=base_lr))
+
+# learning rate
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1.0e-5,
+ by_epoch=False,
+ begin=0,
+ end=1000),
+ dict(
+ # use cosine lr from 150 to 300 epoch
+ type='CosineAnnealingLR',
+ eta_min=base_lr * 0.05,
+ begin=_base_.max_epochs // 2,
+ end=_base_.max_epochs,
+ T_max=_base_.max_epochs // 2,
+ by_epoch=True,
+ convert_to_iter_based=True),
+]
diff --git a/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py b/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py
index 33ccb839c6c..59facd71cb6 100644
--- a/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py
+++ b/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py
@@ -18,7 +18,7 @@
widen_factor=1,
channel_attention=True,
norm_cfg=dict(type='SyncBN'),
- act_cfg=dict(type='SiLU')),
+ act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
type='CSPNeXtPAFPN',
in_channels=[256, 512, 1024],
@@ -26,7 +26,7 @@
num_csp_blocks=3,
expand_ratio=0.5,
norm_cfg=dict(type='SyncBN'),
- act_cfg=dict(type='SiLU')),
+ act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='RTMDetSepBNHead',
num_classes=80,
@@ -47,18 +47,18 @@
share_conv=True,
pred_kernel_size=1,
norm_cfg=dict(type='SyncBN'),
- act_cfg=dict(type='SiLU')),
+ act_cfg=dict(type='SiLU', inplace=True)),
train_cfg=dict(
assigner=dict(type='DynamicSoftLabelAssigner', topk=13),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
- nms_pre=1000,
+ nms_pre=30000,
min_bbox_size=0,
- score_thr=0.05,
- nms=dict(type='nms', iou_threshold=0.6),
- max_per_img=100),
+ score_thr=0.001,
+ nms=dict(type='nms', iou_threshold=0.65),
+ max_per_img=300),
)
train_pipeline = [
@@ -134,6 +134,9 @@
val_interval=interval,
dynamic_intervals=[(max_epochs - stage2_num_epochs, 1)])
+val_evaluator = dict(proposal_nums=(100, 1, 10))
+test_evaluator = val_evaluator
+
# optimizer
optim_wrapper = dict(
_delete_=True,
diff --git a/mmdet/datasets/transforms/transforms.py b/mmdet/datasets/transforms/transforms.py
index 7ae88dcc568..c9e95bd7476 100644
--- a/mmdet/datasets/transforms/transforms.py
+++ b/mmdet/datasets/transforms/transforms.py
@@ -3244,6 +3244,9 @@ def transform(self, results: dict) -> dict:
mosaic_bboxes = []
mosaic_bboxes_labels = []
mosaic_ignore_flags = []
+ mosaic_masks = []
+ with_mask = True if 'gt_masks' in results else False
+
if len(results['img'].shape) == 3:
mosaic_img = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3),
@@ -3298,6 +3301,20 @@ def transform(self, results: dict) -> dict:
mosaic_bboxes.append(gt_bboxes_i)
mosaic_bboxes_labels.append(gt_bboxes_labels_i)
mosaic_ignore_flags.append(gt_ignore_flags_i)
+ if with_mask and results_patch.get('gt_masks', None) is not None:
+ gt_masks_i = results_patch['gt_masks']
+ gt_masks_i = gt_masks_i.rescale(float(scale_ratio_i))
+ gt_masks_i = gt_masks_i.translate(
+ out_shape=(int(self.img_scale[0] * 2),
+ int(self.img_scale[1] * 2)),
+ offset=padw,
+ direction='horizontal')
+ gt_masks_i = gt_masks_i.translate(
+ out_shape=(int(self.img_scale[0] * 2),
+ int(self.img_scale[1] * 2)),
+ offset=padh,
+ direction='vertical')
+ mosaic_masks.append(gt_masks_i)
mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
@@ -3317,6 +3334,10 @@ def transform(self, results: dict) -> dict:
results['gt_bboxes'] = mosaic_bboxes
results['gt_bboxes_labels'] = mosaic_bboxes_labels
results['gt_ignore_flags'] = mosaic_ignore_flags
+
+ if with_mask:
+ mosaic_masks = mosaic_masks[0].cat(mosaic_masks)
+ results['gt_masks'] = mosaic_masks[inside_inds]
return results
def __repr__(self):
@@ -3481,6 +3502,7 @@ def transform(self, results: dict) -> dict:
return results
retrieve_img = retrieve_results['img']
+ with_mask = True if 'gt_masks' in results else False
jit_factor = random.uniform(*self.ratio_range)
is_filp = random.uniform(0, 1) > self.flip_ratio
@@ -3532,16 +3554,32 @@ def transform(self, results: dict) -> dict:
# 6. adjust bbox
retrieve_gt_bboxes = retrieve_results['gt_bboxes']
retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio])
+ if with_mask:
+ retrieve_gt_masks = retrieve_results['gt_masks'].rescale(
+ scale_ratio)
+
if self.bbox_clip_border:
retrieve_gt_bboxes.clip_([origin_h, origin_w])
if is_filp:
retrieve_gt_bboxes.flip_([origin_h, origin_w],
direction='horizontal')
+ if with_mask:
+ retrieve_gt_masks = retrieve_gt_masks.flip()
# 7. filter
cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone()
cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset])
+ if with_mask:
+ retrieve_gt_masks = retrieve_gt_masks.translate(
+ out_shape=(target_h, target_w),
+ offset=-x_offset,
+ direction='horizontal')
+ retrieve_gt_masks = retrieve_gt_masks.translate(
+ out_shape=(target_h, target_w),
+ offset=-y_offset,
+ direction='vertical')
+
if self.bbox_clip_border:
cp_retrieve_gt_bboxes.clip_([target_h, target_w])
@@ -3558,19 +3596,25 @@ def transform(self, results: dict) -> dict:
(results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
mixup_gt_ignore_flags = np.concatenate(
(results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
+ if with_mask:
+ mixup_gt_masks = retrieve_gt_masks.cat(
+ [results['gt_masks'], retrieve_gt_masks])
# remove outside bbox
inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy()
mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds]
mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds]
+ if with_mask:
+ mixup_gt_masks = mixup_gt_masks[inside_inds]
results['img'] = mixup_img.astype(np.uint8)
results['img_shape'] = mixup_img.shape
results['gt_bboxes'] = mixup_gt_bboxes
results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
results['gt_ignore_flags'] = mixup_gt_ignore_flags
-
+ if with_mask:
+ results['gt_masks'] = mixup_gt_masks
return results
def __repr__(self):
diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py
index 0ab3ba2018e..469f5cc69d8 100644
--- a/mmdet/models/dense_heads/__init__.py
+++ b/mmdet/models/dense_heads/__init__.py
@@ -34,6 +34,7 @@
from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead
from .rtmdet_head import RTMDetHead, RTMDetSepBNHead
+from .rtmdet_ins_head import RTMDetInsHead, RTMDetInsSepBNHead
from .sabl_retina_head import SABLRetinaHead
from .solo_head import DecoupledSOLOHead, DecoupledSOLOLightHead, SOLOHead
from .solov2_head import SOLOV2Head
@@ -58,5 +59,5 @@
'DecoupledSOLOHead', 'DecoupledSOLOLightHead', 'SOLOV2Head', 'LADHead',
'TOODHead', 'MaskFormerHead', 'Mask2FormerHead', 'DDODHead',
'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead', 'CondInstBboxHead',
- 'CondInstMaskHead'
+ 'CondInstMaskHead', 'RTMDetInsHead', 'RTMDetInsSepBNHead'
]
diff --git a/mmdet/models/dense_heads/rtmdet_head.py b/mmdet/models/dense_heads/rtmdet_head.py
index 42c15c1f6dd..3c53b68669d 100644
--- a/mmdet/models/dense_heads/rtmdet_head.py
+++ b/mmdet/models/dense_heads/rtmdet_head.py
@@ -266,7 +266,7 @@ def loss_by_feat(self,
batch_img_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore)
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
- assign_metrics_list) = cls_reg_targets
+ assign_metrics_list, sampling_results_list) = cls_reg_targets
losses_cls, losses_bbox,\
cls_avg_factors, bbox_avg_factors = multi_apply(
@@ -353,7 +353,7 @@ def get_targets(self,
batch_gt_instances_ignore = [None] * num_imgs
# anchor_list: list(b * [-1, 4])
(all_anchors, all_labels, all_label_weights, all_bbox_targets,
- all_assign_metrics) = multi_apply(
+ all_assign_metrics, sampling_results_list) = multi_apply(
self._get_targets_single,
cls_scores.detach(),
bbox_preds.detach(),
@@ -378,7 +378,7 @@ def get_targets(self,
num_level_anchors)
return (anchors_list, labels_list, label_weights_list,
- bbox_targets_list, assign_metrics_list)
+ bbox_targets_list, assign_metrics_list, sampling_results_list)
def _get_targets_single(self,
cls_scores: Tensor,
@@ -486,7 +486,8 @@ def _get_targets_single(self,
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
assign_metrics = unmap(assign_metrics, num_total_anchors,
inside_flags)
- return (anchors, labels, label_weights, bbox_targets, assign_metrics)
+ return (anchors, labels, label_weights, bbox_targets, assign_metrics,
+ sampling_result)
def get_anchors(self,
featmap_sizes: List[tuple],
diff --git a/mmdet/models/dense_heads/rtmdet_ins_head.py b/mmdet/models/dense_heads/rtmdet_ins_head.py
new file mode 100644
index 00000000000..e355bdb79f8
--- /dev/null
+++ b/mmdet/models/dense_heads/rtmdet_ins_head.py
@@ -0,0 +1,1034 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, is_norm
+from mmcv.ops import batched_nms
+from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
+ normal_init)
+from mmengine.structures import InstanceData
+from torch import Tensor
+
+from mmdet.models.layers.transformer import inverse_sigmoid
+from mmdet.models.utils import (filter_scores_and_topk, multi_apply,
+ select_single_mlvl, sigmoid_geometric_mean)
+from mmdet.registry import MODELS
+from mmdet.structures.bbox import (cat_boxes, distance2bbox, get_box_tensor,
+ get_box_wh, scale_boxes)
+from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean
+from .rtmdet_head import RTMDetHead
+
+
+@MODELS.register_module()
+class RTMDetInsHead(RTMDetHead):
+ """Detection Head of RTMDet-Ins.
+
+ Args:
+ num_prototypes (int): Number of mask prototype features extracted
+ from the mask head. Defaults to 8.
+ dyconv_channels (int): Channel of the dynamic conv layers.
+ Defaults to 8.
+ num_dyconvs (int): Number of the dynamic convolution layers.
+ Defaults to 3.
+ mask_loss_stride (int): Down sample stride of the masks for loss
+ computation. Defaults to 4.
+ loss_mask (:obj:`ConfigDict` or dict): Config dict for mask loss.
+ """
+
+ def __init__(self,
+ *args,
+ num_prototypes: int = 8,
+ dyconv_channels: int = 8,
+ num_dyconvs: int = 3,
+ mask_loss_stride: int = 4,
+ loss_mask=dict(
+ type='DiceLoss',
+ loss_weight=2.0,
+ eps=5e-6,
+ reduction='mean'),
+ **kwargs) -> None:
+ self.num_prototypes = num_prototypes
+ self.num_dyconvs = num_dyconvs
+ self.dyconv_channels = dyconv_channels
+ self.mask_loss_stride = mask_loss_stride
+ super().__init__(*args, **kwargs)
+ self.loss_mask = MODELS.build(loss_mask)
+
+ def _init_layers(self) -> None:
+ """Initialize layers of the head."""
+ super()._init_layers()
+ # a branch to predict kernels of dynamic convs
+ self.kernel_convs = nn.ModuleList()
+ # calculate num dynamic parameters
+ weight_nums, bias_nums = [], []
+ for i in range(self.num_dyconvs):
+ if i == 0:
+ weight_nums.append(
+ # mask prototype and coordinate features
+ (self.num_prototypes + 2) * self.dyconv_channels)
+ bias_nums.append(self.dyconv_channels * 1)
+ elif i == self.num_dyconvs - 1:
+ weight_nums.append(self.dyconv_channels * 1)
+ bias_nums.append(1)
+ else:
+ weight_nums.append(self.dyconv_channels * self.dyconv_channels)
+ bias_nums.append(self.dyconv_channels * 1)
+ self.weight_nums = weight_nums
+ self.bias_nums = bias_nums
+ self.num_gen_params = sum(weight_nums) + sum(bias_nums)
+
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.kernel_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ pred_pad_size = self.pred_kernel_size // 2
+ self.rtm_kernel = nn.Conv2d(
+ self.feat_channels,
+ self.num_gen_params,
+ self.pred_kernel_size,
+ padding=pred_pad_size)
+ self.mask_head = MaskFeatModule(
+ in_channels=self.in_channels,
+ feat_channels=self.feat_channels,
+ stacked_convs=4,
+ num_levels=len(self.prior_generator.strides),
+ num_prototypes=self.num_prototypes,
+ act_cfg=self.act_cfg,
+ norm_cfg=self.norm_cfg)
+
+ def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of classification scores and bbox prediction
+ - cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * num_classes.
+ - bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * 4.
+ - kernel_preds (list[Tensor]): Dynamic conv kernels for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_gen_params.
+ - mask_feat (Tensor): Output feature of the mask head. Each is a
+ 4D-tensor, the channels number is num_prototypes.
+ """
+ mask_feat = self.mask_head(feats)
+
+ cls_scores = []
+ bbox_preds = []
+ kernel_preds = []
+ for idx, (x, scale, stride) in enumerate(
+ zip(feats, self.scales, self.prior_generator.strides)):
+ cls_feat = x
+ reg_feat = x
+ kernel_feat = x
+
+ for cls_layer in self.cls_convs:
+ cls_feat = cls_layer(cls_feat)
+ cls_score = self.rtm_cls(cls_feat)
+
+ for kernel_layer in self.kernel_convs:
+ kernel_feat = kernel_layer(kernel_feat)
+ kernel_pred = self.rtm_kernel(kernel_feat)
+
+ for reg_layer in self.reg_convs:
+ reg_feat = reg_layer(reg_feat)
+
+ if self.with_objectness:
+ objectness = self.rtm_obj(reg_feat)
+ cls_score = inverse_sigmoid(
+ sigmoid_geometric_mean(cls_score, objectness))
+
+ reg_dist = scale(self.rtm_reg(reg_feat)) * stride[0]
+
+ cls_scores.append(cls_score)
+ bbox_preds.append(reg_dist)
+ kernel_preds.append(kernel_pred)
+ return tuple(cls_scores), tuple(bbox_preds), tuple(
+ kernel_preds), mask_feat
+
+ def predict_by_feat(self,
+ cls_scores: List[Tensor],
+ bbox_preds: List[Tensor],
+ kernel_preds: List[Tensor],
+ mask_feat: Tensor,
+ score_factors: Optional[List[Tensor]] = None,
+ batch_img_metas: Optional[List[dict]] = None,
+ cfg: Optional[ConfigType] = None,
+ rescale: bool = False,
+ with_nms: bool = True) -> InstanceList:
+ """Transform a batch of output features extracted from the head into
+ bbox results.
+
+ Note: When score_factors is not None, the cls_scores are
+ usually multiplied by it then obtain the real score used in NMS,
+ such as CenterNess in FCOS, IoU branch in ATSS.
+
+ Args:
+ cls_scores (list[Tensor]): Classification scores for all
+ scale levels, each is a 4D-tensor, has shape
+ (batch_size, num_priors * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for all
+ scale levels, each is a 4D-tensor, has shape
+ (batch_size, num_priors * 4, H, W).
+ kernel_preds (list[Tensor]): Kernel predictions of dynamic
+ convs for all scale levels, each is a 4D-tensor, has shape
+ (batch_size, num_params, H, W).
+ mask_feat (Tensor): Mask prototype features extracted from the
+ mask head, has shape (batch_size, num_prototypes, H, W).
+ score_factors (list[Tensor], optional): Score factor for
+ all scale level, each is a 4D-tensor, has shape
+ (batch_size, num_priors * 1, H, W). Defaults to None.
+ batch_img_metas (list[dict], Optional): Batch image meta info.
+ Defaults to None.
+ cfg (ConfigDict, optional): Test / postprocessing
+ configuration, if None, test_cfg would be used.
+ Defaults to None.
+ rescale (bool): If True, return boxes in original image space.
+ Defaults to False.
+ with_nms (bool): If True, do nms before return boxes.
+ Defaults to True.
+
+ Returns:
+ list[:obj:`InstanceData`]: Object detection results of each image
+ after the post process. Each item usually contains following keys.
+
+ - scores (Tensor): Classification scores, has a shape
+ (num_instance, )
+ - labels (Tensor): Labels of bboxes, has a shape
+ (num_instances, ).
+ - bboxes (Tensor): Has a shape (num_instances, 4),
+ the last dimension 4 arrange as (x1, y1, x2, y2).
+ - masks (Tensor): Has a shape (num_instances, h, w).
+ """
+ assert len(cls_scores) == len(bbox_preds)
+
+ if score_factors is None:
+ # e.g. Retina, FreeAnchor, Foveabox, etc.
+ with_score_factors = False
+ else:
+ # e.g. FCOS, PAA, ATSS, AutoAssign, etc.
+ with_score_factors = True
+ assert len(cls_scores) == len(score_factors)
+
+ num_levels = len(cls_scores)
+
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_priors = self.prior_generator.grid_priors(
+ featmap_sizes,
+ dtype=cls_scores[0].dtype,
+ device=cls_scores[0].device,
+ with_stride=True)
+
+ result_list = []
+
+ for img_id in range(len(batch_img_metas)):
+ img_meta = batch_img_metas[img_id]
+ cls_score_list = select_single_mlvl(
+ cls_scores, img_id, detach=True)
+ bbox_pred_list = select_single_mlvl(
+ bbox_preds, img_id, detach=True)
+ kernel_pred_list = select_single_mlvl(
+ kernel_preds, img_id, detach=True)
+ if with_score_factors:
+ score_factor_list = select_single_mlvl(
+ score_factors, img_id, detach=True)
+ else:
+ score_factor_list = [None for _ in range(num_levels)]
+
+ results = self._predict_by_feat_single(
+ cls_score_list=cls_score_list,
+ bbox_pred_list=bbox_pred_list,
+ kernel_pred_list=kernel_pred_list,
+ mask_feat=mask_feat[img_id],
+ score_factor_list=score_factor_list,
+ mlvl_priors=mlvl_priors,
+ img_meta=img_meta,
+ cfg=cfg,
+ rescale=rescale,
+ with_nms=with_nms)
+ result_list.append(results)
+ return result_list
+
+ def _predict_by_feat_single(self,
+ cls_score_list: List[Tensor],
+ bbox_pred_list: List[Tensor],
+ kernel_pred_list: List[Tensor],
+ mask_feat: Tensor,
+ score_factor_list: List[Tensor],
+ mlvl_priors: List[Tensor],
+ img_meta: dict,
+ cfg: ConfigType,
+ rescale: bool = False,
+ with_nms: bool = True) -> InstanceData:
+ """Transform a single image's features extracted from the head into
+ bbox and mask results.
+
+ Args:
+ cls_score_list (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_priors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas from
+ all scale levels of a single image, each item has shape
+ (num_priors * 4, H, W).
+ kernel_preds (list[Tensor]): Kernel predictions of dynamic
+ convs for all scale levels of a single image, each is a
+ 4D-tensor, has shape (num_params, H, W).
+ mask_feat (Tensor): Mask prototype features of a single image
+ extracted from the mask head, has shape (num_prototypes, H, W).
+ score_factor_list (list[Tensor]): Score factor from all scale
+ levels of a single image, each item has shape
+ (num_priors * 1, H, W).
+ mlvl_priors (list[Tensor]): Each element in the list is
+ the priors of a single level in feature pyramid. In all
+ anchor-based methods, it has shape (num_priors, 4). In
+ all anchor-free methods, it has shape (num_priors, 2)
+ when `with_stride=True`, otherwise it still has shape
+ (num_priors, 4).
+ img_meta (dict): Image meta info.
+ cfg (mmengine.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Defaults to False.
+ with_nms (bool): If True, do nms before return boxes.
+ Defaults to True.
+
+ Returns:
+ :obj:`InstanceData`: Detection results of each image
+ after the post process.
+ Each item usually contains following keys.
+
+ - scores (Tensor): Classification scores, has a shape
+ (num_instance, )
+ - labels (Tensor): Labels of bboxes, has a shape
+ (num_instances, ).
+ - bboxes (Tensor): Has a shape (num_instances, 4),
+ the last dimension 4 arrange as (x1, y1, x2, y2).
+ - masks (Tensor): Has a shape (num_instances, h, w).
+ """
+ if score_factor_list[0] is None:
+ # e.g. Retina, FreeAnchor, etc.
+ with_score_factors = False
+ else:
+ # e.g. FCOS, PAA, ATSS, etc.
+ with_score_factors = True
+
+ cfg = self.test_cfg if cfg is None else cfg
+ cfg = copy.deepcopy(cfg)
+ img_shape = img_meta['img_shape']
+ nms_pre = cfg.get('nms_pre', -1)
+
+ mlvl_bbox_preds = []
+ mlvl_kernels = []
+ mlvl_valid_priors = []
+ mlvl_scores = []
+ mlvl_labels = []
+ if with_score_factors:
+ mlvl_score_factors = []
+ else:
+ mlvl_score_factors = None
+
+ for level_idx, (cls_score, bbox_pred, kernel_pred,
+ score_factor, priors) in \
+ enumerate(zip(cls_score_list, bbox_pred_list, kernel_pred_list,
+ score_factor_list, mlvl_priors)):
+
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+
+ dim = self.bbox_coder.encode_size
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim)
+ if with_score_factors:
+ score_factor = score_factor.permute(1, 2,
+ 0).reshape(-1).sigmoid()
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ kernel_pred = kernel_pred.permute(1, 2, 0).reshape(
+ -1, self.num_gen_params)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ scores = cls_score.softmax(-1)[:, :-1]
+
+ # After https://github.com/open-mmlab/mmdetection/pull/6268/,
+ # this operation keeps fewer bboxes under the same `nms_pre`.
+ # There is no difference in performance for most models. If you
+ # find a slight drop in performance, you can set a larger
+ # `nms_pre` than before.
+ score_thr = cfg.get('score_thr', 0)
+
+ results = filter_scores_and_topk(
+ scores, score_thr, nms_pre,
+ dict(
+ bbox_pred=bbox_pred,
+ priors=priors,
+ kernel_pred=kernel_pred))
+ scores, labels, keep_idxs, filtered_results = results
+
+ bbox_pred = filtered_results['bbox_pred']
+ priors = filtered_results['priors']
+ kernel_pred = filtered_results['kernel_pred']
+
+ if with_score_factors:
+ score_factor = score_factor[keep_idxs]
+
+ mlvl_bbox_preds.append(bbox_pred)
+ mlvl_valid_priors.append(priors)
+ mlvl_scores.append(scores)
+ mlvl_labels.append(labels)
+ mlvl_kernels.append(kernel_pred)
+
+ if with_score_factors:
+ mlvl_score_factors.append(score_factor)
+
+ bbox_pred = torch.cat(mlvl_bbox_preds)
+ priors = cat_boxes(mlvl_valid_priors)
+ bboxes = self.bbox_coder.decode(
+ priors[..., :2], bbox_pred, max_shape=img_shape)
+
+ results = InstanceData()
+ results.bboxes = bboxes
+ results.priors = priors
+ results.scores = torch.cat(mlvl_scores)
+ results.labels = torch.cat(mlvl_labels)
+ results.kernels = torch.cat(mlvl_kernels)
+ if with_score_factors:
+ results.score_factors = torch.cat(mlvl_score_factors)
+
+ return self._bbox_mask_post_process(
+ results=results,
+ mask_feat=mask_feat,
+ cfg=cfg,
+ rescale=rescale,
+ with_nms=with_nms,
+ img_meta=img_meta)
+
+ def _bbox_mask_post_process(
+ self,
+ results: InstanceData,
+ mask_feat,
+ cfg: ConfigType,
+ rescale: bool = False,
+ with_nms: bool = True,
+ img_meta: Optional[dict] = None) -> InstanceData:
+ """bbox and mask post-processing method.
+
+ The boxes would be rescaled to the original image scale and do
+ the nms operation. Usually `with_nms` is False is used for aug test.
+
+ Args:
+ results (:obj:`InstaceData`): Detection instance results,
+ each item has shape (num_bboxes, ).
+ cfg (ConfigDict): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default to False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default to True.
+ img_meta (dict, optional): Image meta info. Defaults to None.
+
+ Returns:
+ :obj:`InstanceData`: Detection results of each image
+ after the post process.
+ Each item usually contains following keys.
+
+ - scores (Tensor): Classification scores, has a shape
+ (num_instance, )
+ - labels (Tensor): Labels of bboxes, has a shape
+ (num_instances, ).
+ - bboxes (Tensor): Has a shape (num_instances, 4),
+ the last dimension 4 arrange as (x1, y1, x2, y2).
+ - masks (Tensor): Has a shape (num_instances, h, w).
+ """
+ stride = self.prior_generator.strides[0][0]
+ if rescale:
+ assert img_meta.get('scale_factor') is not None
+ scale_factor = [1 / s for s in img_meta['scale_factor']]
+ results.bboxes = scale_boxes(results.bboxes, scale_factor)
+
+ if hasattr(results, 'score_factors'):
+ # TODO: Add sqrt operation in order to be consistent with
+ # the paper.
+ score_factors = results.pop('score_factors')
+ results.scores = results.scores * score_factors
+
+ # filter small size bboxes
+ if cfg.get('min_bbox_size', -1) >= 0:
+ w, h = get_box_wh(results.bboxes)
+ valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
+ if not valid_mask.all():
+ results = results[valid_mask]
+
+ # TODO: deal with `with_nms` and `nms_cfg=None` in test_cfg
+ assert with_nms, 'with_nms must be True for RTMDet-Ins'
+ if results.bboxes.numel() > 0:
+ bboxes = get_box_tensor(results.bboxes)
+ det_bboxes, keep_idxs = batched_nms(bboxes, results.scores,
+ results.labels, cfg.nms)
+ results = results[keep_idxs]
+ # some nms would reweight the score, such as softnms
+ results.scores = det_bboxes[:, -1]
+ results = results[:cfg.max_per_img]
+
+ # process masks
+ mask_logits = self._mask_predict_by_feat_single(
+ mask_feat, results.kernels, results.priors)
+
+ mask_logits = F.interpolate(
+ mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear')
+ if rescale:
+ ori_h, ori_w = img_meta['ori_shape'][:2]
+ mask_logits = F.interpolate(
+ mask_logits,
+ size=[
+ math.ceil(mask_logits.shape[-2] * scale_factor[0]),
+ math.ceil(mask_logits.shape[-1] * scale_factor[1])
+ ],
+ mode='bilinear',
+ align_corners=False)[..., :ori_h, :ori_w]
+ masks = mask_logits.sigmoid().squeeze(0)
+ masks = masks > cfg.mask_thr_binary
+ results.masks = masks
+ else:
+ h, w = img_meta['ori_shape'][:2] if rescale else img_meta[
+ 'img_shape'][:2]
+ results.masks = torch.zeros(
+ size=(results.bboxes.shape[0], h, w),
+ dtype=torch.bool,
+ device=results.bboxes.device)
+
+ return results
+
+ def parse_dynamic_params(self, flatten_kernels: Tensor) -> tuple:
+ """split kernel head prediction to conv weight and bias."""
+ n_inst = flatten_kernels.size(0)
+ n_layers = len(self.weight_nums)
+ params_splits = list(
+ torch.split_with_sizes(
+ flatten_kernels, self.weight_nums + self.bias_nums, dim=1))
+ weight_splits = params_splits[:n_layers]
+ bias_splits = params_splits[n_layers:]
+ for i in range(n_layers):
+ if i < n_layers - 1:
+ weight_splits[i] = weight_splits[i].reshape(
+ n_inst * self.dyconv_channels, -1, 1, 1)
+ bias_splits[i] = bias_splits[i].reshape(n_inst *
+ self.dyconv_channels)
+ else:
+ weight_splits[i] = weight_splits[i].reshape(n_inst, -1, 1, 1)
+ bias_splits[i] = bias_splits[i].reshape(n_inst)
+
+ return weight_splits, bias_splits
+
+ def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor,
+ priors: Tensor) -> Tensor:
+ """Generate mask logits from mask features with dynamic convs.
+
+ Args:
+ mask_feat (Tensor): Mask prototype features.
+ Has shape (num_prototypes, H, W).
+ kernels (Tensor): Kernel parameters for each instance.
+ Has shape (num_instance, num_params)
+ priors (Tensor): Center priors for each instance.
+ Has shape (num_instance, 4).
+ Returns:
+ Tensor: Instance segmentation masks for each instance.
+ Has shape (num_instance, H, W).
+ """
+ num_inst = priors.shape[0]
+ h, w = mask_feat.size()[-2:]
+ if num_inst < 1:
+ return torch.empty(
+ size=(num_inst, h, w),
+ dtype=mask_feat.dtype,
+ device=mask_feat.device)
+ if len(mask_feat.shape) < 4:
+ mask_feat.unsqueeze(0)
+
+ coord = self.prior_generator.single_level_grid_priors(
+ (h, w), level_idx=0).reshape(1, -1, 2)
+ num_inst = priors.shape[0]
+ points = priors[:, :2].reshape(-1, 1, 2)
+ strides = priors[:, 2:].reshape(-1, 1, 2)
+ relative_coord = (points - coord).permute(0, 2, 1) / (
+ strides[..., 0].reshape(-1, 1, 1) * 8)
+ relative_coord = relative_coord.reshape(num_inst, 2, h, w)
+
+ mask_feat = torch.cat(
+ [relative_coord,
+ mask_feat.repeat(num_inst, 1, 1, 1)], dim=1)
+ weights, biases = self.parse_dynamic_params(kernels)
+
+ n_layers = len(weights)
+ x = mask_feat.reshape(1, -1, h, w)
+ for i, (weight, bias) in enumerate(zip(weights, biases)):
+ x = F.conv2d(
+ x, weight, bias=bias, stride=1, padding=0, groups=num_inst)
+ if i < n_layers - 1:
+ x = F.relu(x)
+ x = x.reshape(num_inst, h, w)
+ return x
+
+ def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor,
+ sampling_results_list: list,
+ batch_gt_instances: InstanceList) -> Tensor:
+ """Compute instance segmentation loss.
+
+ Args:
+ mask_feats (list[Tensor]): Mask prototype features extracted from
+ the mask head. Has shape (N, num_prototypes, H, W)
+ flatten_kernels (list[Tensor]): Kernels of the dynamic conv layers.
+ Has shape (N, num_instances, num_params)
+ sampling_results_list (list[:obj:`SamplingResults`]) Batch of
+ assignment results.
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
+ gt_instance. It usually includes ``bboxes`` and ``labels``
+ attributes.
+
+ Returns:
+ Tensor: The mask loss tensor.
+ """
+ batch_pos_mask_logits = []
+ pos_gt_masks = []
+ for idx, (mask_feat, kernels, sampling_results,
+ gt_instances) in enumerate(
+ zip(mask_feats, flatten_kernels, sampling_results_list,
+ batch_gt_instances)):
+ pos_priors = sampling_results.pos_priors
+ pos_inds = sampling_results.pos_inds
+ pos_kernels = kernels[pos_inds] # n_pos, num_gen_params
+ pos_mask_logits = self._mask_predict_by_feat_single(
+ mask_feat, pos_kernels, pos_priors)
+ if gt_instances.masks.numel() == 0:
+ gt_masks = torch.empty_like(gt_instances.masks)
+ else:
+ gt_masks = gt_instances.masks[
+ sampling_results.pos_assigned_gt_inds, :]
+ batch_pos_mask_logits.append(pos_mask_logits)
+ pos_gt_masks.append(gt_masks)
+
+ pos_gt_masks = torch.cat(pos_gt_masks, 0)
+ batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0)
+
+ # avg_factor
+ num_pos = batch_pos_mask_logits.shape[0]
+ num_pos = reduce_mean(mask_feats.new_tensor([num_pos
+ ])).clamp_(min=1).item()
+
+ if batch_pos_mask_logits.shape[0] == 0:
+ return mask_feats.sum() * 0
+
+ scale = self.prior_generator.strides[0][0] // self.mask_loss_stride
+ # upsample pred masks
+ batch_pos_mask_logits = F.interpolate(
+ batch_pos_mask_logits.unsqueeze(0),
+ scale_factor=scale,
+ mode='bilinear',
+ align_corners=False).squeeze(0)
+ # downsample gt masks
+ pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride //
+ 2::self.mask_loss_stride,
+ self.mask_loss_stride //
+ 2::self.mask_loss_stride]
+
+ loss_mask = self.loss_mask(
+ batch_pos_mask_logits,
+ pos_gt_masks,
+ weight=None,
+ avg_factor=num_pos)
+
+ return loss_mask
+
+ def loss_by_feat(self,
+ cls_scores: List[Tensor],
+ bbox_preds: List[Tensor],
+ kernel_preds: List[Tensor],
+ mask_feat: Tensor,
+ batch_gt_instances: InstanceList,
+ batch_img_metas: List[dict],
+ batch_gt_instances_ignore: OptInstanceList = None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Decoded box for each scale
+ level with shape (N, num_anchors * 4, H, W) in
+ [tl_x, tl_y, br_x, br_y] format.
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
+ gt_instance. It usually includes ``bboxes`` and ``labels``
+ attributes.
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
+ Batch of gt_instances_ignore. It includes ``bboxes`` attribute
+ data that is ignored during training and testing.
+ Defaults to None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ num_imgs = len(batch_img_metas)
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, batch_img_metas, device=device)
+ flatten_cls_scores = torch.cat([
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
+ self.cls_out_channels)
+ for cls_score in cls_scores
+ ], 1)
+ flatten_kernels = torch.cat([
+ kernel_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
+ self.num_gen_params)
+ for kernel_pred in kernel_preds
+ ], 1)
+ decoded_bboxes = []
+ for anchor, bbox_pred in zip(anchor_list[0], bbox_preds):
+ anchor = anchor.reshape(-1, 4)
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
+ bbox_pred = distance2bbox(anchor, bbox_pred)
+ decoded_bboxes.append(bbox_pred)
+
+ flatten_bboxes = torch.cat(decoded_bboxes, 1)
+ for gt_instances in batch_gt_instances:
+ gt_instances.masks = gt_instances.masks.to_tensor(
+ dtype=torch.bool, device=device)
+
+ cls_reg_targets = self.get_targets(
+ flatten_cls_scores,
+ flatten_bboxes,
+ anchor_list,
+ valid_flag_list,
+ batch_gt_instances,
+ batch_img_metas,
+ batch_gt_instances_ignore=batch_gt_instances_ignore)
+ (anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ assign_metrics_list, sampling_results_list) = cls_reg_targets
+
+ losses_cls, losses_bbox,\
+ cls_avg_factors, bbox_avg_factors = multi_apply(
+ self.loss_by_feat_single,
+ cls_scores,
+ decoded_bboxes,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ assign_metrics_list,
+ self.prior_generator.strides)
+
+ cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
+ losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
+
+ bbox_avg_factor = reduce_mean(
+ sum(bbox_avg_factors)).clamp_(min=1).item()
+ losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
+
+ loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels,
+ sampling_results_list,
+ batch_gt_instances)
+ loss = dict(
+ loss_cls=losses_cls, loss_bbox=losses_bbox, loss_mask=loss_mask)
+ return loss
+
+
+class MaskFeatModule(BaseModule):
+ """Mask feature head used in RTMDet-Ins.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ feat_channels (int): Number of hidden channels of the mask feature
+ map branch.
+ num_levels (int): The starting feature map level from RPN that
+ will be used to predict the mask feature map.
+ num_prototypes (int): Number of output channel of the mask feature
+ map branch. This is the channel count of the mask
+ feature map that to be dynamically convolved with the predicted
+ kernel.
+ stacked_convs (int): Number of convs in mask feature branch.
+ act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
+ Default: dict(type='ReLU', inplace=True)
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ feat_channels: int = 256,
+ stacked_convs: int = 4,
+ num_levels: int = 3,
+ num_prototypes: int = 8,
+ act_cfg: ConfigType = dict(type='ReLU', inplace=True),
+ norm_cfg: ConfigType = dict(type='BN')
+ ) -> None:
+ super().__init__(init_cfg=None)
+ self.num_levels = num_levels
+ self.fusion_conv = nn.Conv2d(num_levels * in_channels, in_channels, 1)
+ convs = []
+ for i in range(stacked_convs):
+ in_c = in_channels if i == 0 else feat_channels
+ convs.append(
+ ConvModule(
+ in_c,
+ feat_channels,
+ 3,
+ padding=1,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg))
+ self.stacked_convs = nn.Sequential(*convs)
+ self.projection = nn.Conv2d(
+ feat_channels, num_prototypes, kernel_size=1)
+
+ def forward(self, features: Tuple[Tensor, ...]) -> Tensor:
+ # multi-level feature fusion
+ fusion_feats = [features[0]]
+ size = features[0].shape[-2:]
+ for i in range(1, self.num_levels):
+ f = F.interpolate(features[i], size=size, mode='bilinear')
+ fusion_feats.append(f)
+ fusion_feats = torch.cat(fusion_feats, dim=1)
+ fusion_feats = self.fusion_conv(fusion_feats)
+ # pred mask feats
+ mask_features = self.stacked_convs(fusion_feats)
+ mask_features = self.projection(mask_features)
+ return mask_features
+
+
+@MODELS.register_module()
+class RTMDetInsSepBNHead(RTMDetInsHead):
+ """Detection Head of RTMDet-Ins with sep-bn layers.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ share_conv (bool): Whether to share conv layers between stages.
+ Defaults to True.
+ norm_cfg (:obj:`ConfigDict` or dict)): Config dict for normalization
+ layer. Defaults to dict(type='BN').
+ act_cfg (:obj:`ConfigDict` or dict)): Config dict for activation layer.
+ Defaults to dict(type='SiLU', inplace=True).
+ pred_kernel_size (int): Kernel size of prediction layer. Defaults to 1.
+ """
+
+ def __init__(self,
+ num_classes: int,
+ in_channels: int,
+ share_conv: bool = True,
+ with_objectness: bool = False,
+ norm_cfg: ConfigType = dict(type='BN', requires_grad=True),
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
+ pred_kernel_size: int = 1,
+ **kwargs) -> None:
+ self.share_conv = share_conv
+ super().__init__(
+ num_classes,
+ in_channels,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ pred_kernel_size=pred_kernel_size,
+ with_objectness=with_objectness,
+ **kwargs)
+
+ def _init_layers(self) -> None:
+ """Initialize layers of the head."""
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ self.kernel_convs = nn.ModuleList()
+
+ self.rtm_cls = nn.ModuleList()
+ self.rtm_reg = nn.ModuleList()
+ self.rtm_kernel = nn.ModuleList()
+ self.rtm_obj = nn.ModuleList()
+
+ # calculate num dynamic parameters
+ weight_nums, bias_nums = [], []
+ for i in range(self.num_dyconvs):
+ if i == 0:
+ weight_nums.append(
+ (self.num_prototypes + 2) * self.dyconv_channels)
+ bias_nums.append(self.dyconv_channels)
+ elif i == self.num_dyconvs - 1:
+ weight_nums.append(self.dyconv_channels)
+ bias_nums.append(1)
+ else:
+ weight_nums.append(self.dyconv_channels * self.dyconv_channels)
+ bias_nums.append(self.dyconv_channels)
+ self.weight_nums = weight_nums
+ self.bias_nums = bias_nums
+ self.num_gen_params = sum(weight_nums) + sum(bias_nums)
+ pred_pad_size = self.pred_kernel_size // 2
+
+ for n in range(len(self.prior_generator.strides)):
+ cls_convs = nn.ModuleList()
+ reg_convs = nn.ModuleList()
+ kernel_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ kernel_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.cls_convs.append(cls_convs)
+ self.reg_convs.append(cls_convs)
+ self.kernel_convs.append(kernel_convs)
+
+ self.rtm_cls.append(
+ nn.Conv2d(
+ self.feat_channels,
+ self.num_base_priors * self.cls_out_channels,
+ self.pred_kernel_size,
+ padding=pred_pad_size))
+ self.rtm_reg.append(
+ nn.Conv2d(
+ self.feat_channels,
+ self.num_base_priors * 4,
+ self.pred_kernel_size,
+ padding=pred_pad_size))
+ self.rtm_kernel.append(
+ nn.Conv2d(
+ self.feat_channels,
+ self.num_gen_params,
+ self.pred_kernel_size,
+ padding=pred_pad_size))
+ if self.with_objectness:
+ self.rtm_obj.append(
+ nn.Conv2d(
+ self.feat_channels,
+ 1,
+ self.pred_kernel_size,
+ padding=pred_pad_size))
+
+ if self.share_conv:
+ for n in range(len(self.prior_generator.strides)):
+ for i in range(self.stacked_convs):
+ self.cls_convs[n][i].conv = self.cls_convs[0][i].conv
+ self.reg_convs[n][i].conv = self.reg_convs[0][i].conv
+
+ self.mask_head = MaskFeatModule(
+ in_channels=self.in_channels,
+ feat_channels=self.feat_channels,
+ stacked_convs=4,
+ num_levels=len(self.prior_generator.strides),
+ num_prototypes=self.num_prototypes,
+ act_cfg=self.act_cfg,
+ norm_cfg=self.norm_cfg)
+
+ def init_weights(self) -> None:
+ """Initialize weights of the head."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, mean=0, std=0.01)
+ if is_norm(m):
+ constant_init(m, 1)
+ bias_cls = bias_init_with_prob(0.01)
+ for rtm_cls, rtm_reg, rtm_kernel in zip(self.rtm_cls, self.rtm_reg,
+ self.rtm_kernel):
+ normal_init(rtm_cls, std=0.01, bias=bias_cls)
+ normal_init(rtm_reg, std=0.01, bias=1)
+ if self.with_objectness:
+ for rtm_obj in self.rtm_obj:
+ normal_init(rtm_obj, std=0.01, bias=bias_cls)
+
+ def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of classification scores and bbox prediction
+ - cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * num_classes.
+ - bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * 4.
+ - kernel_preds (list[Tensor]): Dynamic conv kernels for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_gen_params.
+ - mask_feat (Tensor): Output feature of the mask head. Each is a
+ 4D-tensor, the channels number is num_prototypes.
+ """
+ mask_feat = self.mask_head(feats)
+
+ cls_scores = []
+ bbox_preds = []
+ kernel_preds = []
+ for idx, (x, stride) in enumerate(
+ zip(feats, self.prior_generator.strides)):
+ cls_feat = x
+ reg_feat = x
+ kernel_feat = x
+
+ for cls_layer in self.cls_convs[idx]:
+ cls_feat = cls_layer(cls_feat)
+ cls_score = self.rtm_cls[idx](cls_feat)
+
+ for kernel_layer in self.kernel_convs[idx]:
+ kernel_feat = kernel_layer(kernel_feat)
+ kernel_pred = self.rtm_kernel[idx](kernel_feat)
+
+ for reg_layer in self.reg_convs[idx]:
+ reg_feat = reg_layer(reg_feat)
+
+ if self.with_objectness:
+ objectness = self.rtm_obj[idx](reg_feat)
+ cls_score = inverse_sigmoid(
+ sigmoid_geometric_mean(cls_score, objectness))
+
+ reg_dist = F.relu(self.rtm_reg[idx](reg_feat)) * stride[0]
+
+ cls_scores.append(cls_score)
+ bbox_preds.append(reg_dist)
+ kernel_preds.append(kernel_pred)
+ return tuple(cls_scores), tuple(bbox_preds), tuple(
+ kernel_preds), mask_feat
diff --git a/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py b/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py
index 00276e05b80..3fc7af39b22 100644
--- a/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py
+++ b/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py
@@ -16,6 +16,26 @@
EPS = 1.0e-7
+def center_of_mass(masks: Tensor, eps: float = 1e-7) -> Tensor:
+ """Compute the masks center of mass.
+
+ Args:
+ masks: Mask tensor, has shape (num_masks, H, W).
+ eps: a small number to avoid normalizer to be zero.
+ Defaults to 1e-7.
+ Returns:
+ Tensor: The masks center of mass. Has shape (num_masks, 2).
+ """
+ n, h, w = masks.shape
+ grid_h = torch.arange(h, device=masks.device)[:, None]
+ grid_w = torch.arange(w, device=masks.device)
+ normalizer = masks.sum(dim=(1, 2)).float().clamp(min=eps)
+ center_y = (masks * grid_h).sum(dim=(1, 2)) / normalizer
+ center_x = (masks * grid_w).sum(dim=(1, 2)) / normalizer
+ center = torch.cat([center_x[:, None], center_y[:, None]], dim=1)
+ return center
+
+
@TASK_UTILS.register_module()
class DynamicSoftLabelAssigner(BaseAssigner):
"""Computes matching between predictions and ground truth with dynamic soft
@@ -118,7 +138,9 @@ def assign(self,
dtype=torch.long)
return AssignResult(
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
- if isinstance(gt_bboxes, BaseBoxes):
+ if hasattr(gt_instances, 'masks'):
+ gt_center = center_of_mass(gt_instances.masks, eps=EPS)
+ elif isinstance(gt_bboxes, BaseBoxes):
gt_center = gt_bboxes.centers
else:
# Tensor boxes will be treated as horizontal boxes by defaults
diff --git a/setup.cfg b/setup.cfg
index 3014050d6b8..70dd621c8f5 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -18,4 +18,4 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
[codespell]
skip = *.ipynb
quiet-level = 3
-ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam
+ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota
|