Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BaseModelCheckpoint class to inherit from #13024

Merged
merged 26 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8a5f6fd
add BaseModelCheckpoint
May 10, 2022
7583e28
update changelog
May 10, 2022
7b1aff3
doc + other places
May 11, 2022
898feaf
Update pytorch_lightning/callbacks/model_checkpoint.py
otaj May 16, 2022
aedbce1
Merge branch 'master' into feature/base_model_checkpoint
May 16, 2022
144a4c6
move docs
May 16, 2022
a7ea8dc
checkpoint
May 25, 2022
4e7d8fc
merge master
May 25, 2022
a366491
Update pytorch_lightning/trainer/connectors/callback_connector.py
otaj May 25, 2022
800043c
Update pytorch_lightning/trainer/connectors/callback_connector.py
otaj May 25, 2022
b761faa
merge master
Jun 21, 2022
06a3e6d
Merge branch 'master' into feature/base_model_checkpoint
Jun 22, 2022
9a1b492
guarding with hasattr
Jun 22, 2022
2faaf6a
Merge branch 'master' into feature/base_model_checkpoint
Jun 22, 2022
1319924
remove unrelated changelog line
Jun 22, 2022
c3248fe
Apply suggestions from code review
otaj Jun 24, 2022
90524ec
Merge branch 'master' into feature/base_model_checkpoint
Jun 24, 2022
5062180
drop changing in app part of the repo
Jun 24, 2022
c627027
correct class in docstring
Jun 24, 2022
fc2b574
merge master
Jun 27, 2022
8361ecd
Merge branch 'master' into feature/base_model_checkpoint
Borda Jun 27, 2022
ff24752
move todo docs
Jun 28, 2022
04365cd
Merge branch 'master' into feature/base_model_checkpoint
Jun 28, 2022
e4776fe
Merge branch 'master' into feature/base_model_checkpoint
carmocca Jun 29, 2022
8a5633f
Merge branch 'master' into feature/base_model_checkpoint
carmocca Jun 29, 2022
bf56731
Merge branch 'master' into feature/base_model_checkpoint
Jun 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added breaking of lazy graph across training, validation, test and predict steps when training with habana accelerators to ensure better performance ([#12938](https://github.com/PyTorchLightning/pytorch-lightning/pull/12938))


-
- Added `BaseModelCheckpoint` class to inherit from ([#13024](https://github.com/PyTorchLightning/pytorch-lightning/pull/13024))

### Changed

Expand Down
7 changes: 7 additions & 0 deletions docs/source/common/checkpointing_expert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,10 @@ Custom Checkpoint IO Plugin
.. note::

Some ``TrainingTypePlugins`` like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable.


*********************************
Writing your own Checkpoint class
*********************************
otaj marked this conversation as resolved.
Show resolved Hide resolved

We provide ``BaseModelCheckpoint`` class, for easier subclassing. Users may want to subclass it in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback.
10 changes: 9 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@
warning_cache = WarningCache()


class ModelCheckpoint(Callback):
class BaseModelCheckpoint(Callback):
otaj marked this conversation as resolved.
Show resolved Hide resolved
otaj marked this conversation as resolved.
Show resolved Hide resolved
r"""
This is the base class for Model checkpointing. Expert users may want to subclass it in case of writing
custom :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback, so that
the trainer recognizes the custom class as a checkpointing callback
otaj marked this conversation as resolved.
Show resolved Hide resolved
"""


class ModelCheckpoint(BaseModelCheckpoint):
r"""
Save the model periodically by monitoring a quantity. Every metric logged with
:meth:`~pytorch_lightning.core.lightning.log` or :meth:`~pytorch_lightning.core.lightning.log_dict` in
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/loggers/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np

import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only


Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(
else:
self._agg_default_func = np.mean

def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None:
"""Called after model checkpoint callback saves a new checkpoint.

Args:
Expand Down Expand Up @@ -236,7 +236,7 @@ def __init__(self, logger_iterable: Iterable[Logger]):
def __getitem__(self, index: int) -> Logger:
return list(self._logger_iterable)[index]

def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None:
for logger in self._logger_iterable:
logger.after_save_checkpoint(checkpoint_callback)

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import torch

from pytorch_lightning import __version__
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
Expand Down Expand Up @@ -533,7 +533,7 @@ def log_model_summary(self, model, max_depth=-1):
)

@rank_zero_only
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.

Args:
Expand Down Expand Up @@ -580,7 +580,7 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
)

@staticmethod
def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> str:
def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> str:
"""Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`."""
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
if not model_path.startswith(expected_model_path):
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torch.nn as nn

from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _WANDB_GREATER_EQUAL_0_10_22, _WANDB_GREATER_EQUAL_0_12_10
Expand Down Expand Up @@ -461,7 +461,7 @@ def version(self) -> Optional[str]:
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else self._id

def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None:
# log checkpoints as artifacts
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
self._scan_and_log_checkpoints(checkpoint_callback)
Expand All @@ -474,7 +474,7 @@ def finalize(self, status: str) -> None:
if self._checkpoint_callback:
self._scan_and_log_checkpoints(self._checkpoint_callback)

def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None:
# get checkpoints to be saved with associated score
checkpoints = {
checkpoint_callback.last_model_path: checkpoint_callback.current_score,
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
RichProgressBar,
TQDMProgressBar,
)
from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities.enums import ModelSummaryMode
Expand Down Expand Up @@ -276,18 +277,18 @@ def _attach_model_callbacks(self) -> None:

@staticmethod
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
"""Moves all ModelCheckpoint callbacks to the end of the list. The sequential order within the group of
"""Moves all BaseModelCheckpoint callbacks to the end of the list. The sequential order within the group of
otaj marked this conversation as resolved.
Show resolved Hide resolved
checkpoint callbacks is preserved, as well as the order of all other callbacks.

Args:
callbacks: A list of callbacks.

Return:
A new list in which the last elements are ModelCheckpoints if there were any present in the
A new list in which the last elements are BaseModelCheckpoints if there were any present in the
otaj marked this conversation as resolved.
Show resolved Hide resolved
input.
"""
checkpoints = [c for c in callbacks if isinstance(c, ModelCheckpoint)]
not_checkpoints = [c for c in callbacks if not isinstance(c, ModelCheckpoint)]
checkpoints = [c for c in callbacks if isinstance(c, BaseModelCheckpoint)]
not_checkpoints = [c for c in callbacks if not isinstance(c, BaseModelCheckpoint)]
return not_checkpoints + checkpoints


Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@

import pytorch_lightning as pl
from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, HPUAccelerator, IPUAccelerator, TPUAccelerator
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks import Callback, EarlyStopping, ProgressBarBase
from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -2309,17 +2310,17 @@ def prediction_writer_callbacks(self) -> List[BasePredictionWriter]:
return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)]

@property
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
def checkpoint_callback(self) -> Optional[BaseModelCheckpoint]:
"""The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the
Trainer.callbacks list, or ``None`` if it doesn't exist."""
callbacks = self.checkpoint_callbacks
return callbacks[0] if len(callbacks) > 0 else None

@property
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
def checkpoint_callbacks(self) -> List[BaseModelCheckpoint]:
"""A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found
in the Trainer.callbacks list."""
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
return [c for c in self.callbacks if isinstance(c, BaseModelCheckpoint)]

@property
def progress_bar_callback(self) -> Optional[ProgressBarBase]:
Expand Down