Skip to content

Commit

Permalink
Add a trainer.ckpt_path setter for stateful loading (#16187)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jan 11, 2023
1 parent c8a7d48 commit 96fb863
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 63 deletions.
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added migration logic to warn about checkpoints with apex AMP state ([#16161](https://github.com/Lightning-AI/lightning/pull/16161))

- Added the `Trainer.ckpt_path = ...` setter to statefully set the checkpoint path to load. This can act as a replacement for the removed `Trainer(resume_from_checkpoint=...)` flag ([#16187](https://github.com/Lightning-AI/lightning/pull/16187))

### Removed

- Removed the `pytorch_lightning.lite` module in favor of `lightning_fabric` ([#15953](https://github.com/Lightning-AI/lightning/pull/15953))
Expand Down
69 changes: 53 additions & 16 deletions src/pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
class CheckpointConnector:
def __init__(self, trainer: "pl.Trainer") -> None:
self.trainer = trainer
self.resume_checkpoint_path: Optional[_PATH] = None
self._ckpt_path: Optional[_PATH] = None
# flag to know if the user is changing the checkpoint path statefully. See `trainer.ckpt_path.setter`
self._user_managed: bool = False
self._loaded_checkpoint: Dict[str, Any] = {}

@property
Expand All @@ -73,7 +75,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
3. from `checkpoint_path` file if provided
4. don't restore
"""
self.resume_checkpoint_path = checkpoint_path
self._ckpt_path = checkpoint_path
if not checkpoint_path:
log.detail("`checkpoint_path` not specified. Skipping checkpoint loading.")
return
Expand All @@ -83,9 +85,41 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)

def _set_ckpt_path(
self, state_fn: TrainerFn, ckpt_path: Optional[str], model_provided: bool, model_connected: bool
) -> Optional[str]:
def _select_ckpt_path(
self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool
) -> Optional[_PATH]:
"""Called by the ``Trainer`` to select the checkpoint path source."""
if self._user_managed:
if ckpt_path:
rank_zero_warn(
f"`trainer.ckpt_path = {self._ckpt_path!r}` was called but then you"
f" passed `trainer.fit(ckpt_path={ckpt_path!r})`. The latter will be loaded."
)
# reset the previous path
self._ckpt_path = None
self._user_managed = False
ckpt_path = self._parse_ckpt_path(
state_fn,
ckpt_path,
model_provided=model_provided,
model_connected=model_connected,
)
else:
ckpt_path = self._ckpt_path
else:
ckpt_path = self._parse_ckpt_path(
state_fn,
ckpt_path,
model_provided=model_provided,
model_connected=model_connected,
)
return ckpt_path

def _parse_ckpt_path(
self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool
) -> Optional[_PATH]:
"""Converts the ``ckpt_path`` special values into an actual filepath, depending on the trainer
configuration."""
if ckpt_path is None and SLURMEnvironment.detect() and self._hpc_resume_path is not None:
ckpt_path = "hpc"

Expand Down Expand Up @@ -181,15 +215,12 @@ def resume_end(self) -> None:
"""Signal the connector that all states have resumed and memory for the checkpoint object can be
released."""
assert self.trainer.state.fn is not None
if self.resume_checkpoint_path:
if self.trainer.state.fn == TrainerFn.FITTING:
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
elif self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING):
rank_zero_info(f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}")
self.resume_checkpoint_path = None
self._loaded_checkpoint = {}
if self._ckpt_path:
message = "Restored all states" if self.trainer.state.fn == TrainerFn.FITTING else "Loaded model weights"
rank_zero_info(f"{message} from the checkpoint at {self._ckpt_path}")

# clear cache after restore
# free memory
self._loaded_checkpoint = {}
torch.cuda.empty_cache()

# wait for all to catch up
Expand Down Expand Up @@ -391,9 +422,15 @@ def restore_lr_schedulers(self) -> None:
for config, lrs_state in zip(self.trainer.lr_scheduler_configs, lr_schedulers):
config.scheduler.load_state_dict(lrs_state)

# ----------------------------------
# PRIVATE OPS
# ----------------------------------
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
# restore modules after setup
self.resume_start(checkpoint_path)
self._restore_quantization_callbacks()
self.restore_model()
self.restore_datamodule()
if self.trainer.state.fn == TrainerFn.FITTING:
# restore callback states
self.restore_callbacks()

def dump_checkpoint(self, weights_only: bool = False) -> dict:
"""Creating a model checkpoint dictionary object from various component states.
Expand Down
69 changes: 31 additions & 38 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,6 @@ def __init__(
# default .predict() loop
self.predict_loop = PredictionLoop()

# set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
self._ckpt_path: Optional[str] = None

# init callbacks
# Declare attributes to be set in _callback_connector on_trainer_init
self._callback_connector.on_trainer_init(
Expand Down Expand Up @@ -574,14 +571,13 @@ def _fit_impl(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

ckpt_path = ckpt_path
self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn,
ckpt_path,
model_provided=True,
model_connected=self.lightning_module is not None,
)
self._run(model, ckpt_path=self.ckpt_path)
self._run(model, ckpt_path=ckpt_path)

assert self.state.stopped
self.training = False
Expand Down Expand Up @@ -665,14 +661,10 @@ def _validate_impl(
# links data to the trainer
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)

self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

self._validated_ckpt_path = self.ckpt_path # TODO: remove in v1.8

# run validate
results = self._run(model, ckpt_path=self.ckpt_path)
results = self._run(model, ckpt_path=ckpt_path)

assert self.state.stopped
self.validating = False
Expand Down Expand Up @@ -758,14 +750,10 @@ def _test_impl(
# links data to the trainer
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)

self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

self._tested_ckpt_path = self.ckpt_path # TODO: remove in v1.8

# run test
results = self._run(model, ckpt_path=self.ckpt_path)
results = self._run(model, ckpt_path=ckpt_path)

assert self.state.stopped
self.testing = False
Expand Down Expand Up @@ -851,13 +839,10 @@ def _predict_impl(
# links data to the trainer
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)

self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

self._predicted_ckpt_path = self.ckpt_path # TODO: remove in v1.8

results = self._run(model, ckpt_path=self.ckpt_path)
results = self._run(model, ckpt_path=ckpt_path)

assert self.state.stopped
self.predicting = False
Expand Down Expand Up @@ -918,18 +903,8 @@ def tune(

return result

def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
# restore modules after setup
self._checkpoint_connector.resume_start(checkpoint_path)
self._checkpoint_connector._restore_quantization_callbacks()
self._checkpoint_connector.restore_model()
self._checkpoint_connector.restore_datamodule()
if self.state.fn == TrainerFn.FITTING:
# restore callback states
self._checkpoint_connector.restore_callbacks()

def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[str] = None
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if model._compiler_ctx is not None:
supported_strategies = [SingleDeviceStrategy, DDPStrategy, DDPFullyShardedNativeStrategy]
Expand Down Expand Up @@ -978,7 +953,7 @@ def _run(
# check if we should delay restoring checkpoint till later
if not self.strategy.restore_checkpoint_after_setup:
log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
self._restore_modules_and_callbacks(ckpt_path)
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)

log.detail(f"{self.__class__.__name__}: configuring sharded model")
self._call_configure_sharded_model() # allow user to setup in model sharded environment
Expand Down Expand Up @@ -1026,7 +1001,7 @@ def _run(

if self.strategy.restore_checkpoint_after_setup:
log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
self._restore_modules_and_callbacks(ckpt_path)
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)

# restore optimizers, etc.
log.detail(f"{self.__class__.__name__}: restoring training state")
Expand Down Expand Up @@ -1811,12 +1786,30 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]:
return None

@property
def ckpt_path(self) -> Optional[str]:
def ckpt_path(self) -> Optional[_PATH]:
"""Set to the path/URL of a checkpoint loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`,
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`,
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
return self._ckpt_path
return self._checkpoint_connector._ckpt_path

@ckpt_path.setter
def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None:
"""Allows you to manage which checkpoint is loaded statefully.
Examples::
trainer = Trainer()
trainer.ckpt_path = "my/checkpoint/file.ckpt"
trainer.fit(model)
...
# you will be in charge of resetting this
trainer.ckpt_path = None
trainer.test(model)
"""
self._checkpoint_connector._ckpt_path = ckpt_path
self._checkpoint_connector._user_managed = bool(ckpt_path)

def save_checkpoint(
self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

for trainer_fn in TrainerFn:
trainer.state.fn = trainer_fn
trainer._restore_modules_and_callbacks(checkpoint_path)
trainer._checkpoint_connector._restore_modules_and_callbacks(checkpoint_path)
assert dm.my_state_dict == {"my": "state_dict"}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
Expand All @@ -21,6 +22,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.migration.utils import _set_version


def test_preloaded_checkpoint_lifecycle(tmpdir):
Expand All @@ -31,26 +33,27 @@ def test_preloaded_checkpoint_lifecycle(tmpdir):

connector = trainer._checkpoint_connector

assert not connector.resume_checkpoint_path
assert not connector._ckpt_path
assert not connector._loaded_checkpoint

connector.resume_start()
assert not connector.resume_checkpoint_path
assert not connector._ckpt_path
assert not connector._loaded_checkpoint
connector.resume_end()
assert not connector.resume_checkpoint_path
assert not connector._ckpt_path
assert not connector._loaded_checkpoint

ckpt_path = trainer.checkpoint_callback.best_model_path
trainer = Trainer(default_root_dir=tmpdir, max_steps=2)
connector = trainer._checkpoint_connector
connector.resume_start(ckpt_path)
assert connector.resume_checkpoint_path == ckpt_path
assert connector._ckpt_path == ckpt_path
assert connector._loaded_checkpoint
assert isinstance(connector._loaded_checkpoint, dict)
trainer.state.fn = TrainerFn.FITTING
connector.resume_end()
assert not connector.resume_checkpoint_path
# not cleared until next restoration, as the user might access it through `trainer.ckpt_path`
assert connector._ckpt_path == ckpt_path
assert not connector._loaded_checkpoint


Expand Down Expand Up @@ -166,3 +169,54 @@ def test_loops_restore(tmpdir):
if fn2 != fn:
trainer_loop2 = getattr(trainer, f"{fn2}_loop")
trainer_loop2.load_state_dict.assert_not_called()


def test_stateful_trainer_ckpt_path_support(tmp_path):
"""Tests support for the pattern used by NeMo's experiment manager."""
model = BoringModel()

# dummy ckpt data
ckpt_data = {"state_dict": model.state_dict(), "optimizer_states": {}, "lr_schedulers": {}}
_set_version(ckpt_data, "2.0.0")

# save a "checkpoint"
ckpt_path = tmp_path / "foo.ckpt"
torch.save(ckpt_data, ckpt_path)

# mock model checkpoint instance that has saved a last checkpoint
model_checkpoint = Mock(spec=ModelCheckpoint)
last_path = tmp_path / "last.ckpt"
torch.save(ckpt_data, last_path)
model_checkpoint._find_last_checkpoints.return_value = {last_path}

trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True, callbacks=model_checkpoint)

# set the ckpt path statefully
trainer.ckpt_path = ckpt_path
trainer.fit(model)
assert trainer.ckpt_path == ckpt_path # not automatically cleaned
assert trainer._checkpoint_connector._user_managed

# now conflict with ckpt_path functionally
with pytest.warns(UserWarning, match="trainer.ckpt_path =.*but then you passed"):
trainer.fit(model, ckpt_path="last")
assert trainer.ckpt_path == last_path
assert not trainer._checkpoint_connector._user_managed

# mock model checkpoint instance that has saved a last checkpoint
best_path = tmp_path / "best.ckpt"
torch.save(ckpt_data, best_path)
model_checkpoint.best_model_path = best_path

# `trainer.test` will use this over "best" if statefully set
trainer.ckpt_path = ckpt_path
trainer.test()
assert trainer.ckpt_path == ckpt_path

# ckpt_path = "best" still works if it's reset
trainer.ckpt_path = None
# the state is cleared
assert trainer._checkpoint_connector._ckpt_path is None
assert not trainer._checkpoint_connector._user_managed
trainer.test()
assert trainer.ckpt_path == best_path
Loading

0 comments on commit 96fb863

Please sign in to comment.