Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yolox improve with REPConv/ASFF/TOOD #154

Merged
merged 85 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
1167efd
try pr
May 16, 2022
9988716
add attention layer and more loss function
May 17, 2022
4f55bda
add attention layer and various loss functions
May 19, 2022
1f4b054
solve the conflict
May 19, 2022
18db2c5
Merge branch 'alibaba:master' into master
zouxinyi0625 May 29, 2022
42472ff
add siou loss
zouxinyi0625 May 31, 2022
a1e5bc1
add tah,various attention layers, and different loss functions
zouxinyi0625 Jun 13, 2022
d0c1297
rersolve conflict
zouxinyi0625 Jun 28, 2022
19388ed
Merge branch 'alibaba:master' into master
zouxinyi0625 Jul 11, 2022
a88c747
fix export error
zouxinyi0625 Jul 11, 2022
a837d97
add asff sim, gsconv
zouxinyi0625 Jul 27, 2022
425d6c2
Merge branch 'alibaba:master' into master
zouxinyi0625 Jul 27, 2022
2e87257
fix config
haiasd Jul 27, 2022
7b08d4f
merge yolox-pai with master
haiasd Aug 5, 2022
f0f7ffc
Merge branch 'alibaba:master' into master
zouxinyi0625 Aug 5, 2022
d9bb9af
Merge branch 'master' into yolox_improve
haiasd Aug 5, 2022
5bc6e00
fix repvgg_yolox_backbone refer2 repvgg
haiasd Aug 5, 2022
e812089
fix asff tood code
haiasd Aug 5, 2022
6a9241e
fix bug
haiasd Aug 6, 2022
bc2b238
fix bug
haiasd Aug 6, 2022
e9e0dfc
fix tood act bug
haiasd Aug 6, 2022
bb16a3f
blade utils fit faster
haiasd Aug 8, 2022
bbdfb9f
blade optimize for yolox static & fp16
Aug 8, 2022
2f5c6c4
decode output for yolox control by cfg
Aug 8, 2022
71c6ea2
fix some bug
Aug 8, 2022
6fcb293
fix tood stem to repconv
Aug 9, 2022
5ecb695
fix tood interconv with repconv
Aug 9, 2022
2346aaf
add reparameterize_models for export
Aug 9, 2022
4066723
Merge branch 'alibaba:master' into yolox_improve
Aug 9, 2022
b1d7b6c
Merge branch 'master' into yolox_improve
Aug 9, 2022
437342c
pre-commit fix
Aug 9, 2022
166fe9b
pre-commit fix 1
Aug 9, 2022
48975bb
pre-commit fix 2
Aug 9, 2022
58e3f0c
pre-commit fix 3
Aug 9, 2022
95028f4
pre-commit fix 4
Aug 9, 2022
33e0740
Merge branch 'yolox_improve' of github.com:wuziheng/EasyCV into yolox…
Aug 9, 2022
1d21505
fix yolox configs
Aug 9, 2022
ab87f09
fix lint bug
Aug 9, 2022
ce4e7b5
fix lint bug1
Aug 9, 2022
502a9dc
fix make_divisible
Aug 10, 2022
6379b27
lint
Aug 10, 2022
47f7e2e
first version check by zl
Aug 10, 2022
81a2051
rm compute model params
Aug 10, 2022
cac888a
modify asff v1
zouxinyi0625 Aug 12, 2022
8dbe360
add trt nms demo for detection
Aug 12, 2022
feb0eeb
Merge branch 'yolox_improve_new' of github.com:wuziheng/EasyCV into y…
Aug 12, 2022
b8b2646
merge ASFF & ASFF_sim and registbackbone & head
zouxinyi0625 Aug 13, 2022
7e07113
merge ASFF & ASFF_sim and registbackbone & head
zouxinyi0625 Aug 13, 2022
e1e9e8a
fix cr mentioned bug
zouxinyi0625 Aug 14, 2022
2f7534f
remove useless file
zouxinyi0625 Aug 14, 2022
a364d60
fix cr bug
zouxinyi0625 Aug 15, 2022
10f549e
fix cr problem
zouxinyi0625 Aug 15, 2022
1eb0a4b
fix cr problem
zouxinyi0625 Aug 15, 2022
08bbba8
fix ut problem
zouxinyi0625 Aug 15, 2022
a8ba8fd
e2e trt_nms plugin export support and numeric test
Aug 15, 2022
7844e64
merge fix
Aug 15, 2022
77973f1
fix bug
zouxinyi0625 Aug 16, 2022
67140ae
Merge branch 'yolox_improve_new' of https://github.com/wuziheng/EasyC…
Aug 16, 2022
e8b6607
fix interface for yolox use_trt_efficientnms
Aug 16, 2022
525f7c0
split preprocess from end2end+blade, speedup from 17ms->7.2ms
Aug 16, 2022
cfde0cd
fix cr bug
zouxinyi0625 Aug 16, 2022
4ffd599
Merge branch 'yolox_improve_new' of github.com:wuziheng/EasyCV into c…
zouxinyi0625 Aug 16, 2022
600e2c2
fix ut
zouxinyi0625 Aug 17, 2022
68454fd
fix config bug
zouxinyi0625 Aug 17, 2022
a6236d3
refractor export and restore yolox_edge
zouxinyi0625 Aug 18, 2022
47c4581
fix jit/blade bug
zouxinyi0625 Aug 19, 2022
eeb269a
change ut
zouxinyi0625 Aug 20, 2022
53183ec
save conflict of YOLOXLrUpdaterHook
zouxinyi0625 Aug 20, 2022
1fefb7e
remove useless uttest
zouxinyi0625 Aug 21, 2022
93c0254
ut
zouxinyi0625 Aug 22, 2022
c22bf98
update yolox.md
zouxinyi0625 Aug 22, 2022
887e2be
update export.md
zouxinyi0625 Aug 22, 2022
b865306
format
zouxinyi0625 Aug 22, 2022
b0fa9e6
complete result
zouxinyi0625 Aug 22, 2022
a7a914c
hook
zouxinyi0625 Aug 22, 2022
3ff49fe
fix ut
zouxinyi0625 Aug 22, 2022
55c3c46
fix ut bug and cr problem
zouxinyi0625 Aug 23, 2022
3490d45
skip correct ut
zouxinyi0625 Aug 23, 2022
991d2ce
ut
zouxinyi0625 Aug 23, 2022
99e0d4c
ut
zouxinyi0625 Aug 23, 2022
3eb7cd2
ut
zouxinyi0625 Aug 23, 2022
1b063fc
ut
zouxinyi0625 Aug 23, 2022
26b458f
ut
zouxinyi0625 Aug 23, 2022
e08509e
fix cr problem
zouxinyi0625 Aug 24, 2022
59a3916
fix cr problem
zouxinyi0625 Aug 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/detection/yolox/yolox_l_8xb8_300e_coco.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = './yolox_s_8xb16_300e_coco.py'
_base_ = 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why modify base path?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change as the base path done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imitate configs/detection/fcos,reconstruct config

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


