Description
Proposed refactoring or deprecation
We need to avoid patching some of these datahooks:
https://github.com/PyTorchLightning/pytorch-lightning/blob/09cf167237e867f1ec67a5db87e5a02c2cea4b69/pytorch_lightning/trainer/connectors/data_connector.py#L238-L242
We have already removed the patching for dataloader related hooks here: #9764
Motivation
It can fail in some cases. One example:
CODE
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringData(LightningDataModule):
def __init__(self):
super().__init__()
self.batch_size = 16
self.total_len = 1000
def forward(self, x):
return self.layer(x)
def train_dataloader(self):
return DataLoader(RandomDataset(32, self.total_len), batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(RandomDataset(32, self.total_len), batch_size=self.batch_size*2)
def transfer_batch_to_device(self, *args, **kwargs):
print(f'{self.trainer.state.stage}: DataModule.transfer_batch_to_device')
return super().transfer_batch_to_device(*args, **kwargs)
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.lr = 1e-2
def transfer_batch_to_device(self, *args, **kwargs):
print(f'{self.trainer.state.stage}: LightningModule.transfer_batch_to_device')
return super().transfer_batch_to_device(*args, **kwargs)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
logits = self(batch)
loss = torch.nn.functional.mse_loss(logits, torch.randn(*logits.shape))
return loss
def validation_step(self, batch, batch_idx):
logits = self(batch)
loss = torch.nn.functional.mse_loss(logits, torch.randn(*logits.shape))
def test_step(self, batch, batch_idx):
logits = self(batch)
loss = torch.nn.functional.mse_loss(logits, torch.randn(*logits.shape))
def configure_optimizers(self):
opt = torch.optim.SGD(self.layer.parameters(), lr=1e-2)
return opt
test_dl = DataLoader(RandomDataset(32, 1000), batch_size=16)
model = BoringModel()
dm = BoringData()
trainer = Trainer(fast_dev_run=True)
trainer.fit(model, datamodule=dm)
trainer.test(model, dataloaders=test_dl)
Prints:
train: DataModule.transfer_batch_to_device
validate: DataModule.transfer_batch_to_device
test: DataModule.transfer_batch_to_device
Expected:
train: DataModule.transfer_batch_to_device
validate: DataModule.transfer_batch_to_device
test: LightningModule.transfer_batch_to_device
Pitch
Additional context
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @justusschock @awaelchli @akihironitta @rohitgr7 @ninginthecloud