Skip to content

Avoid patching DataHooks #10498

Closed
Closed
@rohitgr7

Description

@rohitgr7

Proposed refactoring or deprecation

We need to avoid patching some of these datahooks:
https://github.com/PyTorchLightning/pytorch-lightning/blob/09cf167237e867f1ec67a5db87e5a02c2cea4b69/pytorch_lightning/trainer/connectors/data_connector.py#L238-L242

We have already removed the patching for dataloader related hooks here: #9764

Motivation

It can fail in some cases. One example:

CODE

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, LightningDataModule


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringData(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.batch_size = 16
        self.total_len = 1000
        
    def forward(self, x):
        return self.layer(x)
        
    def train_dataloader(self):
         return DataLoader(RandomDataset(32, self.total_len), batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(RandomDataset(32, self.total_len), batch_size=self.batch_size*2)
    
    def transfer_batch_to_device(self, *args, **kwargs):
        print(f'{self.trainer.state.stage}: DataModule.transfer_batch_to_device')
        return super().transfer_batch_to_device(*args, **kwargs)

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.lr = 1e-2
        
    def transfer_batch_to_device(self, *args, **kwargs):
        print(f'{self.trainer.state.stage}: LightningModule.transfer_batch_to_device')
        return super().transfer_batch_to_device(*args, **kwargs)
        
    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        logits = self(batch)
        loss = torch.nn.functional.mse_loss(logits, torch.randn(*logits.shape))
        return loss
    
    def validation_step(self, batch, batch_idx):
        logits = self(batch)
        loss = torch.nn.functional.mse_loss(logits, torch.randn(*logits.shape))
    
    def test_step(self, batch, batch_idx):
        logits = self(batch)
        loss = torch.nn.functional.mse_loss(logits, torch.randn(*logits.shape))

    def configure_optimizers(self):
        opt = torch.optim.SGD(self.layer.parameters(), lr=1e-2)
        return opt

test_dl = DataLoader(RandomDataset(32, 1000), batch_size=16)
model = BoringModel()
dm = BoringData()

trainer = Trainer(fast_dev_run=True)
trainer.fit(model, datamodule=dm)
trainer.test(model, dataloaders=test_dl)

Prints:

train: DataModule.transfer_batch_to_device
validate: DataModule.transfer_batch_to_device
test: DataModule.transfer_batch_to_device

Expected:

train: DataModule.transfer_batch_to_device
validate: DataModule.transfer_batch_to_device
test: LightningModule.transfer_batch_to_device

Pitch

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @justusschock @awaelchli @akihironitta @rohitgr7 @ninginthecloud

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions