-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Datamodule #2668
Datamodule #2668
Conversation
Hello @nateraw! Thanks for updating this PR.
Comment last updated at 2020-07-24 12:50:20 UTC |
Been playing with this script locally...including it here if anybody wants to try running it/breaking it. import os
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import pytorch_lightning as pl
class CoolSystem(pl.LightningModule):
def __init__(self, datamodule):
super().__init__()
self.datamodule = datamodule
self.l1 = nn.Linear(28*28, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
return torch.log_softmax(self.l1(x), dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return {'loss': loss}
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return {'val_loss': loss}
def validation_epoch_end(self, outputs):
# called at the end of the validation epoch
# outputs is an array with what you returned in validation_step for each batch
# outputs = [{'loss': batch_0_loss}, {'loss': batch_1_loss}, ..., {'loss': batch_n_loss}]
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
return {'pred': logits.detach().cpu().numpy()}
def test_epoch_end(self, outputs):
preds = np.concatenate([x["pred"] for x in outputs], axis=0)
self.preds = np.argmax(preds, axis=1)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir=os.getcwd()):
super().__init__()
self.data_dir = data_dir
def setup(self, stage):
# transforms for images
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# prepare transforms standard to MNIST
mnist_train = MNIST(self.data_dir, train=True, download=False, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
self.mnist_test = MNIST(self.data_dir, train=False, download=False, transform=transform)
def prepare_data(self):
# prepare transforms standard to MNIST
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32, num_workers=8)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32, num_workers=8)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32, num_workers=8)
if __name__ == '__main__':
dm = MNISTDataModule()
model = CoolSystem(dm)
trainer = pl.Trainer(gpus=2, max_epochs=3, distributed_backend='ddp')
# trainer = pl.Trainer(max_epochs=3)
trainer.fit(model)
trainer.test() |
@nateraw what if i don't pass in the dm to init, but instead do it in fit. Have we handled updating all the warnings etc? |
@nateraw I see a potential for a refactor here. It looks like the signature and docstrings for all these hooks like train_dataloader, val_dataloader etc. are going to be the same in LightningModule and DataModule. So if that's the case I propose to pull them out into a common interface like "DataHooks" and inherit it in LightningModule and DataModule. Otherwise, if the api changes we have to go and change docs and signature in two places :) What are your thoughts? |
@williamFalcon Warnings aren't updated (will be soon!), but the following 4 cases should all work: dm = MNISTDataModule(**kwargs)
model = CoolSystem(datamodule=dm) # self.datamodule = datamodule in __init__ model = CoolSystem()
model.datamodule = dm model = CoolSystem()
trainer.fit(model, datamodule=dm) # Not sure why you'd want to do it this way, but you can.
model = CoolSystem()
trainer.fit(model)
trainer.test(model, datamodule=dm) If you pass a datamodule on .fit or .test, it will prefer that instead of one found at |
@awaelchli This is a wonderful idea. I'll throw that together. I didn't notice a EDIT - I tried this and didn't like it. This doesn't solve the fact that |
fd80559
to
081b2a8
Compare
Now it's like: if __name__ == '__main__':
# Enforce random seed
pl.seed_everything(42)
# Init your datamodule
# After subclass defined init, prepare_data and setup will run implicitly
dm = MNISTDataModule()
# Get anything from datamodule that you might need, such as input dims
input_dims = np.prod(dm.size())
# Init model
model = CoolSystem(input_dims=input_dims)
# Init Trainer
trainer = pl.Trainer(max_epochs=1)
# Train and test on your datamodule's data
trainer.fit(model, datamodule=dm)
trainer.test() I think what we did with the DataModule wrapper is messing up |
@nateraw did the refactoring with an intermediate hooks class for train_dataloader etc not work out? |
@awaelchli I did it and it worked, but I think we want to think about that more. Decided to keep it simple and not rearrange a bunch of things for this PR 😄 . |
What does this PR do?
Introduces
DataModule
as an optional way to decouple data related hooks fromLightningModule
. Originally implemented inpytorch-lightning-bolts
, but moving to integrate with Lightning here.Work in progress. We can play with the
DataModule
class a bunch more, but for now I'm looking for feedback on how they are handled intrainer.py
.Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