Skip to content

Commit

Permalink
Remove the deprecated resume_from_checkpoint Trainer argument (#16167)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored and carmocca committed Jan 4, 2023
1 parent 6c6b950 commit a162f81
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 107 deletions.
25 changes: 0 additions & 25 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1141,31 +1141,6 @@ By setting to False, you have to add your own distributed sampler:
.. note:: For iterable datasets, we don't do this automatically.

resume_from_checkpoint
^^^^^^^^^^^^^^^^^^^^^^

.. warning:: ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v2.0.
Please pass ``trainer.fit(ckpt_path="some/path/to/my_checkpoint.ckpt")`` instead.


.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/resume_from_checkpoint.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/resume_from_checkpoint.mp4"></video>

|
To resume training from a specific checkpoint pass in the path here. If resuming from a mid-epoch
checkpoint, training will start from the beginning of the next epoch.

.. testcode::

# default used by the Trainer
trainer = Trainer(resume_from_checkpoint=None)

# resume from a specific checkpoint
trainer = Trainer(resume_from_checkpoint="some/path/to/my_checkpoint.ckpt")

strategy
^^^^^^^^
Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed the `Trainer(num_processes=...)` argument


- Removed the deprecated `resume_from_checkpoint` Trainer argument ([#16167](https://github.com/Lightning-AI/lightning/pull/16167))


## [unreleased] - 202Y-MM-DD

### Added
Expand Down
17 changes: 2 additions & 15 deletions src/pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn

if _OMEGACONF_AVAILABLE:
from omegaconf import Container
Expand All @@ -46,16 +46,9 @@


class CheckpointConnector:
def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None:
def __init__(self, trainer: "pl.Trainer") -> None:
self.trainer = trainer
self.resume_checkpoint_path: Optional[_PATH] = None
# TODO: remove resume_from_checkpoint_fit_path in v2.0
self.resume_from_checkpoint_fit_path: Optional[_PATH] = resume_from_checkpoint
if resume_from_checkpoint is not None:
rank_zero_deprecation(
"Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and"
" will be removed in v2.0. Please pass `Trainer.fit(ckpt_path=)` directly instead."
)
self._loaded_checkpoint: Dict[str, Any] = {}

@property
Expand Down Expand Up @@ -193,12 +186,6 @@ def resume_end(self) -> None:
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
elif self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING):
rank_zero_info(f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}")
# TODO: remove resume_from_checkpoint_fit_path in v2.0
if (
self.trainer.state.fn == TrainerFn.FITTING
and self.resume_checkpoint_path == self.resume_from_checkpoint_fit_path
):
self.resume_from_checkpoint_fit_path = None
self.resume_checkpoint_path = None
self._loaded_checkpoint = {}

Expand Down
29 changes: 3 additions & 26 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from contextlib import contextmanager
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, List, Optional, Type, Union
from weakref import proxy

Expand Down Expand Up @@ -145,7 +144,6 @@ def __init__(
precision: Union[int, str] = 32,
enable_model_summary: bool = True,
num_sanity_val_steps: int = 2,
resume_from_checkpoint: Optional[Union[Path, str]] = None,
profiler: Optional[Union[Profiler, str]] = None,
benchmark: Optional[bool] = None,
deterministic: Optional[Union[bool, _LITERAL_WARN]] = None,
Expand Down Expand Up @@ -311,14 +309,6 @@ def __init__(
train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it,
you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.
resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is
no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.
.. deprecated:: v1.5
``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v2.0.
Please pass the path to ``Trainer.fit(..., ckpt_path=...)`` instead.
strategy: Supports different training strategies with aliases
as well custom strategies.
Default: ``None``.
Expand Down Expand Up @@ -379,7 +369,7 @@ def __init__(
)
self._logger_connector = LoggerConnector(self)
self._callback_connector = CallbackConnector(self)
self._checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint)
self._checkpoint_connector = CheckpointConnector(self)
self._signal_connector = SignalConnector(self)
self.tuner = Tuner(self)

Expand Down Expand Up @@ -581,11 +571,10 @@ def _fit_impl(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

# TODO: ckpt_path only in v2.0
ckpt_path = ckpt_path or self.resume_from_checkpoint
ckpt_path = ckpt_path
self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
self.state.fn,
ckpt_path, # type: ignore[arg-type]
ckpt_path,
model_provided=True,
model_connected=self.lightning_module is not None,
)
Expand Down Expand Up @@ -1818,18 +1807,6 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]:
return c
return None

@property
def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:
resume_from_checkpoint = self._checkpoint_connector.resume_from_checkpoint_fit_path
if resume_from_checkpoint is not None:
rank_zero_deprecation(
"`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v2.0."
" Specify the fit checkpoint path with `trainer.fit(ckpt_path=)` instead.",
stacklevel=5,
)

return resume_from_checkpoint

@property
def ckpt_path(self) -> Optional[str]:
"""Set to the path/URL of a checkpoint loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`,
Expand Down
41 changes: 0 additions & 41 deletions tests/tests_pytorch/deprecated_api/test_remove_2-0.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,47 +19,6 @@
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.demos.boring_classes import BoringModel
from tests_pytorch.callbacks.test_callbacks import OldStatefulCallback


def test_v2_0_0_resume_from_checkpoint_trainer_constructor(tmpdir):
# test resume_from_checkpoint still works until v2.0 deprecation
model = BoringModel()
callback = OldStatefulCallback(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path

callback = OldStatefulCallback(state=222)
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5"):
_ = trainer.resume_from_checkpoint
assert trainer._checkpoint_connector.resume_checkpoint_path is None
assert trainer._checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
trainer.validate(model=model, ckpt_path=ckpt_path)
assert callback.state == 222
assert trainer._checkpoint_connector.resume_checkpoint_path is None
assert trainer._checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5"):
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path # last `fit` replaced the `best_model_path`
assert callback.state == 111
assert trainer._checkpoint_connector.resume_checkpoint_path is None
assert trainer._checkpoint_connector.resume_from_checkpoint_fit_path is None
trainer.predict(model=model, ckpt_path=ckpt_path)
assert trainer._checkpoint_connector.resume_checkpoint_path is None
assert trainer._checkpoint_connector.resume_from_checkpoint_fit_path is None
trainer.fit(model)
assert trainer._checkpoint_connector.resume_checkpoint_path is None
assert trainer._checkpoint_connector.resume_from_checkpoint_fit_path is None

# test fit(ckpt_path=) precedence over Trainer(resume_from_checkpoint=) path
model = BoringModel()
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")


def test_v2_0_0_callback_on_load_checkpoint_hook(tmpdir):
Expand Down

0 comments on commit a162f81

Please sign in to comment.