Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid patching DataHooks #10498

Closed
rohitgr7 opened this issue Nov 12, 2021 · 12 comments · Fixed by #10603
Closed

Avoid patching DataHooks #10498

rohitgr7 opened this issue Nov 12, 2021 · 12 comments · Fixed by #10603
Assignees
Labels
data handling Generic data-related topic refactor
Milestone

Comments

@rohitgr7
Copy link
Contributor

rohitgr7 commented Nov 12, 2021

Proposed refactoring or deprecation

We need to avoid patching some of these datahooks:
https://github.com/PyTorchLightning/pytorch-lightning/blob/09cf167237e867f1ec67a5db87e5a02c2cea4b69/pytorch_lightning/trainer/connectors/data_connector.py#L238-L242

We have already removed the patching for dataloader related hooks here: #9764

Motivation

It can fail in some cases. One example:

CODE

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, LightningDataModule


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringData(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.batch_size = 16
        self.total_len = 1000
        
    def forward(self, x):
        return self.layer(x)
        
    def train_dataloader(self):
         return DataLoader(RandomDataset(32, self.total_len), batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(RandomDataset(32, self.total_len), batch_size=self.batch_size*2)
    
    def transfer_batch_to_device(self, *args, **kwargs):
        print(f'{self.trainer.state.stage}: DataModule.transfer_batch_to_device')
        return super().transfer_batch_to_device(*args, **kwargs)

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.lr = 1e-2
        
    def transfer_batch_to_device(self, *args, **kwargs):
        print(f'{self.trainer.state.stage}: LightningModule.transfer_batch_to_device')
        return super().transfer_batch_to_device(*args, **kwargs)
        
    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        logits = self(batch)
        loss = torch.nn.functional.mse_loss(logits, torch.randn(*logits.shape))
        return loss
    
    def validation_step(self, batch, batch_idx):
        logits = self(batch)
        loss = torch.nn.functional.mse_loss(logits, torch.randn(*logits.shape))
    
    def test_step(self, batch, batch_idx):
        logits = self(batch)
        loss = torch.nn.functional.mse_loss(logits, torch.randn(*logits.shape))

    def configure_optimizers(self):
        opt = torch.optim.SGD(self.layer.parameters(), lr=1e-2)
        return opt

test_dl = DataLoader(RandomDataset(32, 1000), batch_size=16)
model = BoringModel()
dm = BoringData()

trainer = Trainer(fast_dev_run=True)
trainer.fit(model, datamodule=dm)
trainer.test(model, dataloaders=test_dl)

Prints:

train: DataModule.transfer_batch_to_device
validate: DataModule.transfer_batch_to_device
test: DataModule.transfer_batch_to_device

Expected:

train: DataModule.transfer_batch_to_device
validate: DataModule.transfer_batch_to_device
test: LightningModule.transfer_batch_to_device

Pitch

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @justusschock @awaelchli @akihironitta @rohitgr7 @ninginthecloud

@tchaton
Copy link
Contributor

tchaton commented Nov 15, 2021

Sounds good to me.

@ananthsub ananthsub added the data handling Generic data-related topic label Nov 18, 2021
@ananthsub
Copy link
Contributor

If both DataModule and LightningModule provide implementations for transfer_batch_to_device then we select the implementation based on which entity provided the data?

Should this be an error instead if both are overridden and used together? What if the DataModule has separate logic for on_before_batch_transfer and on_after_batch_transfer which is meant to be used with the LightningModule?

@carmocca
Copy link
Contributor

A question about context. Do we know why these were added for both the LightningModule and LightningDataModule? What would be a good example where one would need both?

And if there's one, why did we not do this as:

datamodule.on_before_batch_transfer(batch)
model.on_before_batch_transfer(batch)

datamodule.transfer_batch_to_device(batch)
model.transfer_batch_to_device(batch)

datamodule.on_after_batch_transfer(batch)
model.on_after_batch_transfer(batch)

which would avoid the question of "what should happen when both are implemented"

@ananthsub
Copy link
Contributor

ananthsub commented Nov 19, 2021

@carmocca only one of the data module or lightning module can actually move the batch to device, so transfer_batch_to_device can only be called once.

Other than that I share your questions for what the context for these hooks and this mixin sharing was

@ninginthecloud
Copy link
Contributor

A question about context. Do we know why these were added for both the LightningModule and LightningDataModule? What would be a good example where one would need both?

And if there's one, why did we not do this as:

datamodule.on_before_batch_transfer(batch)
model.on_before_batch_transfer(batch)

datamodule.transfer_batch_to_device(batch)
model.transfer_batch_to_device(batch)

datamodule.on_after_batch_transfer(batch)
model.on_after_batch_transfer(batch)

which would avoid the question of "what should happen when both are implemented"

+1 for your question. @carmocca
Based on your suggestion, I am thinking there are three situations for current hook execution:

  1. Trainer.fit(model)
    There's no confusion in terms of hook's execution.
  2. Trainer.fit(model, train_dataloader)
    We still rely on the hooks from lightningmodule. How about we wrap *_dataloader in a temp datamodule. temp_dm with only *_dataloder() implemented. In this way, only one transfer_batch_to_device from lightningmodule is called. After fit is done, temp_dm will be deleted.
  3. Trainer.fit(model, dm)
    We check if both lightningmodule and datamodule override transfer_batch_to_device, let's call the one from dm, as user explicitly input, but we should warn users.

situation 3 may not be reasonably addressed, I'm open for better ideas. However, in summary, all three situations do not need hook patching.

@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Nov 19, 2021

If both DataModule and LightningModule provide implementations for transfer_batch_to_device then we select the implementation based on which entity provided the data?

yes, if dataloaders are present in the datamodule then we rely on the implementation provided under datamodule else LightningModule.

Should this be an error instead if both are overridden and used together? What if the DataModule has separate logic for on_before_batch_transfer and on_after_batch_transfer which is meant to be used with the LightningModule?

@carmocca @ananthsub I'd say a warning in such a case would be good enough and we can prioritize datamodule in such a case. The reason being, a user can import a LightningModule component from some other package where let's say transfer_batch_to_device is overridden and now they implement a datamodule with some augmentation logic inside on_after_batch_transfer. Now it would not be easy for them to avoid transfer_batch_to_device from their LightningModule, because even if they disable it, we still call it an override, but if they re-implement the same logic within their datamodule, it will work with just a warning.

  1. Trainer.fit(model, train_dataloader)
    We still rely on the hooks from lightningmodule. How about we wrap *_dataloader in a temp datamodule. temp_dm with only *_dataloder() implemented. In this way, only one transfer_batch_to_device from lightningmodule is called. After fit is done, temp_dm will be deleted.
  2. Trainer.fit(model, dm)
    We check if both lightningmodule and datamodule override transfer_batch_to_device, let's call the one from dm, as user explicitly input, but we should warn users.

@ninginthecloud if we go with this then it would be hard to know in the second case whether to use datamodule or lightning module for these hooks.

@ninginthecloud
Copy link
Contributor

Trainer.fit(model, train_dataloader)
We still rely on the hooks from lightningmodule. How about we wrap *_dataloader in a temp datamodule. temp_dm with only *_dataloder() implemented. In this way, only one transfer_batch_to_device from lightningmodule is called. After fit is done, temp_dm will be deleted.

Hi, @rohitgr7 For the second situation, since the temporary datamodule does not override hooks like transfer_batch_to_device, so we just use whatever we have from lightningmodule.

My initial idea is based on @carmocca suggested about executing hooks from dm and lightningmodule in order. How about

if is_overridden("on_before_batch_transfer", datamodule):
   datamodule.on_before_batch_transfer(batch)
   elif is_overridden("on_before_batch_transfer", lightningmodule):
         lightningmodule.on_before_batch_transfer(batch)

Here, datamodule could come from user input trainer.fit(model, datamodule) or it could be a temporary datamodule generated based train_dataloader from trainer.fit(model, train_dataloader).

In this way, there are several pros: 1) avoid patching and de-patching hooks to lightningmodule 2) the way to process datamodule and dataloader could be consistent 3) for all the hooks, we know which one get executed.

