Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YOLOv3 - Continue on #1695 #3083

Merged
merged 83 commits into from
Aug 31, 2020
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
9b5b39e
Implement YOLOv3
Nov 19, 2019
787def0
Remove unused function
Nov 20, 2019
589158c
Merge branch 'master' into yolo
wuhy08 Dec 2, 2019
722fef8
Update yolov3_ms_aug_273e.py
wuhy08 Mar 12, 2020
08ef678
Merge remote-tracking branch 'origin-mm/master' into yolo
wuhy08 Mar 12, 2020
d61c575
Add README.md
wuhy08 Mar 26, 2020
082a63f
Merge branch 'master' of github.com:open-mmlab/mmdetection into yolo-2.0
ElectronicElephant Jun 20, 2020
e09390e
port to mmdet-2.0 api
ElectronicElephant Jun 23, 2020
0edf116
unify registry
ElectronicElephant Jun 23, 2020
035793a
Merge pull request #1 from ElectronicElephant/yolo-dev
ElectronicElephant Jun 23, 2020
e13afb6
Merge pull request #2 from ElectronicElephant/yolo-dev
ElectronicElephant Jun 23, 2020
1b4ecfc
port to ConvModule and remove ConvLayer
ElectronicElephant Jun 29, 2020
1f64c24
Merge pull request #3 from ElectronicElephant/yolo-dev
ElectronicElephant Jun 29, 2020
75c60f5
Refactor Backbone
ElectronicElephant Jun 29, 2020
3e59739
Update README
ElectronicElephant Jun 29, 2020
4f24fa1
Lint and format
ElectronicElephant Jun 29, 2020
4f82cd5
Merge pull request #4 from ElectronicElephant/yolo-dev
ElectronicElephant Jun 29, 2020
0ff50f8
Unify the class name
ElectronicElephant Jun 29, 2020
fd0e591
fix the `label - 1` problem
ElectronicElephant Jun 29, 2020
0b4db8f
Unify the class name and fix the `label-1` problem
ElectronicElephant Jun 29, 2020
4c7a837
Move a lot hard-coded params to the __init__ function
ElectronicElephant Jun 29, 2020
a3219ff
Refactor YOLOV3Neck
ElectronicElephant Jun 29, 2020
8e64bee
Add norm_cfg and act_cfg to backbone
ElectronicElephant Jun 29, 2020
40a6bbb
Update Config
ElectronicElephant Jun 29, 2020
14a13ae
Merge pull request #6 from ElectronicElephant/yolo-dev
ElectronicElephant Jun 30, 2020
59c5e38
Fix doc string
ElectronicElephant Jun 30, 2020
fc7d002
Merge pull request #7 from ElectronicElephant/yolo-dev
ElectronicElephant Jun 30, 2020
12ccde0
Fix nms (thanks to @LMerCy)
ElectronicElephant Jul 3, 2020
f257251
Add doc string
ElectronicElephant Jul 3, 2020
797e3dc
Merge pull request #10 from ElectronicElephant/yolo-dev
ElectronicElephant Jul 3, 2020
556eac0
Update config
ElectronicElephant Jul 5, 2020
77d7a6d
Remove pretrained in head and neck
ElectronicElephant Jul 5, 2020
23fd70e
Add support for conv_cfg in neck
ElectronicElephant Jul 5, 2020
0958b3c
Update mmdet/models/dense_heads/yolo_head.py
ElectronicElephant Jul 5, 2020
80e8cdc
Update mmdet/models/dense_heads/yolo_head.py
ElectronicElephant Jul 5, 2020
55ef429
Fix README.md
ElectronicElephant Jul 5, 2020
067f651
Fix typos
ElectronicElephant Jul 5, 2020
4ebe16b
Resolve comments
ElectronicElephant Jul 5, 2020
7abf5de
update config
ElectronicElephant Jul 5, 2020
deb86e3
Merge pull request #12 from ElectronicElephant/yolo-dev
ElectronicElephant Jul 5, 2020
e7d5081
Merge branch 'master' into yolo-2.0-solve-conflicts
ElectronicElephant Jul 5, 2020
d0aecdb
flake8, yapf, docformatter, etc
ElectronicElephant Jul 5, 2020
a65b733
Update README
ElectronicElephant Jul 13, 2020
1cd4e77
Add conv_cfg to backbone and head
ElectronicElephant Jul 16, 2020
4f28a0f
Move some config to arch_settings in backbone
ElectronicElephant Jul 16, 2020
013d50c
Add doc strings and replace Warning with warnings.warn()
ElectronicElephant Jul 16, 2020
5f435b8
Fix bug.
ElectronicElephant Jul 16, 2020
bbde141
Update doc
ElectronicElephant Jul 16, 2020
f6784d5
Add _frozen_stages for backbone
ElectronicElephant Jul 21, 2020
a74355a
Update mmdet/models/backbones/darknet.py
ElectronicElephant Jul 21, 2020
04965c3
Fix inplace bug
ElectronicElephant Jul 21, 2020
44d6be3
Merge branch 'yolo-2.0-solve-conflicts' of github.com:ElectronicEleph…
ElectronicElephant Jul 21, 2020
7bbe4b4
fix indent
ElectronicElephant Jul 21, 2020
16b39d1
refactor config
xvjiarui Jul 26, 2020
b1a07ce
merge master>2.2.1
xvjiarui Jul 26, 2020
cfb5126
set 8GPU lr
xvjiarui Jul 26, 2020
48f791d
fixed typo
xvjiarui Jul 26, 2020
e3b6d26
update performance table
xvjiarui Aug 2, 2020
411180e
merge master
xvjiarui Aug 2, 2020
cce0b73
Resolve conversation
ElectronicElephant Aug 4, 2020
4b16235
Add anchor generator and coder
xvjiarui Aug 11, 2020
8b06764
fixed test
xvjiarui Aug 11, 2020
832e2d2
Finish refactor
xvjiarui Aug 13, 2020
4524789
Merge branch 'master' into yolo-2.0-solve-conflicts
xvjiarui Aug 13, 2020
a55f5b4
refactor anchor order
xvjiarui Aug 14, 2020
596e96e
fixed batch size
xvjiarui Aug 14, 2020
01722be
Fixed train_cfg
xvjiarui Aug 17, 2020
ec3b71c
fix yolo assigner
sudo-rm-covid19 Aug 18, 2020
00fb03e
clean up
xvjiarui Aug 19, 2020
c8577d3
Merge branch 'yolo-2.0-solve-conflicts' of https://github.com/sudo-rm…
xvjiarui Aug 19, 2020
ff7ff9d
Fixed format
xvjiarui Aug 19, 2020
5b8567d
Update model zoo
xvjiarui Aug 23, 2020
3d72a34
change to mmcv pretrain link
xvjiarui Aug 23, 2020
4c85d6c
add test forward
xvjiarui Aug 23, 2020
5649356
fixed comma and docstring
xvjiarui Aug 26, 2020
6d829eb
Refactor loss
xvjiarui Aug 26, 2020
d938a49
reformat
xvjiarui Aug 26, 2020
7271a5f
fixed avg_factor
xvjiarui Aug 27, 2020
aee61cf
revert to original
xvjiarui Aug 27, 2020
292a051
fixed format
xvjiarui Aug 28, 2020
1516dc8
update table
xvjiarui Aug 28, 2020
f526be0
fixed BCE
xvjiarui Aug 30, 2020
fd74165
merge master
xvjiarui Aug 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions configs/yolo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# YOLOv3

