diff --git a/CHANGELOG.md b/CHANGELOG.md index aa7c4f9b056bc..200db58986e83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,6 +72,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `MisconfigurationException` if user provided `opt_idx` in scheduler config doesn't match with actual optimizer index of its respective optimizer ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247)) +- Added a `_Stateful` support for `LightningDataModule` ([#11637](https://github.com/PyTorchLightning/pytorch-lightning/pull/11637)) + + ### Changed - Set the `prog_bar` flag to False in `LightningModule.log_grad_norm` ([#11472](https://github.com/PyTorchLightning/pytorch-lightning/pull/11472)) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 98e14b128b2ff..ab7e06bd48f58 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -13,7 +13,7 @@ # limitations under the License. """LightningDataModule for loading DataLoaders with ease.""" from argparse import ArgumentParser, Namespace -from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -257,3 +257,19 @@ def test_dataloader(): if test_dataset is not None: datamodule.test_dataloader = test_dataloader return datamodule + + def state_dict(self) -> Dict[str, Any]: + """Called when saving a checkpoint, implement to generate and save datamodule state. + + Returns: + A dictionary containing datamodule state. + """ + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict. + + Args: + state_dict: the datamodule state returned by ``state_dict``. + """ + pass diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5c437bfd889b2..9dfc59ffbac30 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -154,6 +154,8 @@ def restore_datamodule(self) -> None: datamodule = self.trainer.datamodule if datamodule is not None: datamodule.on_load_checkpoint(self._loaded_checkpoint) + if datamodule.__class__.__qualname__ in self._loaded_checkpoint: + datamodule.load_state_dict(self._loaded_checkpoint[datamodule.__class__.__qualname__]) def restore_model(self) -> None: """Restores a model's weights from a PyTorch Lightning checkpoint. @@ -331,7 +333,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: CHECKPOINT_HYPER_PARAMS_KEY: CHECKPOINT_HYPER_PARAMS_TYPE: something_cool_i_want_to_save: anything you define through model.on_save_checkpoint - LightningDataModule.__class__.__name__: pl DataModule's state + LightningDataModule.__class__.__qualname__: pl DataModule's state } """ @@ -385,10 +387,17 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: else: checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) - # give the model a chance to dump a few things + # dump stateful datamodule + datamodule = self.trainer.datamodule + if datamodule is not None: + datamodule_state_dict = datamodule.state_dict() + if datamodule_state_dict: + checkpoint[datamodule.__class__.__qualname__] = datamodule_state_dict + + # on_save_checkpoint hooks model.on_save_checkpoint(checkpoint) - if self.trainer.datamodule is not None: - self.trainer.datamodule.on_save_checkpoint(checkpoint) + if datamodule is not None: + datamodule.on_save_checkpoint(checkpoint) # TODO: remove this in v1.8. environment = self.trainer._accelerator_connector.cluster_environment diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 738083091d93d..22878d834086c 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -196,11 +196,18 @@ def validation_step(self, batch, batch_idx): return out class CustomBoringDataModule(BoringDataModule): + def state_dict(self) -> Dict[str, Any]: + return {"my": "state_dict"} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.my_state_dict = state_dict + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - checkpoint[self.__class__.__name__] = self.__class__.__name__ + checkpoint[self.__class__.__qualname__].update({"on_save": "update"}) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - self.checkpoint_state = checkpoint.get(self.__class__.__name__) + self.checkpoint_state = checkpoint.get(self.__class__.__qualname__).copy() + checkpoint[self.__class__.__qualname__].pop("on_save") reset_seed() dm = CustomBoringDataModule() @@ -220,14 +227,14 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: assert trainer.state.finished, f"Training failed with {trainer.state}" checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0] checkpoint = torch.load(checkpoint_path) - assert dm.__class__.__name__ in checkpoint - assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ + assert dm.__class__.__qualname__ in checkpoint + assert checkpoint[dm.__class__.__qualname__] == {"my": "state_dict", "on_save": "update"} for trainer_fn in TrainerFn: trainer.state.fn = trainer_fn - with mock.patch.object(dm, "on_load_checkpoint") as dm_mock: - trainer._restore_modules_and_callbacks(checkpoint_path) - dm_mock.assert_called_once() + trainer._restore_modules_and_callbacks(checkpoint_path) + assert dm.checkpoint_state == {"my": "state_dict", "on_save": "update"} + assert dm.my_state_dict == {"my": "state_dict"} def test_full_loop(tmpdir): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 5f20d7bb4115a..bf981575979d7 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -866,6 +866,7 @@ def call(hook, fn, *args, **kwargs): dict(name="setup", kwargs=dict(stage="fit")), dict(name="val_dataloader"), dict(name="train_dataloader"), + dict(name="state_dict"), dict(name="on_save_checkpoint", args=(ANY,)), dict(name="teardown", kwargs=dict(stage="fit")), ]