Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]: Refactor SSD #5291

Merged
merged 17 commits into from
Jun 22, 2021
Prev Previous commit
Next Next commit
change ssdvgg backbone
RangiLyu committed Jun 2, 2021
commit 168d5fe834a19eaff665fed833da0d3d63e2bd53
119 changes: 119 additions & 0 deletions configs/ssd/ssd_vgg16_caffe_300_10x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py'
]
# model settings
input_size = 300
model = dict(
type='SingleStageDetector',
pretrained='open-mmlab://vgg16_caffe',
backbone=dict(
type='SSDVGG',
depth=16,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34)),
neck=dict(
type='SSDNeck',
in_channels=(512, 1024),
out_channels=(512, 1024, 512, 256, 256, 256),
level_strides=(2, 2, 1, 1),
level_paddings=(1, 1, 0, 0),
is_vgg_neck=False,
l2_norm_scale=20),
bbox_head=dict(
type='SSDHead',
in_channels=(512, 1024, 512, 256, 256, 256),
num_classes=80,
anchor_generator=dict(
type='SSDAnchorGenerator',
scale_major=False,
input_size=input_size,
basesize_ratio_range=(0.15, 0.9),
strides=[8, 16, 32, 64, 100, 300],
ratios=[[2], [2, 3], [2, 3], [2, 3], [2], [2]]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2])),
# model training and testing settings
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.,
ignore_iof_thr=-1,
gt_max_assign_all=False),
smoothl1_beta=1.,
allowed_border=-1,
pos_weight=-1,
neg_pos_ratio=3,
debug=False),
test_cfg=dict(
nms_pre=1000,
nms=dict(type='nms', iou_threshold=0.45),
min_bbox_size=0,
score_thr=0.02,
max_per_img=200))
cudnn_benchmark = True

# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[1, 1, 1], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PhotoMetricDistortion',
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18),
dict(
type='Expand',
mean=img_norm_cfg['mean'],
to_rgb=img_norm_cfg['to_rgb'],
ratio_range=(1, 4)),
dict(
type='MinIoURandomCrop',
min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
min_crop_size=0.3),
dict(type='Resize', img_scale=(300, 300), keep_ratio=False),
dict(type='Normalize', **img_norm_cfg),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(300, 300),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=False),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=8,
workers_per_gpu=3,
train=dict(
_delete_=True,
type='RepeatDataset',
times=5,
dataset=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline)),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=2e-3, momentum=0.9, weight_decay=5e-4)
optimizer_config = dict(_delete_=True)
84 changes: 2 additions & 82 deletions mmdet/models/backbones/ssd_vgg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import VGG
from mmcv.runner import BaseModule, Sequential
from mmcv.runner import BaseModule

from ..builder import BACKBONES

@@ -40,13 +38,11 @@ class SSDVGG(VGG, BaseModule):
}

def __init__(self,
input_size,
depth,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
l2_norm_scale=20.,
pretrained=None,
init_cfg=None):
# TODO: in_channels for mmcv.VGG
@@ -55,8 +51,6 @@ def __init__(self,
with_last_pool=with_last_pool,
ceil_mode=ceil_mode,
out_indices=out_indices)
assert input_size in (300, 512)
self.input_size = input_size

self.features.add_module(
str(len(self.features)),
@@ -72,12 +66,6 @@ def __init__(self,
str(len(self.features)), nn.ReLU(inplace=True))
self.out_feature_indices = out_feature_indices

self.inplanes = 1024
self.extra = self._make_extra_layers(self.extra_setting[input_size])
self.l2_norm = L2Norm(
self.features[out_feature_indices[0] - 1].out_channels,
l2_norm_scale)

assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
@@ -94,18 +82,6 @@ def __init__(self,
else:
raise TypeError('pretrained must be a str or None')

if init_cfg is None:
self.init_cfg += [
dict(
type='Xavier',
distribution='uniform',
override=dict(name='extra')),
dict(
type='Constant',
val=self.l2_norm.scale,
override=dict(name='l2_norm'))
]

def init_weights(self, pretrained=None):
super(VGG, self).init_weights()

@@ -116,64 +92,8 @@ def forward(self, x):
x = layer(x)
if i in self.out_feature_indices:
outs.append(x)
for i, layer in enumerate(self.extra):
x = F.relu(layer(x), inplace=True)
if i % 2 == 1:
outs.append(x)
outs[0] = self.l2_norm(outs[0])

if len(outs) == 1:
return outs[0]
else:
return tuple(outs)

def _make_extra_layers(self, outplanes):
layers = []
kernel_sizes = (1, 3)
num_layers = 0
outplane = None
for i in range(len(outplanes)):
if self.inplanes == 'S':
self.inplanes = outplane
continue
k = kernel_sizes[num_layers % 2]
if outplanes[i] == 'S':
outplane = outplanes[i + 1]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=2, padding=1)
else:
outplane = outplanes[i]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=1, padding=0)
layers.append(conv)
self.inplanes = outplanes[i]
num_layers += 1
if self.input_size == 512:
layers.append(nn.Conv2d(self.inplanes, 256, 4, padding=1))

return Sequential(*layers)


class L2Norm(nn.Module):

def __init__(self, n_dims, scale=20., eps=1e-10):
"""L2 normalization layer.

Args:
n_dims (int): Number of dimensions to be normalized
scale (float, optional): Defaults to 20..
eps (float, optional): Used to avoid division by zero.
Defaults to 1e-10.
"""
super(L2Norm, self).__init__()
self.n_dims = n_dims
self.weight = nn.Parameter(torch.Tensor(self.n_dims))
self.eps = eps
self.scale = scale

def forward(self, x):
"""Forward function."""
# normalization layer convert to FP32 in FP16 training
x_float = x.float()
norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
return (self.weight[None, :, None, None].float().expand_as(x_float) *
x_float / norm).type_as(x)