Skip to content

Commit

Permalink
Add BaseModelCheckpoint class to inherit from (Lightning-AI#13024)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
5 people authored and jerome-habana committed Jul 14, 2022
1 parent 1e245a9 commit 284d95c
Show file tree
Hide file tree
Showing 14 changed files with 99 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ ENV/
Datasets/
mnist/
MNIST/
legacy/checkpoints/
tests/legacy/checkpoints/
*.gz
*ubyte

Expand Down
9 changes: 8 additions & 1 deletion docs/source-pytorch/common/checkpointing_expert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
Checkpointing (expert)
######################

TODO: I don't understand this...
*********************************
Writing your own Checkpoint class
*********************************

We provide ``Checkpoint`` class, for easier subclassing. Users may want to subclass this class in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback.


***********************
Customize Checkpointing
Expand All @@ -23,6 +28,8 @@ and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` met
what's saved in the checkpoint.


TODO: I don't understand this...

******************************
Built-in Checkpoint IO Plugins
******************************
Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added breaking of lazy graph across training, validation, test and predict steps when training with habana accelerators to ensure better performance ([#12938](https://github.com/PyTorchLightning/pytorch-lightning/pull/12938))


- Added `Checkpoint` class to inherit from ([#13024](https://github.com/PyTorchLightning/pytorch-lightning/pull/13024))


- Added CPU metric tracking to `DeviceStatsMonitor` ([#11795](https://github.com/PyTorchLightning/pytorch-lightning/pull/11795))


Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.callbacks.checkpoint import Checkpoint
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
Expand All @@ -32,6 +33,7 @@
"BackboneFinetuning",
"BaseFinetuning",
"Callback",
"Checkpoint",
"DeviceStatsMonitor",
"EarlyStopping",
"GradientAccumulationScheduler",
Expand Down
9 changes: 9 additions & 0 deletions src/pytorch_lightning/callbacks/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pytorch_lightning.callbacks.callback import Callback


class Checkpoint(Callback):
r"""
This is the base class for model checkpointing. Expert users may want to subclass it in case of writing
custom :class:`~pytorch_lightning.callbacksCheckpoint` callback, so that
the trainer recognizes the custom class as a checkpointing callback.
"""
4 changes: 2 additions & 2 deletions src/pytorch_lightning/callbacks/fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from typing import Any

import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.utilities.types import _PATH


class _FaultToleranceCheckpoint(Callback):
class _FaultToleranceCheckpoint(Checkpoint):
"""Used to save a fault-tolerance checkpoint on exception."""

FILE_EXTENSION = ".ckpt"
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from torch import Tensor

import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _name, _version
Expand All @@ -46,7 +46,7 @@
warning_cache = WarningCache()


class ModelCheckpoint(Callback):
class ModelCheckpoint(Checkpoint):
r"""
Save the model periodically by monitoring a quantity. Every metric logged with
:meth:`~pytorch_lightning.core.module.log` or :meth:`~pytorch_lightning.core.module.log_dict` in
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/loggers/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np

import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only


Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(
else:
self._agg_default_func = np.mean

def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
"""Called after model checkpoint callback saves a new checkpoint.
Args:
Expand Down Expand Up @@ -221,7 +221,7 @@ def __init__(self, logger_iterable: Iterable[Logger]):
def __getitem__(self, index: int) -> Logger:
return list(self._logger_iterable)[index]

def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
for logger in self._logger_iterable:
logger.after_save_checkpoint(checkpoint_callback)

Expand Down
34 changes: 19 additions & 15 deletions src/pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch import Tensor

from pytorch_lightning import __version__
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
Expand Down Expand Up @@ -534,7 +534,7 @@ def log_model_summary(self, model, max_depth=-1):
)

@rank_zero_only
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
Args:
Expand All @@ -547,19 +547,20 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints")

# save last model
if checkpoint_callback.last_model_path:
if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path:
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
file_names.add(model_last_name)
self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)

# save best k models
for key in checkpoint_callback.best_k_models.keys():
model_name = self._get_full_model_name(key, checkpoint_callback)
file_names.add(model_name)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)
if hasattr(checkpoint_callback, "best_k_models"):
for key in checkpoint_callback.best_k_models.keys():
model_name = self._get_full_model_name(key, checkpoint_callback)
file_names.add(model_name)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)

# log best model path and checkpoint
if checkpoint_callback.best_model_path:
if hasattr(checkpoint_callback, "best_model_path") and checkpoint_callback.best_model_path:
self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path

model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
Expand All @@ -575,19 +576,22 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
del self.run[f"{checkpoints_namespace}/{file_to_drop}"]

# log best model score
if checkpoint_callback.best_model_score:
if hasattr(checkpoint_callback, "best_model_score") and checkpoint_callback.best_model_score:
self.run[self._construct_path_with_prefix("model/best_model_score")] = (
checkpoint_callback.best_model_score.cpu().detach().numpy()
)

@staticmethod
def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> str:
def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[Checkpoint]") -> str:
"""Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`."""
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
if not model_path.startswith(expected_model_path):
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
# Remove extension from filepath
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
if hasattr(checkpoint_callback, "dirpath"):
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
if not model_path.startswith(expected_model_path):
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
# Remove extension from filepath
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
else:
filepath = model_path

return filepath

Expand Down
40 changes: 26 additions & 14 deletions src/pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torch.nn as nn

from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _WANDB_GREATER_EQUAL_0_10_22, _WANDB_GREATER_EQUAL_0_12_10
Expand Down Expand Up @@ -461,9 +461,14 @@ def version(self) -> Optional[str]:
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else self._id

def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
# log checkpoints as artifacts
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
if (
self._log_model == "all"
or self._log_model is True
and hasattr(checkpoint_callback, "save_top_k")
and checkpoint_callback.save_top_k == -1
):
self._scan_and_log_checkpoints(checkpoint_callback)
elif self._log_model is True:
self._checkpoint_callback = checkpoint_callback
Expand All @@ -474,25 +479,33 @@ def finalize(self, status: str) -> None:
if self._checkpoint_callback:
self._scan_and_log_checkpoints(self._checkpoint_callback)

def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
# get checkpoints to be saved with associated score
checkpoints = {
checkpoint_callback.last_model_path: checkpoint_callback.current_score,
checkpoint_callback.best_model_path: checkpoint_callback.best_model_score,
**checkpoint_callback.best_k_models,
}
checkpoints = sorted((Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file())
checkpoints = dict()
if hasattr(checkpoint_callback, "last_model_path") and hasattr(checkpoint_callback, "current_score"):
checkpoints[checkpoint_callback.last_model_path] = (checkpoint_callback.current_score, "latest")

if hasattr(checkpoint_callback, "best_model_path") and hasattr(checkpoint_callback, "best_model_score"):
checkpoints[checkpoint_callback.best_model_path] = (checkpoint_callback.best_model_score, "best")

if hasattr(checkpoint_callback, "best_k_models"):
for key, value in checkpoint_callback.best_k_models.items():
checkpoints[key] = (value, "best_k")

checkpoints = sorted(
(Path(p).stat().st_mtime, p, s, tag) for p, (s, tag) in checkpoints.items() if Path(p).is_file()
)
checkpoints = [
c for c in checkpoints if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0]
]

# log iteratively all new checkpoints
for t, p, s in checkpoints:
for t, p, s, tag in checkpoints:
metadata = (
{
"score": s,
"original_filename": Path(p).name,
"ModelCheckpoint": {
checkpoint_callback.__class__.__name__: {
k: getattr(checkpoint_callback, k)
for k in [
"monitor",
Expand All @@ -511,7 +524,6 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelChe
)
artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata)
artifact.add_file(p, name="model.ckpt")
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
self.experiment.log_artifact(artifact, aliases=aliases)
self.experiment.log_artifact(artifact, aliases=[tag])
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
self._logged_model_time[p] = t
8 changes: 6 additions & 2 deletions src/pytorch_lightning/strategies/launchers/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _wrapping_function(

def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None:
# transfer back the best path to the trainer
if trainer.checkpoint_callback:
if trainer.checkpoint_callback and hasattr(trainer.checkpoint_callback, "best_model_path"):
trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path)

# TODO: pass also best score
Expand All @@ -131,7 +131,11 @@ def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_debug("Finalizing the DDP spawn environment.")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
best_model_path = (
checkpoint_callback.best_model_path
if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path")
else None
)

# requires to compute the state_dict on all processes in case Metrics are present
state_dict = trainer.lightning_module.state_dict()
Expand Down
6 changes: 5 additions & 1 deletion src/pytorch_lightning/strategies/launchers/xla_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ def _wrapping_function(
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_debug("Finalizing the TPU spawn environment.")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
best_model_path = (
checkpoint_callback.best_model_path
if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path")
else None
)

# requires to compute the state_dict on all processes in case Metrics are present
state_dict = trainer.lightning_module.state_dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pytorch_lightning.callbacks import (
Callback,
Checkpoint,
GradientAccumulationScheduler,
ModelCheckpoint,
ModelSummary,
Expand Down Expand Up @@ -232,18 +233,18 @@ def _attach_model_callbacks(self) -> None:

@staticmethod
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
"""Moves all ModelCheckpoint callbacks to the end of the list. The sequential order within the group of
"""Moves all Checkpoint callbacks to the end of the list. The sequential order within the group of
checkpoint callbacks is preserved, as well as the order of all other callbacks.
Args:
callbacks: A list of callbacks.
Return:
A new list in which the last elements are ModelCheckpoints if there were any present in the
A new list in which the last elements are Checkpoint if there were any present in the
input.
"""
checkpoints = [c for c in callbacks if isinstance(c, ModelCheckpoint)]
not_checkpoints = [c for c in callbacks if not isinstance(c, ModelCheckpoint)]
checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)]
not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)]
return not_checkpoints + checkpoints


