Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

COCO evolution fix #3388

Merged
merged 4 commits into from
May 29, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 28 additions & 30 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def train(hyp, opt, device, tb_writer=None):
init_seeds(2 + rank)
with open(opt.data) as f:
data_dict = yaml.safe_load(f) # data dict
is_coco = opt.data.endswith('coco.yaml')

# Logging- Doing this before checking the dataset. Might update data_dict
loggers = {'wandb': None} # loggers dict
Expand All @@ -78,6 +77,7 @@ def train(hyp, opt, device, tb_writer=None):
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset

# Model
pretrained = weights.endswith('.pt')
Expand Down Expand Up @@ -358,6 +358,7 @@ def train(hyp, opt, device, tb_writer=None):
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
wandb_logger=wandb_logger,
Expand Down Expand Up @@ -409,41 +410,38 @@ def train(hyp, opt, device, tb_writer=None):
# end epoch ----------------------------------------------------------------------------------------------------
# end training
if rank in [-1, 0]:
# Plots
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots:
plot_results(save_dir=save_dir) # save as results.png
if wandb_logger.wandb:
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})
# Test best.pt
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
results, _, _ = test.test(opt.data,
batch_size=batch_size * 2,
imgsz=imgsz_test,
conf_thres=0.001,
iou_thres=0.7,
model=attempt_load(m, device).half(),
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=True,
plots=False,
is_coco=is_coco)

# Strip optimizers
final = best if best.exists() else last # final model
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
if opt.bucket:
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
if wandb_logger.wandb and not opt.evolve: # Log the stripped model
wandb_logger.wandb.log_artifact(str(final), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])

if not opt.evolve:
if is_coco: # COCO dataset
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
results, _, _ = test.test(opt.data,
batch_size=batch_size * 2,
imgsz=imgsz_test,
conf_thres=0.001,
iou_thres=0.7,
model=attempt_load(m, device).half(),
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=True,
plots=False,
is_coco=is_coco)

# Strip optimizers
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
if wandb_logger.wandb: # Log the stripped model
wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
wandb_logger.finish_run()
else:
dist.destroy_process_group()
Expand Down