Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enchance] Update FAQ docs #6587

Merged
merged 27 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
18e1cdc
Fix mosaic repr typo (#6523)
lkm2835 Nov 18, 2021
eb4c6ac
Include mmflow in readme (#6545)
Czm369 Nov 19, 2021
2068e3e
Make OHEM work with seesaw loss (#6514)
ohwi Nov 19, 2021
2cb9a43
[Enhance] Support file_client in Datasets and evaluating panoptic res…
AronLin Nov 22, 2021
4c48569
Fix MMDetection model to ONNX command (#6558)
Rishit-dagli Nov 22, 2021
17858d3
Update README.md (#6567)
RangiLyu Nov 23, 2021
46fbf2f
[Feature] Support custom persistent_workers (#6435)
hhaAndroid Nov 23, 2021
74eb539
Fix SSD512 config error (#6574)
hhaAndroid Nov 24, 2021
5b95b46
Catch symlink failure on Windows (#6482)
del-zhenwu Nov 24, 2021
5bf6695
[Feature] Support Label Assignment Distillation (LAD) (#6342)
thuyngch Nov 24, 2021
c2d03ba
[Fix] Avoid infinite GPU waiting in dist training (#6501)
fingertap Nov 24, 2021
143ff85
Support to collect the best models (#6560)
hhaAndroid Nov 24, 2021
4c13bf1
[Enhance]: Optimize augmentation pipeline to speed up training. (#6442)
RangiLyu Nov 24, 2021
e130054
Refactor YOLOX (#6443)
hhaAndroid Nov 24, 2021
d9697f3
[Refactor] Remove some code in `mmdet/apis/train.py` (#6576)
Czm369 Nov 24, 2021
27862b6
Fix lad repeatedly output warning message (#6584)
hhaAndroid Nov 24, 2021
fcfc695
update faq docs
hhaAndroid Nov 25, 2021
062b1fe
update
hhaAndroid Nov 25, 2021
e413f9e
update
hhaAndroid Nov 25, 2021
6b1550e
update
hhaAndroid Dec 2, 2021
fdf8b97
merge dev-2.19.1
hhaAndroid Dec 2, 2021
bf20e69
fix lint
hhaAndroid Dec 2, 2021
76eaf17
update
hhaAndroid Dec 8, 2021
a78471f
update
hhaAndroid Dec 8, 2021
a9f7358
update
hhaAndroid Dec 8, 2021
965426a
update readme
hhaAndroid Dec 8, 2021
898f160
Rephrase
ZwwWayne Dec 8, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions .dev_scripts/gather_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def get_final_epoch(config):
return cfg.runner.max_epochs


def get_best_epoch(exp_dir):
best_epoch_full_path = list(
sorted(glob.glob(osp.join(exp_dir, 'best_*.pth'))))[-1]
best_epoch_model_path = best_epoch_full_path.split('/')[-1]
best_epoch = best_epoch_model_path.split('_')[-1].split('.')[0]
return best_epoch_model_path, int(best_epoch)


def get_real_epoch(config):
cfg = mmcv.Config.fromfile('./configs/' + config)
epoch = cfg.runner.max_epochs
Expand Down Expand Up @@ -160,6 +168,10 @@ def parse_args():
help='root path of benchmarked models to be gathered')
parser.add_argument(
'out', type=str, help='output path of gathered models to be stored')
parser.add_argument(
'--best',
action='store_true',
help='whether to gather the best model.')

args = parser.parse_args()
return args
Expand Down Expand Up @@ -187,10 +199,13 @@ def main():
for used_config in used_configs:
exp_dir = osp.join(models_root, used_config)
# check whether the exps is finished
final_epoch = get_final_epoch(used_config)
final_model = 'epoch_{}.pth'.format(final_epoch)
model_path = osp.join(exp_dir, final_model)
if args.best is True:
final_model, final_epoch = get_best_epoch(exp_dir)
else:
final_epoch = get_final_epoch(used_config)
final_model = 'epoch_{}.pth'.format(final_epoch)

model_path = osp.join(exp_dir, final_model)
# skip if the model is still training
if not osp.exists(model_path):
continue
Expand Down Expand Up @@ -221,6 +236,7 @@ def main():
results=model_performance,
epochs=final_epoch,
model_time=model_time,
final_model=final_model,
log_json_path=osp.split(log_json_path)[-1]))

# publish model for each checkpoint
Expand All @@ -234,7 +250,7 @@ def main():
model_name += '_' + model['model_time']
publish_model_path = osp.join(model_publish_dir, model_name)
trained_model_path = osp.join(models_root, model['config'],
'epoch_{}.pth'.format(model['epochs']))
model['final_model'])

# convert model
final_model_path = process_checkpoint(trained_model_path,
Expand All @@ -254,9 +270,9 @@ def main():
config_path = osp.join(
'configs',
config_path) if 'configs' not in config_path else config_path
target_cconfig_path = osp.split(config_path)[-1]
shutil.copy(config_path,
osp.join(model_publish_dir, target_cconfig_path))
target_config_path = osp.split(config_path)[-1]
shutil.copy(config_path, osp.join(model_publish_dir,
target_config_path))

model['model_path'] = final_model_path
publish_model_infos.append(model)
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

[📘Documentation](https://mmdetection.readthedocs.io/en/v2.18.1/) |
[🛠️Installation](https://mmdetection.readthedocs.io/en/v2.18.1/get_started.html) |
[👀Model Zoo](https://mmdetection.readthedocs.io/zh_CN/v2.18.1/model_zoo.html) |
[👀Model Zoo](https://mmdetection.readthedocs.io/en/v2.18.1/model_zoo.html) |
[🆕Update News](https://mmdetection.readthedocs.io/en/v2.18.1/changelog.html) |
[🚀Ongoing Projects](https://github.com/open-mmlab/mmdetection/projects) |
[🤔Reporting Issues](https://github.com/open-mmlab/mmdetection/issues/new/choose)
Expand Down Expand Up @@ -205,3 +205,4 @@ If you use this toolbox or benchmark in your research, please cite this project.
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox.
- [MMOCR](https://github.com/open-mmlab/mmocr): A Comprehensive Toolbox for Text Detection, Recognition and Understanding.
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox.
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark.
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ MMDetection 是一款由来自不同高校和企业的研发人员共同参与
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab 图像视频编辑工具箱
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具包
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab 图片视频生成模型工具箱
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准

## 欢迎加入 OpenMMLab 社区

Expand Down
32 changes: 32 additions & 0 deletions configs/lad/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Improving Object Detection by Label Assignment Distillation

<!-- [ALGORITHM] -->

```latex
@inproceedings{nguyen2021improving,
title={Improving Object Detection by Label Assignment Distillation},
author={Chuong H. Nguyen and Thuy C. Nguyen and Tuan N. Tang and Nam L. H. Phan},
booktitle = {WACV},
year={2022}
}
```

## Results and Models

We provide config files to reproduce the object detection results in the
WACV 2022 paper for Improving Object Detection by Label Assignment
Distillation.

### PAA with LAD

| Teacher | Student | Training schedule | AP (val) | Config |
| :-------: | :-----: | :---------------: | :------: | :----------------------------------------------------: |
| -- | R-50 | 1x | 40.4 | |
| -- | R-101 | 1x | 42.6 | |
| R-101 | R-50 | 1x | 41.6 | [config](configs/lad/lad_r50_paa_r101_fpn_coco_1x.py) |
| R-50 | R-101 | 1x | 43.2 | [config](configs/lad/lad_r101_paa_r50_fpn_coco_1x.py) |

## Note

- Meaning of Config name: lad_r50(student model)_paa(based on paa)_r101(teacher model)_fpn(neck)_coco(dataset)_1x(12 epoch).py
- Results may fluctuate by about 0.2 mAP.
121 changes: 121 additions & 0 deletions configs/lad/lad_r101_paa_r50_fpn_coco_1x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_coco/paa_r50_fpn_1x_coco_20200821-936edec3.pth' # noqa
model = dict(
type='LAD',
# student
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained',
checkpoint='torchvision://resnet101')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
bbox_head=dict(
type='LADHead',
reg_decoded_bbox=True,
score_voting=True,
topk=9,
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)),
# teacher
teacher_ckpt=teacher_ckpt,
teacher_backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
teacher_neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
teacher_bbox_head=dict(
type='LADHead',
reg_decoded_bbox=True,
score_voting=True,
topk=9,
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.1,
neg_iou_thr=0.1,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
score_voting=True,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
data = dict(samples_per_gpu=8, workers_per_gpu=4)
optimizer = dict(lr=0.01)
fp16 = dict(loss_scale=512.)
120 changes: 120 additions & 0 deletions configs/lad/lad_r50_paa_r101_fpn_coco_1x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
teacher_ckpt = 'http://download.openmmlab.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_coco/paa_r101_fpn_1x_coco_20200821-0a1825a4.pth' # noqa
model = dict(
type='LAD',
# student
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
bbox_head=dict(
type='LADHead',
reg_decoded_bbox=True,
score_voting=True,
topk=9,
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)),
# teacher
teacher_ckpt=teacher_ckpt,
teacher_backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
teacher_neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
teacher_bbox_head=dict(
type='LADHead',
reg_decoded_bbox=True,
score_voting=True,
topk=9,
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.1,
neg_iou_thr=0.1,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
score_voting=True,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
data = dict(samples_per_gpu=8, workers_per_gpu=4)
optimizer = dict(lr=0.01)
fp16 = dict(loss_scale=512.)
Loading