# model settings
model = dict(model_type='l')
Expand Down
41 changes: 36 additions & 5 deletions easycv/apis/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from easycv.file import io
from easycv.models import (DINO, MOCO, SWAV, YOLOX, Classification, MoBY,
build_model)
from easycv.models.backbones.repvgg_yolox_backbone import RepVGGBlock
from easycv.utils.bbox_util import scale_coords
from easycv.utils.checkpoint import load_checkpoint

Expand All @@ -21,6 +22,24 @@
]


def reparameterize_models(model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems not common, move to model.forward_export

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks

""" reparameterize model for inference, especially for
1. rep conv block : merge 3x3 weight 1x1 weights
call module switch_to_deploy recursively
Args:
model: nn.Module
"""
reparameterize_count = 0
for layer in model.modules():
reparameterize_count += 1
if isinstance(layer, RepVGGBlock):
layer.switch_to_deploy()
logging.info(
'export : PAI-export reparameterize_count(RepVGGBlock, ) switch to deploy with {} blocks'
.format(reparameterize_count))
return model


def export(cfg, ckpt_path, filename):
""" export model for inference

Expand All @@ -34,6 +53,7 @@ def export(cfg, ckpt_path, filename):
load_checkpoint(model, ckpt_path, map_location='cpu')
else:
cfg.model.backbone.pretrained = False
model = reparameterize_models(model)

if isinstance(model, MOCO) or isinstance(model, DINO):
_export_moco(model, cfg, filename)
Expand Down Expand Up @@ -171,6 +191,7 @@ def _export_yolox(model, cfg, filename):
raise ValueError('`end2end` only support torch1.7.0 and later!')

batch_size = cfg.export.get('batch_size', 1)
static_opt = cfg.export.get('static_opt', True)
img_scale = cfg.get('img_scale', (640, 640))
assert (
len(img_scale) == 2
Expand All @@ -194,7 +215,6 @@ def _export_yolox(model, cfg, filename):
# use trace is a litter bit faster than script. But it is not supported in an end2end model.
if end2end:
yolox_trace = torch.jit.script(model_export)

else:
yolox_trace = torch.jit.trace(model_export, input.to(device))

Expand All @@ -207,13 +227,17 @@ def _export_yolox(model, cfg, filename):
assert blade_env_assert()

if end2end:
input = 255 * torch.rand(img_scale + (3, ))
if batch_size == 1:
input = 255 * torch.rand(img_scale + (3, ))
else:
input = 255 * torch.rand(img_scale + (3, batch_size))

yolox_blade = blade_optimize(
script_model=model,
model=yolox_trace,
inputs=(input.to(device), ),
blade_config=blade_config)
blade_config=blade_config,
static_opt=static_opt)

with io.open(filename + '.blade', 'wb') as ofile:
torch.jit.save(yolox_blade, ofile)
Expand Down Expand Up @@ -644,7 +668,15 @@ def __init__(self,

self.example_inputs = example_inputs
self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn
self.ignore_postprocess = getattr(self.model, 'ignore_postprocess',
False)
if not self.ignore_postprocess:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add ignore_postprocess? postprocess_fn=None is already supported. ignore_postprocess also should be removed from model

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but this is some confused design ,we will try to fix it.

self.postprocess_fn = postprocess_fn
else:
self.postprocess_fn = None
logging.warning(
'Model {} ignore_postprocess set to be {} during export !'.format(
type(model), self.ignore_postprocess))
self.trace_model = trace_model
if self.trace_model:
self.trace_module()
Expand All @@ -669,7 +701,6 @@ def forward(self, image):
image = output

model_output = self.model.forward_export(image)

if self.postprocess_fn is not None:
model_output = self.postprocess_fn(model_output,
*preprocess_outputs)
Expand Down
1 change: 1 addition & 0 deletions easycv/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def train_model(model,

# SyncBatchNorm
open_sync_bn = cfg.get('sync_bn', False)

wenmengzhou marked this conversation as resolved.
Show resolved Hide resolved
if open_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
logger.info('Using SyncBatchNorm()')
Expand Down
3 changes: 3 additions & 0 deletions easycv/datasets/detection/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def visualize(self, results, vis_num=10, score_thr=0.3, **kwargs):
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
"""
import copy
results = copy.deepcopy(results)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add deepcopy?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to fix the map=0 problem when evaluation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is already fixed, please refer to: #67

remove deepcopy


class_names = None
if hasattr(self.data_source, 'CLASSES'):
class_names = self.data_source.CLASSES
Expand Down
3 changes: 3 additions & 0 deletions easycv/hooks/yolox_mode_switch_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ def before_train_epoch(self, runner):
train_loader.dataset.update_skip_type_keys(self.skip_type_keys)
runner.logger.info('Add additional L1 loss now!')
model.head.use_l1 = True

if hasattr(runner.model.module, 'epoch_counter'):
runner.model.module.epoch_counter = epoch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's epoch_counter for?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use to change the label assigner before but cannot get a better result. remove it now.

1 change: 1 addition & 0 deletions easycv/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .mnasnet import MNASNet
from .mobilenetv2 import MobileNetV2
from .pytorch_image_models_wrapper import *
from .repvgg_yolox_backbone import RepVGGYOLOX
from .resnest import ResNeSt
from .resnet import ResNet
from .resnet_jit import ResNetJIT
Expand Down
107 changes: 73 additions & 34 deletions easycv/models/backbones/darknet.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import torch
from torch import nn

from .network_blocks import (BaseConv, CSPLayer, DWConv, Focus, ResLayer,
SPPBottleneck)
SPPBottleneck, SPPFBottleneck)