## Introduction
```
@misc{redmon2018yolov3,
title={YOLOv3: An Incremental Improvement},
author={Joseph Redmon and Ali Farhadi},
year={2018},
eprint={1804.02767},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

## Results and Models

| Backbone | Train Scale | Lr schd | Mem (GB) | Eval Scale | Inf time (fps) | box AP | Download |
| :-------------: | :----------: | :-----: | :------: | :--------: | :------------: | :----: |:--------:|
| DarkNet-53 | Multi-Scale | 273e | 1.8 | 608 * 608 | 44 |**37.6**| [model](https://drive.google.com/file/d/1Ca27fP4hlBFduMCv5b_f-0J9EdfxCgPb/view?usp=sharing) | [log](https://github.com/open-mmlab/mmdetection/files/4910982/log.zip) |
| - | - | - | - | 416 * 416 | **64** | 34.8 | - |


## Credit
This implementation originates from the project of Haoyu Wu(@wuhy08) at Western Digital.
118 changes: 118 additions & 0 deletions configs/yolo/yolov3_d53_yolo_mstrain_273e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) 2019 Western Digital Corporation or its affiliates.
_base_ = [
'../_base_/default_runtime.py',
]
# model settings
model = dict(
type='YOLOV3',
pretrained='./work_dirs/darknet_state_dict_only.pth',
backbone=dict(
type='Darknet',
depth=53,
out_indices=(3, 4, 5),
),
neck=dict(
type='YOLOV3Neck',
num_scales=3,
in_channels=[1024, 512, 256],
out_channels=[512, 256, 128],
),
bbox_head=dict(
type='YOLOV3Head',
num_classes=80,
num_scales=3,
num_anchors_per_scale=3,
in_channels=[512, 256, 128],
out_channels=[1024, 512, 256],
strides=[32, 16, 8],
anchor_base_sizes=[
[(116, 90), (156, 198), (373, 326)],
[(30, 61), (62, 45), (59, 119)],
[(10, 13), (16, 30), (33, 23)],
],
))
# training and testing settings
train_cfg = dict(
one_hot_smoother=0., ignore_config=0.5, xy_use_logit=False, debug=False)
test_cfg = dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
conf_thr=0.005,
nms=dict(type='nms', iou_thr=0.45),
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(mean=[0, 0, 0], std=[255., 255., 255.], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='PhotoMetricDistortion'),
dict(
type='Expand',
mean=img_norm_cfg['mean'],
to_rgb=img_norm_cfg['to_rgb'],
ratio_range=(1, 2)),
dict(
type='MinIoURandomCrop',
min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
min_crop_size=0.3),
dict(type='Resize', img_scale=[(320, 320), (608, 608)], keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(608, 608),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=8,
workers_per_gpu=8,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline,
),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline,
),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline,
))
# optimizer
optimizer = dict(type='SGD', lr=5e-4, momentum=0.9, weight_decay=0.0005)
ElectronicElephant marked this conversation as resolved.
Show resolved Hide resolved
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=2000, # same as burn-in in darknet
warmup_ratio=0.1,
step=[218, 246])
# runtime settings
total_epochs = 273
work_dir = './work_dirs/yolo_pretrained'
evaluation = dict(interval=1, metric=['bbox'])
find_unused_parameters = True
3 changes: 2 additions & 1 deletion mmdet/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .darknet import Darknet
from .detectors_resnet import DetectoRS_ResNet
from .detectors_resnext import DetectoRS_ResNeXt
from .hourglass import HourglassNet
Expand All @@ -10,5 +11,5 @@

__all__ = [
'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net',
'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt'
'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet'
]
190 changes: 190 additions & 0 deletions mmdet/models/backbones/darknet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) 2019 Western Digital Corporation or its affiliates.

import logging

import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm

from ..builder import BACKBONES


class ResBlock(nn.Module):
"""The basic residual block used in YoloV3. Each ResBlock consists of two
ElectronicElephant marked this conversation as resolved.
Show resolved Hide resolved
ConvModules and the input is added to the final output. Each ConvModule is
composed of Conv, BN, and LeakyReLU In YoloV3 paper, the first convLayer
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
has half of the number of the filters as much as the second convLayer. The
first convLayer has filter size of 1x1 and the second one has the filter
size of 3x3.

Args:
in_channels (int): The input channels. Must be even.
ElectronicElephant marked this conversation as resolved.
Show resolved Hide resolved
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True)
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
"""

def __init__(self,
in_channels,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='LeakyReLU', negative_slope=0.1)):
super(ResBlock, self).__init__()
assert in_channels % 2 == 0 # ensure the in_channels is even
half_in_channels = in_channels // 2

# shortcut
cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)

self.conv1 = ConvModule(in_channels, half_in_channels, 1, **cfg)
self.conv2 = ConvModule(
half_in_channels, in_channels, 3, padding=1, **cfg)

def forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
out += residual
ElectronicElephant marked this conversation as resolved.
Show resolved Hide resolved

return out


def make_conv_and_res_block(in_channels,
ElectronicElephant marked this conversation as resolved.
Show resolved Hide resolved
out_channels,
res_repeat,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
ElectronicElephant marked this conversation as resolved.
Show resolved Hide resolved
act_cfg=dict(type='LeakyReLU',
negative_slope=0.1)):
"""In Darknet backbone, ConvLayer is usually followed by ResBlock. This
function will make that. The Conv layers always have 3x3 filters with
stride=2. The number of the filters in Conv layer is the same as the out
channels of the ResBlock.

Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
res_repeat (int): The number of ResBlocks.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True)
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
"""

cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)

model = nn.Sequential()
model.add_module(
'conv',
ConvModule(in_channels, out_channels, 3, stride=2, padding=1, **cfg))
for idx in range(res_repeat):
model.add_module('res{}'.format(idx), ResBlock(out_channels, **cfg))
return model


@BACKBONES.register_module()
class Darknet(nn.Module):
"""Darknet backbone.

Args:
depth (int): Depth of Darknet. Currently only support 53.
out_indices (Sequence[int]): Output from which stages.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True)
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.

Example:
>>> from mmdet.models import Darknet
>>> import torch
>>> self = Darknet(depth=53)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 416, 416)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 256, 52, 52)
(1, 512, 26, 26)
(1, 1024, 13, 13)
"""

# Dict(depth: (layers, channels))
arch_settings = {
53: ((1, 2, 8, 8, 4), ((32, 64), (64, 128), (128, 256), (256, 512),
(512, 1024)))
}

def __init__(self,
depth=53,
out_indices=(3, 4, 5),
ElectronicElephant marked this conversation as resolved.
Show resolved Hide resolved
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
norm_eval=True):
super(Darknet, self).__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for darknet')
self.depth = depth
self.out_indices = out_indices
self.layers, self.channels = self.arch_settings[depth]

cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)

self.conv1 = ConvModule(3, 32, 3, padding=1, **cfg)

self.cr_blocks = ['conv1']
for i, n_layers in enumerate(self.layers):
layer_name = f'cr_block{i + 1}'
in_c, out_c = self.channels[i]
self.add_module(
layer_name,
make_conv_and_res_block(in_c, out_c, n_layers, **cfg))
self.cr_blocks.append(layer_name)

self.norm_eval = norm_eval

def forward(self, x):
outs = []
for i, layer_name in enumerate(self.cr_blocks):
cr_block = getattr(self, layer_name)
x = cr_block(x)
if i in self.out_indices:
outs.append(x)

return tuple(outs)

def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)

else:
raise TypeError('pretrained must be a str or None')

def _freeze_stages(self):
for param in self.parameters():
param.requires_grad = False
ElectronicElephant marked this conversation as resolved.
Show resolved Hide resolved

def train(self, mode=True):
super(Darknet, self).train(mode)
self._freeze_stages()
ElectronicElephant marked this conversation as resolved.
Show resolved Hide resolved
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
ElectronicElephant marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(m, _BatchNorm):
m.eval()
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead
from .ssd_head import SSDHead
from .yolo_head import YOLOV3Head

__all__ = [
'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption',
'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead',
'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead',
'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead',
'PISARetinaHead', 'PISASSDHead', 'GFLHead'
'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'YOLOV3Head'
]
Loading