Skip to content

Commit

Permalink
Add non-existing resume_from_checkpoint acceptance for auto-resubmit (#…
Browse files Browse the repository at this point in the history
…4402)

* Add empty resume_from_checkpoint acceptance #4366

* Fix general error catch with focused file check

* Add fsspec HTTP extras

Add fsspec's HTTPFileSystem  support through http extras.
pl has supported remote http file (e.g. #2925),
so this commit do not add new functionality.

* Fix potential too much logging in DDP

* Add PR changelog

* Add well-written argument explanation

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Fix DDP-compatible restore logging

Notify from where the states are restored.
This feature temporally deleted as a result of PR review.
With succeeding review, added with DDP compatibility.

* Fix utility import pathes

* Refactor load step commentaries

* Refactor hpc ckpt suffix acquisition

* Refactor restore/hpc_load match

* Refactor hpc load trial

* Refactor checkpoint dir check

* Refactor unneeded function nest

* Refactor nested If

* Refactor duplicated cache clear

* Refactor attempt flow with if/elif

* Fix pip8

* Refactor hook commentary

Co-authored-by: chaton <thomas@grid.ai>

* Fix pep8

* Refactor hpc load checkpoint path acquisition

* Fix pip8

* Fix typo

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Fix typo

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Fix doc

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Refactor None Union type with Optional

* Fix build-doc CI failure debuged in #5329

* Fix fsspec import during build-doc #5329

* Fix test epoch

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Fix test with latest test models

* .

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
  • Loading branch information
6 people authored Jan 5, 2021
1 parent dd442b6 commit b0051e8
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `resume_from_checkpoint` accept non-existing file path ([#4402](https://github.com/PyTorchLightning/pytorch-lightning/pull/4402))


### Changed

Expand Down
6 changes: 5 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,14 @@ def setup(app):
# Ignoring Third-party packages
# https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule
def package_list_from_file(file):
"""List up package name (not containing version and extras) from a package list file
"""
mocked_packages = []
with open(file, 'r') as fp:
for ln in fp.readlines():
found = [ln.index(ch) for ch in list(',=<>#') if ch in ln]
# Example: `tqdm>=4.41.0` => `tqdm`
# `[` is for package with extras
found = [ln.index(ch) for ch in list(',=<>#[') if ch in ln]
pkg = ln[:min(found)] if found else ln
if pkg.rstrip():
mocked_packages.append(pkg.rstrip())
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies:
- future>=0.17.1
- PyYAML>=5.1
- tqdm>=4.41.0
- fsspec>=0.8.0
- fsspec[http]>=0.8.1
#- tensorboard>=2.2.0 # not needed, already included in pytorch

# Optional
Expand Down
12 changes: 10 additions & 2 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, trainer):
# used to validate checkpointing logic
self.has_trained = False

def restore_weights(self, model: LightningModule):
def restore_weights(self, model: LightningModule) -> None:
"""
Attempt to restore a checkpoint (e.g. weights) in this priority:
1. from HPC weights
Expand Down Expand Up @@ -73,11 +73,16 @@ def restore_weights(self, model: LightningModule):
if self.trainer.on_gpu:
torch.cuda.empty_cache()

def restore(self, checkpoint_path: str, on_gpu: bool):
def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
"""
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
All restored states are listed in return value description of `dump_checkpoint`.
"""
# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
fs = get_filesystem(checkpoint_path)
if not fs.exists(checkpoint_path):
rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch")
return False

# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
Expand All @@ -94,6 +99,9 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
# restore training state
self.restore_training_state(checkpoint)

rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}")
return True

def restore_model_state(self, model: LightningModule, checkpoint) -> None:
"""
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ def __init__(
train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it,
you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
This can be a URL. If resuming from mid-epoch checkpoint, training will start from
the beginning of the next epoch.
resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is
no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ future>=0.17.1 # required for builtins in setup.py
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement >=5.1
tqdm>=4.41.0
fsspec>=0.8.0
fsspec[http]>=0.8.1
tensorboard>=2.2.0
2 changes: 1 addition & 1 deletion requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip#eg
sphinx-autodoc-typehints
sphinx-paramlinks<0.4.0
sphinx-togglebutton
sphinx-copybutton
sphinx-copybutton
21 changes: 20 additions & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import tests.base.develop_utils as tutils
from pytorch_lightning import Callback, LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from tests.base import EvalModelTemplate, GenericEvalModelTemplate, TrialMNIST
from tests.base import BoringModel, EvalModelTemplate, GenericEvalModelTemplate, TrialMNIST


class ModelTrainerPropertyParity(Callback):
Expand Down Expand Up @@ -73,6 +73,25 @@ def test_model_properties_resume_from_checkpoint(enable_pl_optimizer, tmpdir):
trainer.fit(model)


def test_try_resume_from_non_existing_checkpoint(tmpdir):
""" Test that trying to resume from non-existing `resume_from_checkpoint` fail without error."""
model = BoringModel()
checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=False,
callbacks=[checkpoint_cb],
limit_train_batches=0.1,
limit_val_batches=0.1,
)
# Generate checkpoint `last.ckpt` with BoringModel
trainer.fit(model)
# `True` if resume/restore successfully else `False`
assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu)
assert not trainer.checkpoint_connector.restore(str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu)


class CaptureCallbacksBeforeTraining(Callback):
callbacks = []

Expand Down

0 comments on commit b0051e8

Please sign in to comment.