From d0a3e252e1acd3c6a7afcd46096d64fc619d1b43 Mon Sep 17 00:00:00 2001 From: Yongrae Jo Date: Fri, 15 Nov 2019 22:09:04 +0900 Subject: [PATCH 1/6] Add resume_from_checkpoint --- pytorch_lightning/trainer/trainer.py | 5 ++++- pytorch_lightning/trainer/trainer_io.py | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index eebba265c354b..24903157d6368 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -82,7 +82,8 @@ def __init__(self, weights_save_path=None, amp_level='O1', nb_sanity_val_steps=5, - truncated_bptt_steps=None): + truncated_bptt_steps=None, + resume_from_checkpoint=None): """ :param logger: Logger for experiment tracking @@ -119,6 +120,7 @@ def __init__(self, :param amp_level: str. Check nvidia docs for level :param nb_sanity_val_steps: int. How many val steps before a full train loop. :param truncated_bptt_steps: int. Enables multiple backward passes for each batch. + :param resume_from_checkpoint: str or os.PathLike object. Resume from specific checkpoint. """ # Transfer params self.nb_gpu_nodes = nb_gpu_nodes @@ -139,6 +141,7 @@ def __init__(self, self.nb_sanity_val_steps = nb_sanity_val_steps self.print_nan_grads = print_nan_grads self.truncated_bptt_steps = truncated_bptt_steps + self.resume_from_checkpoint = resume_from_checkpoint self.shown_warnings = set() self.fast_dev_run = fast_dev_run diff --git a/pytorch_lightning/trainer/trainer_io.py b/pytorch_lightning/trainer/trainer_io.py index 11bbb1d55045a..a62c2a60e306f 100644 --- a/pytorch_lightning/trainer/trainer_io.py +++ b/pytorch_lightning/trainer/trainer_io.py @@ -2,6 +2,7 @@ import re import signal import warnings +from pathlib import Path from subprocess import call import logging @@ -46,7 +47,9 @@ def restore_weights(self, model): if not did_restore_hpc_weights: # restore weights if same exp version - self.restore_state_if_checkpoint_exists(model) + did_restore_last_checkpoint = self.restore_state_if_checkpoint_exists(model) + if not did_restore_last_checkpoint and self.restore_from_checkpoint is not None: + self.restore_state_from_checkpoint(self.restore_from_checkpoint) # wait for all models to restore weights if self.use_ddp or self.use_ddp2: @@ -93,6 +96,18 @@ def restore_state_if_checkpoint_exists(self, model): return did_restore + def restore_state_from_checkpoint(self, checkpoint_path): + did_restore = False + + checkpoint_path = Path(checkpoint_path) + if not checkpoint_path.exists(): + return did_restore + + self.restore(checkpoint_path, self.on_gpu) + did_restore = True + + return did_restore + # -------------------- # HPC SIGNAL HANDLING # -------------------- From 915b5c530f5712ac2cdc0fbb00691fe55333aff2 Mon Sep 17 00:00:00 2001 From: Yongrae Jo Date: Fri, 15 Nov 2019 22:12:28 +0900 Subject: [PATCH 2/6] Fix variable name --- pytorch_lightning/trainer/trainer_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer_io.py b/pytorch_lightning/trainer/trainer_io.py index a62c2a60e306f..95385684c9425 100644 --- a/pytorch_lightning/trainer/trainer_io.py +++ b/pytorch_lightning/trainer/trainer_io.py @@ -48,8 +48,8 @@ def restore_weights(self, model): if not did_restore_hpc_weights: # restore weights if same exp version did_restore_last_checkpoint = self.restore_state_if_checkpoint_exists(model) - if not did_restore_last_checkpoint and self.restore_from_checkpoint is not None: - self.restore_state_from_checkpoint(self.restore_from_checkpoint) + if not did_restore_last_checkpoint and self.resume_from_checkpoint is not None: + self.restore_state_from_checkpoint(self.resume_from_checkpoint) # wait for all models to restore weights if self.use_ddp or self.use_ddp2: From 0db87852e43739784099f70cf9cf81e8995f7e4b Mon Sep 17 00:00:00 2001 From: Yongrae Jo Date: Mon, 18 Nov 2019 15:00:01 +0900 Subject: [PATCH 3/6] #515 Remove did_restore --- pytorch_lightning/trainer/trainer_io.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/trainer_io.py b/pytorch_lightning/trainer/trainer_io.py index 95385684c9425..18b6b3a677e74 100644 --- a/pytorch_lightning/trainer/trainer_io.py +++ b/pytorch_lightning/trainer/trainer_io.py @@ -97,16 +97,12 @@ def restore_state_if_checkpoint_exists(self, model): return did_restore def restore_state_from_checkpoint(self, checkpoint_path): - did_restore = False - checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): - return did_restore + return False self.restore(checkpoint_path, self.on_gpu) - did_restore = True - - return did_restore + return True # -------------------- # HPC SIGNAL HANDLING From bcebf600903ea99b298b01f73a02c6e226bdd662 Mon Sep 17 00:00:00 2001 From: Yongrae Jo Date: Mon, 18 Nov 2019 15:09:16 +0900 Subject: [PATCH 4/6] #515 Simplify code --- pytorch_lightning/trainer/trainer_io.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/trainer_io.py b/pytorch_lightning/trainer/trainer_io.py index 18b6b3a677e74..d33d4a87b3833 100644 --- a/pytorch_lightning/trainer/trainer_io.py +++ b/pytorch_lightning/trainer/trainer_io.py @@ -2,7 +2,6 @@ import re import signal import warnings -from pathlib import Path from subprocess import call import logging @@ -46,10 +45,11 @@ def restore_weights(self, model): torch.cuda.empty_cache() if not did_restore_hpc_weights: - # restore weights if same exp version - did_restore_last_checkpoint = self.restore_state_if_checkpoint_exists(model) - if not did_restore_last_checkpoint and self.resume_from_checkpoint is not None: - self.restore_state_from_checkpoint(self.resume_from_checkpoint) + if self.resume_from_checkpoint is not None: + self.restore(self.resume_from_checkpoint) + else: + # restore weights if same exp version + self.restore_state_if_checkpoint_exists(model) # wait for all models to restore weights if self.use_ddp or self.use_ddp2: @@ -96,14 +96,6 @@ def restore_state_if_checkpoint_exists(self, model): return did_restore - def restore_state_from_checkpoint(self, checkpoint_path): - checkpoint_path = Path(checkpoint_path) - if not checkpoint_path.exists(): - return False - - self.restore(checkpoint_path, self.on_gpu) - return True - # -------------------- # HPC SIGNAL HANDLING # -------------------- From 70cbd8c67e697e2674b023ccdd573ff310e456d2 Mon Sep 17 00:00:00 2001 From: Yongrae Jo Date: Mon, 18 Nov 2019 15:12:15 +0900 Subject: [PATCH 5/6] #515 Update doc for resume_from_checkpoint --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 24903157d6368..ae326700e5896 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -120,7 +120,7 @@ def __init__(self, :param amp_level: str. Check nvidia docs for level :param nb_sanity_val_steps: int. How many val steps before a full train loop. :param truncated_bptt_steps: int. Enables multiple backward passes for each batch. - :param resume_from_checkpoint: str or os.PathLike object. Resume from specific checkpoint. + :param resume_from_checkpoint: file-like object or str containing a file name. Resume from specific checkpoint. """ # Transfer params self.nb_gpu_nodes = nb_gpu_nodes From 34771080959f1a0cf0449044aff60e83dd1fdda4 Mon Sep 17 00:00:00 2001 From: Yongrae Jo Date: Mon, 18 Nov 2019 15:24:39 +0900 Subject: [PATCH 6/6] #515 Add on_gpu --- pytorch_lightning/trainer/trainer_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer_io.py b/pytorch_lightning/trainer/trainer_io.py index d33d4a87b3833..330d4a88286f6 100644 --- a/pytorch_lightning/trainer/trainer_io.py +++ b/pytorch_lightning/trainer/trainer_io.py @@ -46,7 +46,7 @@ def restore_weights(self, model): if not did_restore_hpc_weights: if self.resume_from_checkpoint is not None: - self.restore(self.resume_from_checkpoint) + self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu) else: # restore weights if same exp version self.restore_state_if_checkpoint_exists(model)