Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved model+EMA checkpointing #2292

Merged
merged 13 commits into from
Feb 25, 2021
1 change: 0 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def test(data,
if not training:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}")
model.float() # for training
maps = np.zeros(nc) + map
for i, c in enumerate(ap_class):
maps[c] = ap[i]
Expand Down
25 changes: 16 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from utils.google_utils import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -136,6 +136,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
loggers = {'wandb': wandb} # loggers dict

# EMA
ema = ModelEMA(model) if rank in [-1, 0] else None

# Resume
start_epoch, best_fitness = 0, 0.0
if pretrained:
Expand All @@ -144,6 +147,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
optimizer.load_state_dict(ckpt['optimizer'])
best_fitness = ckpt['best_fitness']

# EMA
if ema and ckpt.get('ema'):
ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
ema.updates = ckpt['ema'][1]

# Results
if ckpt.get('training_results') is not None:
results_file.write_text(ckpt['training_results']) # write results.txt
Expand Down Expand Up @@ -173,9 +181,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
logger.info('Using SyncBatchNorm()')

# EMA
ema = ModelEMA(model) if rank in [-1, 0] else None

# DDP mode
if cuda and rank != -1:
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
Expand All @@ -191,7 +196,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):

# Process 0
if rank in [-1, 0]:
ema.updates = start_epoch * nb // accumulate # set EMA updates
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers,
Expand Down Expand Up @@ -335,8 +339,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# DDP process 0 or single-GPU
if rank in [-1, 0]:
# mAP
if ema:
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
results, maps, times = test.test(opt.data,
Expand Down Expand Up @@ -378,15 +381,19 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': results_file.read_text(),
'model': ema.ema,
'optimizer': None if final_epoch else optimizer.state_dict(),
'model': (model.module if is_parallel(model) else model).half(),
'ema': (ema.ema.half(), ema.updates),
'optimizer': optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None}

# Save last, best and delete
torch.save(ckpt, last)
if best_fitness == fi:
torch.save(ckpt, best)
del ckpt

model.float(), ema.ema.float()

# end epoch ----------------------------------------------------------------------------------------------------
# end training

Expand Down
4 changes: 2 additions & 2 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
for key in 'optimizer', 'training_results', 'wandb_id':
x[key] = None
for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
Expand Down