Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Datamodule not calling load_state_dict() when loading from checkpoint #14842

Closed
5 tasks done
dconathan opened this issue Sep 22, 2022 · 2 comments · Fixed by #14883
Closed
5 tasks done

Datamodule not calling load_state_dict() when loading from checkpoint #14842

dconathan opened this issue Sep 22, 2022 · 2 comments · Fixed by #14883
Labels
bug Something isn't working checkpointing Related to checkpointing lightningdatamodule pl.LightningDataModule

Comments

@dconathan
Copy link
Contributor

First check

  • I'm sure this is a bug.
  • I've added a descriptive title to this bug.
  • I've provided clear instructions on how to reproduce the bug.
  • I've added a code sample.
  • I've provided any other important info that is required.

Bug description

Sorry if this is part of #14841 but wanted to make sure it gets fixed as part of that if not!

In short, when you call MyDataModule.load_from_checkpoint(checkpoint_path), it doesn't seem like the MyDataModule.load_state_dict() is being called.

How to reproduce the bug

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, LightningDataModule


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringDataModule(LightningDataModule):
    def state_dict(self):
        print("state_dict()")
        return dict()

    def load_state_dict(self, state_dict):
        print("load_state_dict()")
        raise RuntimeError("this should be raised!")

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def test_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():

    datamodule = BoringDataModule()
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, datamodule)
    trainer.test(model, datamodule)
    checkpoint_path = os.path.join(trainer.log_dir, "checkpoints", "epoch=0-step=1.ckpt")
    # should raise the RunTime error from BoringDataModule.load_state_dict() ?
    loaded_datamodule = BoringDataModule.load_from_checkpoint(checkpoint_path)
    assert isinstance(loaded_datamodule, BoringDataModule)


if __name__ == "__main__":
    run()

Error messages and logs

No response

Important info


#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): LightningDataModule
#- PyTorch Lightning Version (e.g., 1.5.0): 1.7.6
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10): 1.12.1
#- Python version (e.g., 3.9): 3.10
#- OS (e.g., Linux): MacOS
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): `pip`
#- Running environment of LightningApp (e.g. local, cloud): `local`

More info

No response

@dconathan dconathan added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Sep 22, 2022
@awaelchli
Copy link
Contributor

@dconathan The code where this happens is here:

https://github.com/Lightning-AI/lightning/blob/abb6049fa37996440185ffe15323ee6340d125db/src/pytorch_lightning/core/saving.py#L228-L233

As we can see, the DataModule is purposefully skipped. This logic was introduce in #12550. The checkpoint entry for datamodule gets dumped here:

https://github.com/Lightning-AI/lightning/blob/abb6049fa37996440185ffe15323ee6340d125db/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L503-L507

We should probably load it using that key. Maybe @rohitgr7 can give more info on this.

@awaelchli awaelchli added checkpointing Related to checkpointing lightningdatamodule pl.LightningDataModule and removed needs triage Waiting to be triaged by maintainers labels Sep 24, 2022
@rohitgr7
Copy link
Contributor

ah! looks like this was missed. Thanks for reporting. Sending a fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing lightningdatamodule pl.LightningDataModule
Projects
None yet
3 participants