From fa87d1da769ee7160984e0b905cc8b22d8647656 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 22 Apr 2020 20:09:41 -0400 Subject: [PATCH] added state saving --- pytorch_lightning/trainer/training_io.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 47448132df28a..09a65af0485de 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -281,6 +281,10 @@ def restore(self, checkpoint_path: str, on_gpu: bool): if on_gpu: model.cuda(self.root_gpu) + # restore amp scaling + if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: + self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) + # load training state (affects trainer only) self.restore_training_state(checkpoint) @@ -316,6 +320,10 @@ def dump_checkpoint(self): checkpoint['state_dict'] = model.state_dict() + # restore native amp scaling + if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: + checkpoint['native_amp_scaling_state'] = self.scaler.state_dict + if hasattr(model, "hparams"): is_namespace = isinstance(model.hparams, Namespace) checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams @@ -441,6 +449,10 @@ def hpc_load(self, folderpath, on_gpu): # load the state_dict on the model automatically model.load_state_dict(checkpoint['state_dict']) + # restore amp scaling + if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: + self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) + if self.root_gpu is not None: model.cuda(self.root_gpu)