class Darknet(nn.Module):
# number of blocks from dark2 to dark5.
depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]}

def __init__(
self,
depth,
in_channels=3,
stem_out_channels=32,
out_features=('dark3', 'dark4', 'dark5'),
):
def __init__(self,
depth,
in_channels=3,
stem_out_channels=32,
out_features=('dark3', 'dark4', 'dark5'),
spp_type='spp'):
"""
Args:
depth (int): depth of darknet used in model, usually use [21, 53] for this param.
Expand Down Expand Up @@ -49,11 +49,18 @@ def __init__(
*self.make_group_layer(in_channels, num_blocks[2], stride=2))
in_channels *= 2 # 512

self.dark5 = nn.Sequential(
*self.make_group_layer(in_channels, num_blocks[3], stride=2),
*self.make_spp_block([in_channels, in_channels * 2],
in_channels * 2),
)
if spp_type == 'spp':
self.dark5 = nn.Sequential(
*self.make_group_layer(in_channels, num_blocks[3], stride=2),
*self.make_spp_block([in_channels, in_channels * 2],
in_channels * 2),
)
elif spp_type == 'sppf':
self.dark5 = nn.Sequential(
*self.make_group_layer(in_channels, num_blocks[3], stride=2),
*self.make_sppf_block([in_channels, in_channels * 2],
in_channels * 2),
)

def make_group_layer(self,
in_channels: int,
Expand Down Expand Up @@ -87,6 +94,23 @@ def make_spp_block(self, filters_list, in_filters):
])
return m

def make_sppf_block(self, filters_list, in_filters):
m = nn.Sequential(*[
BaseConv(in_filters, filters_list[0], 1, stride=1, act='lrelu'),
BaseConv(
filters_list[0], filters_list[1], 3, stride=1, act='lrelu'),
SPPBottleneck(
in_channels=filters_list[1],
out_channels=filters_list[0],
activation='lrelu',
),
BaseConv(
filters_list[0], filters_list[1], 3, stride=1, act='lrelu'),
BaseConv(
filters_list[1], filters_list[0], 1, stride=1, act='lrelu'),
])
return m

def forward(self, x):
outputs = {}
x = self.stem(x)
Expand All @@ -104,14 +128,13 @@ def forward(self, x):

class CSPDarknet(nn.Module):

def __init__(
self,
dep_mul,
wid_mul,
out_features=('dark3', 'dark4', 'dark5'),
depthwise=False,
act='silu',
):
def __init__(self,
dep_mul,
wid_mul,
out_features=('dark3', 'dark4', 'dark5'),
depthwise=False,
act='silu',
spp_type='spp'):
super().__init__()
assert out_features, 'please provide output features of Darknet'
self.out_features = out_features
Expand Down Expand Up @@ -160,19 +183,35 @@ def __init__(
)

# dark5
self.dark5 = nn.Sequential(
Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
SPPBottleneck(
base_channels * 16, base_channels * 16, activation=act),
CSPLayer(
base_channels * 16,
base_channels * 16,
n=base_depth,
shortcut=False,
depthwise=depthwise,
act=act,
),
)
if spp_type == 'spp':
wenmengzhou marked this conversation as resolved.
Show resolved Hide resolved
self.dark5 = nn.Sequential(
Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
SPPBottleneck(
base_channels * 16, base_channels * 16, activation=act),
CSPLayer(
base_channels * 16,
base_channels * 16,
n=base_depth,
shortcut=False,
depthwise=depthwise,
act=act,
),
)

elif spp_type == 'sppf':
self.dark5 = nn.Sequential(
Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
SPPFBottleneck(
base_channels * 16, base_channels * 16, activation=act),
CSPLayer(
base_channels * 16,
base_channels * 16,
n=base_depth,
shortcut=False,
depthwise=depthwise,
act=act,
),
)

def forward(self, x):
outputs = {}
Expand Down
Loading