Expand Down
16 changes: 8 additions & 8 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
MPSAccelerator,
TPUAccelerator,
)
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -1406,7 +1406,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.'
)

if not self.checkpoint_callback.best_model_path:
if hasattr(self.checkpoint_callback, "best_model_path") and not self.checkpoint_callback.best_model_path:
if self.fast_dev_run:
raise MisconfigurationException(
f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.'
Expand All @@ -1416,11 +1416,11 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
)
# load best weights
ckpt_path = self.checkpoint_callback.best_model_path
ckpt_path = getattr(self.checkpoint_callback, "best_model_path", None)

if ckpt_path == "last":
candidates = [ft.ckpt_path for ft in ft_checkpoints] + [
cb.last_model_path for cb in self.checkpoint_callbacks
candidates = [getattr(ft, "ckpt_path", None) for ft in ft_checkpoints] + [
getattr(cb, "last_model_path", None) for cb in self.checkpoint_callbacks
]
candidates_fs = {path: get_filesystem(path) for path in candidates if path}
candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)}
Expand Down Expand Up @@ -2308,17 +2308,17 @@ def prediction_writer_callbacks(self) -> List[BasePredictionWriter]:
return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)]

@property
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
def checkpoint_callback(self) -> Optional[Checkpoint]:
"""The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the
Trainer.callbacks list, or ``None`` if it doesn't exist."""
callbacks = self.checkpoint_callbacks
return callbacks[0] if len(callbacks) > 0 else None

@property
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
def checkpoint_callbacks(self) -> List[Checkpoint]:
"""A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found
in the Trainer.callbacks list."""
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
return [c for c in self.callbacks if isinstance(c, Checkpoint)]

@property
def progress_bar_callback(self) -> Optional[ProgressBarBase]:
Expand Down

0 comments on commit 284d95c

Please sign in to comment.