@rohitgr7
Copy link
Contributor Author

sounds good to me 😃 .

I'll add the logic now in the attached PR, we can integrate a temp datamodule in another one maybe.

@stale
Copy link

stale bot commented Dec 27, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Dec 27, 2021
@carmocca carmocca removed the won't fix This will not be worked on label Jan 3, 2022
@ninginthecloud
Copy link
Contributor

Hi, @rohitgr7, I'd like to follow up on this issue. Are there any update? I'm wondering if there's anything I can do here.

@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Jan 6, 2022

hey @ninginthecloud !
there is an open PR for this fix: #10603
would like to get your review there. Although there are some conflicts there, will update them tomorrow.

Regarding your proposal of creating a team datamdodule if dataloaders are passed explicitly, it can be configured easily with some additional changes once this issue gets fixed.

@stale
Copy link

stale bot commented Feb 6, 2022

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Feb 6, 2022
@carmocca carmocca removed the won't fix This will not be worked on label Feb 7, 2022
@carmocca carmocca added this to the 1.6 milestone Feb 7, 2022
@carmocca carmocca moved this to In Progress in Frameworks Planning Feb 16, 2022
@carmocca carmocca moved this from In Progress to In Review in Frameworks Planning Feb 21, 2022
Repository owner moved this from In Review to Done in Frameworks Planning Feb 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data handling Generic data-related topic refactor
Projects
No open projects
Status: Done
Development

Successfully merging a pull request may close this issue.

5 participants