diff --git a/.gitignore b/.gitignore index eb56709276b25..47b9bfff92523 100644 --- a/.gitignore +++ b/.gitignore @@ -136,7 +136,7 @@ ENV/ Datasets/ mnist/ MNIST/ -legacy/checkpoints/ +tests/legacy/checkpoints/ *.gz *ubyte diff --git a/docs/source-pytorch/common/checkpointing_expert.rst b/docs/source-pytorch/common/checkpointing_expert.rst index c1859d60ecf52..c4a948a34cb9d 100644 --- a/docs/source-pytorch/common/checkpointing_expert.rst +++ b/docs/source-pytorch/common/checkpointing_expert.rst @@ -6,7 +6,12 @@ Checkpointing (expert) ###################### -TODO: I don't understand this... +********************************* +Writing your own Checkpoint class +********************************* + +We provide ``Checkpoint`` class, for easier subclassing. Users may want to subclass this class in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback. + *********************** Customize Checkpointing @@ -23,6 +28,8 @@ and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` met what's saved in the checkpoint. +TODO: I don't understand this... + ****************************** Built-in Checkpoint IO Plugins ****************************** diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 28695785c367c..38da5a36a40ab 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -73,6 +73,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added breaking of lazy graph across training, validation, test and predict steps when training with habana accelerators to ensure better performance ([#12938](https://github.com/PyTorchLightning/pytorch-lightning/pull/12938)) +- Added `Checkpoint` class to inherit from ([#13024](https://github.com/PyTorchLightning/pytorch-lightning/pull/13024)) + + - Added CPU metric tracking to `DeviceStatsMonitor` ([#11795](https://github.com/PyTorchLightning/pytorch-lightning/pull/11795)) diff --git a/src/pytorch_lightning/callbacks/__init__.py b/src/pytorch_lightning/callbacks/__init__.py index 6e37b84ce204a..b3d2035f33496 100644 --- a/src/pytorch_lightning/callbacks/__init__.py +++ b/src/pytorch_lightning/callbacks/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.callbacks.checkpoint import Checkpoint from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning @@ -32,6 +33,7 @@ "BackboneFinetuning", "BaseFinetuning", "Callback", + "Checkpoint", "DeviceStatsMonitor", "EarlyStopping", "GradientAccumulationScheduler", diff --git a/src/pytorch_lightning/callbacks/checkpoint.py b/src/pytorch_lightning/callbacks/checkpoint.py new file mode 100644 index 0000000000000..405f29876c6fc --- /dev/null +++ b/src/pytorch_lightning/callbacks/checkpoint.py @@ -0,0 +1,9 @@ +from pytorch_lightning.callbacks.callback import Callback + + +class Checkpoint(Callback): + r""" + This is the base class for model checkpointing. Expert users may want to subclass it in case of writing + custom :class:`~pytorch_lightning.callbacksCheckpoint` callback, so that + the trainer recognizes the custom class as a checkpointing callback. + """ diff --git a/src/pytorch_lightning/callbacks/fault_tolerance.py b/src/pytorch_lightning/callbacks/fault_tolerance.py index 59b8d31f46506..9d04fc86b62ce 100644 --- a/src/pytorch_lightning/callbacks/fault_tolerance.py +++ b/src/pytorch_lightning/callbacks/fault_tolerance.py @@ -21,11 +21,11 @@ from typing import Any import pytorch_lightning as pl -from pytorch_lightning import Callback +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.utilities.types import _PATH -class _FaultToleranceCheckpoint(Callback): +class _FaultToleranceCheckpoint(Checkpoint): """Used to save a fault-tolerance checkpoint on exception.""" FILE_EXTENSION = ".ckpt" diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index 8522bb49b7292..bb6d0a9a9b0b6 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -34,7 +34,7 @@ from torch import Tensor import pytorch_lightning as pl -from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.logger import _name, _version @@ -46,7 +46,7 @@ warning_cache = WarningCache() -class ModelCheckpoint(Callback): +class ModelCheckpoint(Checkpoint): r""" Save the model periodically by monitoring a quantity. Every metric logged with :meth:`~pytorch_lightning.core.module.log` or :meth:`~pytorch_lightning.core.module.log_dict` in diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index c1eecb93fc8bf..d532aae413650 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -25,7 +25,7 @@ import numpy as np import pytorch_lightning as pl -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only @@ -86,7 +86,7 @@ def __init__( else: self._agg_default_func = np.mean - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: """Called after model checkpoint callback saves a new checkpoint. Args: @@ -221,7 +221,7 @@ def __init__(self, logger_iterable: Iterable[Logger]): def __getitem__(self, index: int) -> Logger: return list(self._logger_iterable)[index] - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: for logger in self._logger_iterable: logger.after_save_checkpoint(checkpoint_callback) diff --git a/src/pytorch_lightning/loggers/neptune.py b/src/pytorch_lightning/loggers/neptune.py index 4d2f6897a21aa..44ae3f0f5bfdc 100644 --- a/src/pytorch_lightning/loggers/neptune.py +++ b/src/pytorch_lightning/loggers/neptune.py @@ -31,7 +31,7 @@ from torch import Tensor from pytorch_lightning import __version__ -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9 from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params @@ -534,7 +534,7 @@ def log_model_summary(self, model, max_depth=-1): ) @rank_zero_only - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: """Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint. Args: @@ -547,19 +547,20 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints") # save last model - if checkpoint_callback.last_model_path: + if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path: model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback) file_names.add(model_last_name) self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path) # save best k models - for key in checkpoint_callback.best_k_models.keys(): - model_name = self._get_full_model_name(key, checkpoint_callback) - file_names.add(model_name) - self.run[f"{checkpoints_namespace}/{model_name}"].upload(key) + if hasattr(checkpoint_callback, "best_k_models"): + for key in checkpoint_callback.best_k_models.keys(): + model_name = self._get_full_model_name(key, checkpoint_callback) + file_names.add(model_name) + self.run[f"{checkpoints_namespace}/{model_name}"].upload(key) # log best model path and checkpoint - if checkpoint_callback.best_model_path: + if hasattr(checkpoint_callback, "best_model_path") and checkpoint_callback.best_model_path: self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback) @@ -575,19 +576,22 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo del self.run[f"{checkpoints_namespace}/{file_to_drop}"] # log best model score - if checkpoint_callback.best_model_score: + if hasattr(checkpoint_callback, "best_model_score") and checkpoint_callback.best_model_score: self.run[self._construct_path_with_prefix("model/best_model_score")] = ( checkpoint_callback.best_model_score.cpu().detach().numpy() ) @staticmethod - def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> str: + def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[Checkpoint]") -> str: """Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`.""" - expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}" - if not model_path.startswith(expected_model_path): - raise ValueError(f"{model_path} was expected to start with {expected_model_path}.") - # Remove extension from filepath - filepath, _ = os.path.splitext(model_path[len(expected_model_path) :]) + if hasattr(checkpoint_callback, "dirpath"): + expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}" + if not model_path.startswith(expected_model_path): + raise ValueError(f"{model_path} was expected to start with {expected_model_path}.") + # Remove extension from filepath + filepath, _ = os.path.splitext(model_path[len(expected_model_path) :]) + else: + filepath = model_path return filepath diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py index 53103dfdfd154..88439cd9435db 100644 --- a/src/pytorch_lightning/loggers/wandb.py +++ b/src/pytorch_lightning/loggers/wandb.py @@ -23,7 +23,7 @@ import torch.nn as nn -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _WANDB_GREATER_EQUAL_0_10_22, _WANDB_GREATER_EQUAL_0_12_10 @@ -461,9 +461,14 @@ def version(self) -> Optional[str]: # don't create an experiment if we don't have one return self._experiment.id if self._experiment else self._id - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: # log checkpoints as artifacts - if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: + if ( + self._log_model == "all" + or self._log_model is True + and hasattr(checkpoint_callback, "save_top_k") + and checkpoint_callback.save_top_k == -1 + ): self._scan_and_log_checkpoints(checkpoint_callback) elif self._log_model is True: self._checkpoint_callback = checkpoint_callback @@ -474,25 +479,33 @@ def finalize(self, status: str) -> None: if self._checkpoint_callback: self._scan_and_log_checkpoints(self._checkpoint_callback) - def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: # get checkpoints to be saved with associated score - checkpoints = { - checkpoint_callback.last_model_path: checkpoint_callback.current_score, - checkpoint_callback.best_model_path: checkpoint_callback.best_model_score, - **checkpoint_callback.best_k_models, - } - checkpoints = sorted((Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file()) + checkpoints = dict() + if hasattr(checkpoint_callback, "last_model_path") and hasattr(checkpoint_callback, "current_score"): + checkpoints[checkpoint_callback.last_model_path] = (checkpoint_callback.current_score, "latest") + + if hasattr(checkpoint_callback, "best_model_path") and hasattr(checkpoint_callback, "best_model_score"): + checkpoints[checkpoint_callback.best_model_path] = (checkpoint_callback.best_model_score, "best") + + if hasattr(checkpoint_callback, "best_k_models"): + for key, value in checkpoint_callback.best_k_models.items(): + checkpoints[key] = (value, "best_k") + + checkpoints = sorted( + (Path(p).stat().st_mtime, p, s, tag) for p, (s, tag) in checkpoints.items() if Path(p).is_file() + ) checkpoints = [ c for c in checkpoints if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0] ] # log iteratively all new checkpoints - for t, p, s in checkpoints: + for t, p, s, tag in checkpoints: metadata = ( { "score": s, "original_filename": Path(p).name, - "ModelCheckpoint": { + checkpoint_callback.__class__.__name__: { k: getattr(checkpoint_callback, k) for k in [ "monitor", @@ -511,7 +524,6 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelChe ) artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata) artifact.add_file(p, name="model.ckpt") - aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] - self.experiment.log_artifact(artifact, aliases=aliases) + self.experiment.log_artifact(artifact, aliases=[tag]) # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) self._logged_model_time[p] = t diff --git a/src/pytorch_lightning/strategies/launchers/spawn.py b/src/pytorch_lightning/strategies/launchers/spawn.py index 6af2688e47419..d94909b778a83 100644 --- a/src/pytorch_lightning/strategies/launchers/spawn.py +++ b/src/pytorch_lightning/strategies/launchers/spawn.py @@ -109,7 +109,7 @@ def _wrapping_function( def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: # transfer back the best path to the trainer - if trainer.checkpoint_callback: + if trainer.checkpoint_callback and hasattr(trainer.checkpoint_callback, "best_model_path"): trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path) # TODO: pass also best score @@ -131,7 +131,11 @@ def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_debug("Finalizing the DDP spawn environment.") checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + best_model_path = ( + checkpoint_callback.best_model_path + if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path") + else None + ) # requires to compute the state_dict on all processes in case Metrics are present state_dict = trainer.lightning_module.state_dict() diff --git a/src/pytorch_lightning/strategies/launchers/xla_spawn.py b/src/pytorch_lightning/strategies/launchers/xla_spawn.py index b3e1bf3465203..13c948577ca5b 100644 --- a/src/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/src/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -115,7 +115,11 @@ def _wrapping_function( def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_debug("Finalizing the TPU spawn environment.") checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + best_model_path = ( + checkpoint_callback.best_model_path + if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path") + else None + ) # requires to compute the state_dict on all processes in case Metrics are present state_dict = trainer.lightning_module.state_dict() diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index eddc2e2a84716..83881905beeb1 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -19,6 +19,7 @@ from pytorch_lightning.callbacks import ( Callback, + Checkpoint, GradientAccumulationScheduler, ModelCheckpoint, ModelSummary, @@ -232,18 +233,18 @@ def _attach_model_callbacks(self) -> None: @staticmethod def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: - """Moves all ModelCheckpoint callbacks to the end of the list. The sequential order within the group of + """Moves all Checkpoint callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as the order of all other callbacks. Args: callbacks: A list of callbacks. Return: - A new list in which the last elements are ModelCheckpoints if there were any present in the + A new list in which the last elements are Checkpoint if there were any present in the input. """ - checkpoints = [c for c in callbacks if isinstance(c, ModelCheckpoint)] - not_checkpoints = [c for c in callbacks if not isinstance(c, ModelCheckpoint)] + checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)] + not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)] return not_checkpoints + checkpoints diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index e823ff7e08eb0..7201ef53501c0 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -44,7 +44,7 @@ MPSAccelerator, TPUAccelerator, ) -from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase +from pytorch_lightning.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.optimizer import LightningOptimizer @@ -1406,7 +1406,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.' ) - if not self.checkpoint_callback.best_model_path: + if hasattr(self.checkpoint_callback, "best_model_path") and not self.checkpoint_callback.best_model_path: if self.fast_dev_run: raise MisconfigurationException( f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.' @@ -1416,11 +1416,11 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.' ) # load best weights - ckpt_path = self.checkpoint_callback.best_model_path + ckpt_path = getattr(self.checkpoint_callback, "best_model_path", None) if ckpt_path == "last": - candidates = [ft.ckpt_path for ft in ft_checkpoints] + [ - cb.last_model_path for cb in self.checkpoint_callbacks + candidates = [getattr(ft, "ckpt_path", None) for ft in ft_checkpoints] + [ + getattr(cb, "last_model_path", None) for cb in self.checkpoint_callbacks ] candidates_fs = {path: get_filesystem(path) for path in candidates if path} candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)} @@ -2308,17 +2308,17 @@ def prediction_writer_callbacks(self) -> List[BasePredictionWriter]: return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)] @property - def checkpoint_callback(self) -> Optional[ModelCheckpoint]: + def checkpoint_callback(self) -> Optional[Checkpoint]: """The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.""" callbacks = self.checkpoint_callbacks return callbacks[0] if len(callbacks) > 0 else None @property - def checkpoint_callbacks(self) -> List[ModelCheckpoint]: + def checkpoint_callbacks(self) -> List[Checkpoint]: """A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found in the Trainer.callbacks list.""" - return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + return [c for c in self.callbacks if isinstance(c, Checkpoint)] @property def progress_bar_callback(self) -> Optional[ProgressBarBase]: