Skip to content

Commit

Permalink
Load checkpoint on CPU instead of on GPU (ultralytics#6516)
Browse files Browse the repository at this point in the history
* Load checkpoint on CPU instead of on GPU

* refactor: simplify code

* Cleanup

* Update train.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
bilzard and glenn-jocher authored Feb 4, 2022
1 parent 1c184bc commit 24380e6
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if pretrained:
with torch_distributed_zero_first(LOCAL_RANK):
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
Expand Down

0 comments on commit 24380e6

Please sign in to comment.