Skip to content

Commit

Permalink
improve amp training and fix nan error (#8305)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 authored Jun 5, 2023
1 parent 129ddbb commit 2277804
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 10 deletions.
3 changes: 3 additions & 0 deletions configs/keypoint/tiny_pose/tinypose_128x96.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ trainsize: &trainsize [*train_width, *train_height]
hmsize: &hmsize [24, 32]
flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]

# AMP training
init_loss_scaling: 32752
master_grad: true

#####model
architecture: TopDownHRNet
Expand Down
3 changes: 3 additions & 0 deletions configs/ppyolo/_base_/ppyolov2_r50vd_dcn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ norm_type: sync_bn
use_ema: true
ema_decay: 0.9998

# AMP training
master_grad: true

YOLOv3:
backbone: ResNet
neck: PPYOLOPAN
Expand Down
3 changes: 3 additions & 0 deletions configs/ppyolo/ppyolo_mbv3_large_coco.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ _BASE_: [
snapshot_epoch: 10
weights: output/ppyolo_mbv3_large_coco/model_final

# AMP training
master_grad: true

TrainReader:
inputs_def:
num_max_boxes: 90
Expand Down
18 changes: 14 additions & 4 deletions ppdet/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, cfg, mode='train'):
self.amp_level = self.cfg.get('amp_level', 'O1')
self.custom_white_list = self.cfg.get('custom_white_list', None)
self.custom_black_list = self.cfg.get('custom_black_list', None)
self.use_master_grad = self.cfg.get('master_grad', False)
if 'slim' in cfg and cfg['slim_type'] == 'PTQ':
self.cfg['TestDataset'] = create('TestDataset')()

Expand Down Expand Up @@ -180,10 +181,19 @@ def __init__(self, cfg, mode='train'):
self.pruner = create('UnstructuredPruner')(self.model,
steps_per_epoch)
if self.use_amp and self.amp_level == 'O2':
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level)
paddle_version = paddle.__version__[:3]
# paddle version >= 2.5.0 or develop
if paddle_version in ["2.5", "0.0"]:
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level,
master_grad=self.use_master_grad)
else:
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level)
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
if self.use_ema:
ema_decay = self.cfg.get('ema_decay', 0.9998)
Expand Down
2 changes: 2 additions & 0 deletions ppdet/modeling/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ def forward(self, x):
padding=self.block_size // 2,
data_format=self.data_format)
mask = 1. - mask_inv
mask = mask.astype('float32')
x = x.astype('float32')
y = x * mask * (mask.numel() / mask.sum())
return y

Expand Down
5 changes: 3 additions & 2 deletions ppdet/modeling/losses/yolo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def forward(self, inputs, targets, anchors):
self.distill_pairs.clear()
for x, t, anchor, downsample in zip(inputs, gt_targets, anchors,
self.downsample):
yolo_loss = self.yolov3_loss(x, t, gt_box, anchor, downsample,
self.scale_x_y)
yolo_loss = self.yolov3_loss(
x.astype('float32'), t, gt_box, anchor, downsample,
self.scale_x_y)
for k, v in yolo_loss.items():
if k in yolo_losses:
yolo_losses[k] += v
Expand Down
10 changes: 6 additions & 4 deletions ppdet/optimizer/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def __init__(self,
self.state_dict = dict()
for k, v in model.state_dict().items():
if k in self.ema_black_list:
self.state_dict[k] = v
self.state_dict[k] = v.astype('float32')
else:
self.state_dict[k] = paddle.zeros_like(v)
self.state_dict[k] = paddle.zeros_like(v, dtype='float32')

self._model_state = {
k: weakref.ref(p)
Expand Down Expand Up @@ -114,7 +114,7 @@ def update(self, model=None):

for k, v in self.state_dict.items():
if k not in self.ema_black_list:
v = decay * v + (1 - decay) * model_dict[k]
v = decay * v + (1 - decay) * model_dict[k].astype('float32')
v.stop_gradient = True
self.state_dict[k] = v
self.step += 1
Expand All @@ -123,13 +123,15 @@ def apply(self):
if self.step == 0:
return self.state_dict
state_dict = dict()
model_dict = {k: p() for k, p in self._model_state.items()}
for k, v in self.state_dict.items():
if k in self.ema_black_list:
v.stop_gradient = True
state_dict[k] = v
state_dict[k] = v.astype(model_dict[k].dtype)
else:
if self.ema_decay_type != 'exponential':
v = v / (1 - self._decay**self.step)
v = v.astype(model_dict[k].dtype)
v.stop_gradient = True
state_dict[k] = v
self.epoch += 1
Expand Down

0 comments on commit 2277804

Please sign in to comment.