-
Notifications
You must be signed in to change notification settings - Fork 3.5k
support for native amp #1561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support for native amp #1561
Changes from all commits
fd42641
ea1650a
807033d
d698328
794df48
fb6e414
ba02a20
ee6299e
4d06040
199c96c
2af9dc9
afb6801
fa87d1d
c60a885
de22e4f
879e691
60b9963
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -281,6 +281,10 @@ def restore(self, checkpoint_path: str, on_gpu: bool): | |
if on_gpu: | ||
model.cuda(self.root_gpu) | ||
|
||
# restore amp scaling | ||
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mcarilli sanity check this loading? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good if you fix the saving https://github.com/PyTorchLightning/pytorch-lightning/pull/1561/files#r413418705 Like saving, loading should occur either at the very beginning of an iteration (before any training-related |
||
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) | ||
|
||
# load training state (affects trainer only) | ||
self.restore_training_state(checkpoint) | ||
|
||
|
@@ -316,6 +320,10 @@ def dump_checkpoint(self): | |
|
||
checkpoint['state_dict'] = model.state_dict() | ||
|
||
# restore native amp scaling | ||
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mcarilli sanity check this saving? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also you should make sure I can't tell from these lines alone if the calling code occurs at a spot that obeys those criteria. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i thought it was a property haha, but i guess it's consistent with the other state_dict() calls haha There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lol i see. it's consistent with the rest There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another thing to consider is that with
I think your There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah this code works. Case 1: Train with amp, load ampworks fine case 2: Train amp, load and not use ampin this case, lightning loads the amp state but amp is disabled so user doesn't use it at all case 3: train regular, resume regularworks fine case 4: train regular, resume with ampin this case the checkpoint has no amp state and model starts normal but on amp. |
||
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() | ||
|
||
if hasattr(model, "hparams"): | ||
is_namespace = isinstance(model.hparams, Namespace) | ||
checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams | ||
|
@@ -441,6 +449,10 @@ def hpc_load(self, folderpath, on_gpu): | |
# load the state_dict on the model automatically | ||
model.load_state_dict(checkpoint['state_dict']) | ||
|
||
# restore amp scaling | ||
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: | ||
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) | ||
|
||
if self.root_gpu is not None: | ||
model.cuda(self.root_gpu) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.