Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Datamodule #2668

Merged
merged 18 commits into from
Jul 24, 2020
Merged

Datamodule #2668

merged 18 commits into from
Jul 24, 2020

Conversation

nateraw
Copy link
Contributor

@nateraw nateraw commented Jul 22, 2020

What does this PR do?

Introduces DataModule as an optional way to decouple data related hooks from LightningModule. Originally implemented in pytorch-lightning-bolts, but moving to integrate with Lightning here.

Work in progress. We can play with the DataModule class a bunch more, but for now I'm looking for feedback on how they are handled in trainer.py.

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@pep8speaks
Copy link

pep8speaks commented Jul 22, 2020

Hello @nateraw! Thanks for updating this PR.

Line 1249:37: E203 whitespace before ':'
Line 1253:60: E203 whitespace before ':'

Line 16:1: W293 blank line contains whitespace

Comment last updated at 2020-07-24 12:50:20 UTC

@mergify mergify bot requested a review from a team July 22, 2020 06:08
@Borda Borda added feature Is an improvement or enhancement Important labels Jul 22, 2020
@Borda Borda added this to the 0.9.0 milestone Jul 22, 2020
@nateraw
Copy link
Contributor Author

nateraw commented Jul 22, 2020

Been playing with this script locally...including it here if anybody wants to try running it/breaking it.

import os

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import datasets, transforms


import pytorch_lightning as pl


class CoolSystem(pl.LightningModule):
    
    def __init__(self, datamodule):
        super().__init__()
        self.datamodule = datamodule
        self.l1 = nn.Linear(28*28, 10)
        
    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        return torch.log_softmax(self.l1(x), dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        # called at the end of the validation epoch
        # outputs is an array with what you returned in validation_step for each batch
        # outputs = [{'loss': batch_0_loss}, {'loss': batch_1_loss}, ..., {'loss': batch_n_loss}] 
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        return {'pred': logits.detach().cpu().numpy()}
    
    def test_epoch_end(self, outputs):
        preds = np.concatenate([x["pred"] for x in outputs], axis=0)
        self.preds = np.argmax(preds, axis=1)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, data_dir=os.getcwd()):
        super().__init__()
        self.data_dir = data_dir

    def setup(self, stage):

        # transforms for images
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # prepare transforms standard to MNIST
        mnist_train = MNIST(self.data_dir, train=True, download=False, transform=transform)
        self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
        self.mnist_test = MNIST(self.data_dir, train=False, download=False, transform=transform)

    def prepare_data(self):

        # prepare transforms standard to MNIST
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32, num_workers=8)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32, num_workers=8)


if __name__ == '__main__':
    dm = MNISTDataModule()
    model = CoolSystem(dm)
    trainer = pl.Trainer(gpus=2, max_epochs=3, distributed_backend='ddp')
    # trainer = pl.Trainer(max_epochs=3)
    trainer.fit(model)
    trainer.test()

@mergify mergify bot requested a review from a team July 22, 2020 06:54
@mergify mergify bot requested a review from a team July 22, 2020 07:51
@williamFalcon
Copy link
Contributor

@nateraw what if i don't pass in the dm to init, but instead do it in fit. Have we handled updating all the warnings etc?

@awaelchli
Copy link
Contributor

awaelchli commented Jul 22, 2020

@nateraw I see a potential for a refactor here. It looks like the signature and docstrings for all these hooks like train_dataloader, val_dataloader etc. are going to be the same in LightningModule and DataModule. So if that's the case I propose to pull them out into a common interface like "DataHooks" and inherit it in LightningModule and DataModule. Otherwise, if the api changes we have to go and change docs and signature in two places :) What are your thoughts?

@nateraw
Copy link
Contributor Author

nateraw commented Jul 22, 2020

@nateraw what if i don't pass in the dm to init, but instead do it in fit. Have we handled updating all the warnings etc?

@williamFalcon Warnings aren't updated (will be soon!), but the following 4 cases should all work:

dm = MNISTDataModule(**kwargs)
model = CoolSystem(datamodule=dm)  # self.datamodule = datamodule in __init__
model = CoolSystem()
model.datamodule = dm
model = CoolSystem()
trainer.fit(model, datamodule=dm)
# Not sure why you'd want to do it this way, but you can.
model = CoolSystem()
trainer.fit(model)
trainer.test(model, datamodule=dm)

If you pass a datamodule on .fit or .test, it will prefer that instead of one found at model.datamodule. Jirka mentioned adding misconfiguration exception when user trys to pass dataloaders + datamodule, which I think is a great idea too.

@nateraw
Copy link
Contributor Author

nateraw commented Jul 22, 2020

@nateraw I see a potential for a refactor here. It looks like the signature and docstrings for all these hooks like train_dataloader, val_dataloader etc. are going to be the same in LightningModule and DataModule. So if that's the case I propose to pull them out into a common interface like "DataHooks" and inherit it in LightningModule and DataModule. Otherwise, if the api changes we have to go and change docs and signature in two places :) What are your thoughts?

@awaelchli This is a wonderful idea. I'll throw that together. I didn't notice a setup() docstring anywhere...does it exist somewhere that I wasn't seeing?

EDIT - I tried this and didn't like it. This doesn't solve the fact that Callbacks also have same hooks. We can rethink this in a different PR. We'll keep it simple for now.

@nateraw nateraw force-pushed the datamodule branch 2 times, most recently from fd80559 to 081b2a8 Compare July 22, 2020 23:20
@nateraw
Copy link
Contributor Author

nateraw commented Jul 23, 2020

Now it's like:

if __name__ == '__main__':
    
    # Enforce random seed
    pl.seed_everything(42)

    # Init your datamodule
    # After subclass defined init, prepare_data and setup will run implicitly
    dm = MNISTDataModule()
    
    # Get anything from datamodule that you might need, such as input dims
    input_dims = np.prod(dm.size())

    # Init model
    model = CoolSystem(input_dims=input_dims)

    # Init Trainer
    trainer = pl.Trainer(max_epochs=1)

    # Train and test on your datamodule's data
    trainer.fit(model, datamodule=dm)
    trainer.test()

I think what we did with the DataModule wrapper is messing up add_argparse_args, though 🙁

@mergify mergify bot requested a review from a team July 23, 2020 16:21
@nateraw nateraw marked this pull request as ready for review July 23, 2020 21:12
@williamFalcon williamFalcon merged commit 1caf8be into Lightning-AI:master Jul 24, 2020
@awaelchli
Copy link
Contributor

@nateraw did the refactoring with an intermediate hooks class for train_dataloader etc not work out?

@nateraw
Copy link
Contributor Author

nateraw commented Jul 24, 2020

@awaelchli I did it and it worked, but I think we want to think about that more. Decided to keep it simple and not rearrange a bunch of things for this PR 😄 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants