diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b9b705459510..68941743ed00e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `resume_from_checkpoint` accept non-existing file path ([#4402](https://github.com/PyTorchLightning/pytorch-lightning/pull/4402)) + ### Changed diff --git a/docs/source/conf.py b/docs/source/conf.py index 655e8dba30a36..2b861623599a6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -294,10 +294,14 @@ def setup(app): # Ignoring Third-party packages # https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule def package_list_from_file(file): + """List up package name (not containing version and extras) from a package list file + """ mocked_packages = [] with open(file, 'r') as fp: for ln in fp.readlines(): - found = [ln.index(ch) for ch in list(',=<>#') if ch in ln] + # Example: `tqdm>=4.41.0` => `tqdm` + # `[` is for package with extras + found = [ln.index(ch) for ch in list(',=<>#[') if ch in ln] pkg = ln[:min(found)] if found else ln if pkg.rstrip(): mocked_packages.append(pkg.rstrip()) diff --git a/environment.yml b/environment.yml index 3d59c1eeed0dd..1278f15f718e9 100644 --- a/environment.yml +++ b/environment.yml @@ -30,7 +30,7 @@ dependencies: - future>=0.17.1 - PyYAML>=5.1 - tqdm>=4.41.0 - - fsspec>=0.8.0 + - fsspec[http]>=0.8.1 #- tensorboard>=2.2.0 # not needed, already included in pytorch # Optional diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index e912462d2491b..c71cbe6ce6180 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -43,7 +43,7 @@ def __init__(self, trainer): # used to validate checkpointing logic self.has_trained = False - def restore_weights(self, model: LightningModule): + def restore_weights(self, model: LightningModule) -> None: """ Attempt to restore a checkpoint (e.g. weights) in this priority: 1. from HPC weights @@ -73,11 +73,16 @@ def restore_weights(self, model: LightningModule): if self.trainer.on_gpu: torch.cuda.empty_cache() - def restore(self, checkpoint_path: str, on_gpu: bool): + def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. """ + # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. + fs = get_filesystem(checkpoint_path) + if not fs.exists(checkpoint_path): + rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch") + return False # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) @@ -94,6 +99,9 @@ def restore(self, checkpoint_path: str, on_gpu: bool): # restore training state self.restore_training_state(checkpoint) + rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}") + return True + def restore_model_state(self, model: LightningModule, checkpoint) -> None: """ Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 25dffa52dcdab..f2e943d2783af 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -251,9 +251,9 @@ def __init__( train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler. - resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here. - This can be a URL. If resuming from mid-epoch checkpoint, training will start from - the beginning of the next epoch. + resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is + no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, + training will start from the beginning of the next epoch. sync_batchnorm: Synchronize batch norm layers between process groups/whole world. diff --git a/requirements.txt b/requirements.txt index 4b8a3efb5c841..2dd5378649851 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,5 @@ future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement >=5.1 tqdm>=4.41.0 -fsspec>=0.8.0 +fsspec[http]>=0.8.1 tensorboard>=2.2.0 diff --git a/requirements/docs.txt b/requirements/docs.txt index 0f8f2005b88b1..df596ed2bdda8 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -11,4 +11,4 @@ https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip#eg sphinx-autodoc-typehints sphinx-paramlinks<0.4.0 sphinx-togglebutton -sphinx-copybutton +sphinx-copybutton \ No newline at end of file diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 17821570bdfa7..f7773f63aa8c2 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -27,7 +27,7 @@ import tests.base.develop_utils as tutils from pytorch_lightning import Callback, LightningModule, Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint -from tests.base import EvalModelTemplate, GenericEvalModelTemplate, TrialMNIST +from tests.base import BoringModel, EvalModelTemplate, GenericEvalModelTemplate, TrialMNIST class ModelTrainerPropertyParity(Callback): @@ -73,6 +73,25 @@ def test_model_properties_resume_from_checkpoint(enable_pl_optimizer, tmpdir): trainer.fit(model) +def test_try_resume_from_non_existing_checkpoint(tmpdir): + """ Test that trying to resume from non-existing `resume_from_checkpoint` fail without error.""" + model = BoringModel() + checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + logger=False, + callbacks=[checkpoint_cb], + limit_train_batches=0.1, + limit_val_batches=0.1, + ) + # Generate checkpoint `last.ckpt` with BoringModel + trainer.fit(model) + # `True` if resume/restore successfully else `False` + assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu) + assert not trainer.checkpoint_connector.restore(str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu) + + class CaptureCallbacksBeforeTraining(Callback): callbacks = []