Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#38 from heavengate/fix_yolo_backbone
Browse files Browse the repository at this point in the history
fix yolo backbone
  • Loading branch information
heavengate authored Apr 15, 2020
2 parents 308447b + 0e8b317 commit f9f2d42
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 9 deletions.
1 change: 1 addition & 0 deletions examples/yolov3/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
dataset/voc*
pretrain_weights/darknet53_pretrained.pdparams
11 changes: 8 additions & 3 deletions examples/yolov3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,17 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层
| ...
```

```bash
sh pretrain_weights/download.sh
```

### 模型训练

数据准备完毕后,可使用`main.py`脚本启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`yolo_checkpoint`目录下。
数据准备完成后,可使用`main.py`脚本启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`yolo_checkpoint`目录下。

YOLOv3模型训练总batch_size为64训练,以下以使用4卡Tesla P40每卡batch_size为16训练介绍训练方式。对于静态图和动态图,多卡训练中`--batch_size`为每卡上的batch_size,即总batch_size为`--batch_size`乘以卡数。

YOLOv3模型训练须加载骨干网络[DarkNet53]()的预训练权重,可在训练时通过`--pretrain_weights`指定,若指定为URL,将自动下载权重至`~/.cache/paddle/weights`目录并加载。

`main.py`脚本参数可通过如下命令查询

Expand All @@ -117,7 +122,7 @@ python main.py --help
使用如下方式进行多卡训练:

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --data=<path/to/dataset> --batch_size=16
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --data=<path/to/dataset> --batch_size=16 --pretrain_weights=https://paddlemodels.bj.bcebos.com/hapi/darknet53_pretrained.pdparams
```

#### 动态图训练
Expand All @@ -127,7 +132,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --data=
使用如下方式进行多卡训练:

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py -m paddle.distributed.launch --data=<path/to/dataset> --batch_size=16 -d
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py -m paddle.distributed.launch --data=<path/to/dataset> --batch_size=16 -d --pretrain_weights=https://paddlemodels.bj.bcebos.com/hapi/darknet53_pretrained.pdparams
```


Expand Down
11 changes: 8 additions & 3 deletions examples/yolov3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from hapi.model import Model, Input, set_device
from hapi.distributed import DistributedBatchSampler
from hapi.download import is_url, get_weights_path
from hapi.datasets import COCODataset
from hapi.vision.transforms import *
from hapi.vision.models import yolov3_darknet53, YoloLoss
Expand Down Expand Up @@ -124,8 +125,11 @@ def main():
model_mode='eval' if FLAGS.eval_only else 'train',
pretrained=pretrained)

if FLAGS.pretrain_weights is not None:
model.load(FLAGS.pretrain_weights, skip_mismatch=True, reset_optimizer=True)
if FLAGS.pretrain_weights and not FLAGS.eval_only:
pretrain_weights = FLAGS.pretrain_weights
if is_url(pretrain_weights):
pretrain_weights = get_weights_path(pretrain_weights)
model.load(pretrain_weights, skip_mismatch=True, reset_optimizer=True)

optim = make_optimizer(len(batch_sampler), parameter_list=model.parameters())

Expand Down Expand Up @@ -196,7 +200,8 @@ def main():
parser.add_argument(
"-j", "--num_workers", default=4, type=int, help="reader worker number")
parser.add_argument(
"-p", "--pretrain_weights", default=None, type=str,
"-p", "--pretrain_weights",
default="./pretrain_weights/darknet53_pretrained", type=str,
help="path to pretrained weights")
parser.add_argument(
"-r", "--resume", default=None, type=str,
Expand Down
12 changes: 11 additions & 1 deletion hapi/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,22 @@
import logging
logger = logging.getLogger(__name__)

__all__ = ['get_weights_path']
__all__ = ['get_weights_path', 'is_url']

WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights")

DOWNLOAD_RETRY_LIMIT = 3


def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')


def get_weights_path(url, md5sum=None):
"""Get weights path from WEIGHT_HOME, if not exists,
download it from url.
Expand All @@ -62,6 +71,7 @@ def get_path(url, root_dir, md5sum=None, check_exist=True):
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
"""
assert is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir
fullpath = map_path(url, root_dir)

Expand Down
7 changes: 7 additions & 0 deletions hapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,13 @@ def _check_match(key, param):
format(key, list(state.shape), list(param.shape)))
return param, state

def _strip_postfix(path):
path, ext = os.path.splitext(path)
assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
"Unknown postfix {} from weights".format(ext)
return path

path = _strip_postfix(path)
param_state = _load_state_from_path(path + ".pdparams")
assert param_state, "Failed to load parameters, please check path."

Expand Down
2 changes: 1 addition & 1 deletion hapi/vision/models/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def forward(self,inputs):
DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}


class DarkNet(Model):
class DarkNet(fluid.dygraph.Layer):
"""DarkNet model from
`"YOLOv3: An Incremental Improvement" <https://arxiv.org/abs/1804.02767>`_
Expand Down
2 changes: 1 addition & 1 deletion hapi/vision/models/yolov3.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, num_classes=80, model_mode='train'):
self.nms_posk = 100
self.draw_thresh = 0.5

self.backbone = darknet53(pretrained=(model_mode=='train'))
self.backbone = darknet53(pretrained=False)
self.block_outputs = []
self.yolo_blocks = []
self.route_blocks = []
Expand Down

0 comments on commit f9f2d42

Please sign in to comment.