Skip to content

Commit

Permalink
remove dataloader patching on the LightningModule (#9764)
Browse files Browse the repository at this point in the history

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
4 people authored Oct 20, 2021
1 parent 6701526 commit 2c16f1d
Show file tree
Hide file tree
Showing 19 changed files with 198 additions and 125 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `process_idx` from the `{DDPSpawnPlugin,TPUSpawnPlugin}.new_process` methods ([#10022](https://github.com/PyTorchLightning/pytorch-lightning/pull/10022))


- Removed automatic patching of `{train,val,test,predict}_dataloader()` on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))


### Fixed


Expand Down Expand Up @@ -594,6 +597,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).



- Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))


## [1.4.9] - 2021-09-30

- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))
Expand Down
4 changes: 2 additions & 2 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ def cli_main():
trainer_defaults={"callbacks": ImageSampler(), "max_epochs": 10},
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")
predictions = cli.trainer.predict(ckpt_path="best")
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule)
print(predictions[0])


Expand Down
4 changes: 2 additions & 2 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def predict_dataloader(self):
def cli_main():
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")
predictions = cli.trainer.predict(ckpt_path="best")
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule)
print(predictions[0])


Expand Down
2 changes: 1 addition & 1 deletion pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def cli_main():

cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion pl_examples/basic_examples/simple_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def cli_main():
LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)


if __name__ == "__main__":
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/plugins/plugins_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,7 @@ def is_register_plugins_overridden(plugin: type) -> bool:
else:
return False

if hasattr(plugin_attr, "patch_loader_code"):
is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__)
else:
is_overridden = plugin_attr.__code__ is not super_attr.__code__
return is_overridden
return plugin_attr.__code__ is not super_attr.__code__


def call_training_type_register_plugins(root: Path, base_module: str) -> None:
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,8 +628,9 @@ def _auto_select_batch_size(self):
# train_micro_batch_size_per_gpu is used for throughput logging purposes
# by default we try to use the batch size of the loader
batch_size = 1
if hasattr(self.lightning_module, "train_dataloader"):
train_dataloader = self.lightning_module.train_dataloader()
train_dl_source = self.lightning_module.trainer.data_connector._train_dataloader_source
if train_dl_source.is_defined():
train_dataloader = train_dl_source.dataloader()
if hasattr(train_dataloader, "batch_sampler"):
batch_size = train_dataloader.batch_sampler.batch_size
return batch_size
Expand Down
26 changes: 12 additions & 14 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@

import torch
import torch.multiprocessing as mp
from torch.nn import Module
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
_OMEGACONF_AVAILABLE,
Expand Down Expand Up @@ -96,19 +95,18 @@ def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> No
)

@staticmethod
def _validate_patched_dataloaders(model: Module) -> None:
def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
"""Validate and fail fast if the dataloaders were passed directly to fit."""
if hasattr(model, "train_dataloader") and isinstance(model.train_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.train_dataloader.dataloader)

if hasattr(model, "val_dataloader") and isinstance(model.val_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.val_dataloader.dataloader)

if hasattr(model, "test_dataloader") and isinstance(model.test_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.test_dataloader.dataloader)

if hasattr(model, "predict_dataloader") and isinstance(model.predict_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.predict_dataloader.dataloader)
connector: DataConnector = model.trainer.data_connector
sources = (
connector._train_dataloader_source,
connector._val_dataloader_source,
connector._test_dataloader_source,
connector._predict_dataloader_source,
)
for source in sources:
if not source.is_module():
TPUSpawnPlugin._validate_dataloader(source.instance)

def connect(self, model: "pl.LightningModule") -> None:
TPUSpawnPlugin._validate_patched_dataloaders(model)
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
elif trainer.state.fn == TrainerFn.TESTING:
__verify_eval_loop_configuration(model, "test")
elif trainer.state.fn == TrainerFn.PREDICTING:
__verify_predict_loop_configuration(model)
__verify_predict_loop_configuration(trainer, model)
__verify_dp_batch_transfer_support(trainer, model)
_check_add_get_queue(model)
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
Expand All @@ -65,7 +65,7 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin
# -----------------------------------
# verify model has a train dataloader
# -----------------------------------
has_train_dataloader = is_overridden("train_dataloader", model)
has_train_dataloader = trainer.data_connector._train_dataloader_source.is_defined()
if not has_train_dataloader:
raise MisconfigurationException(
"No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a"
Expand Down Expand Up @@ -175,8 +175,8 @@ def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) ->
)


def __verify_predict_loop_configuration(model: "pl.LightningModule") -> None:
has_predict_dataloader = is_overridden("predict_dataloader", model)
def __verify_predict_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
has_predict_dataloader = trainer.data_connector._predict_dataloader_source.is_defined()
if not has_predict_dataloader:
raise MisconfigurationException("Dataloader not found for `Trainer.predict`")
# ----------------------------------------------
Expand Down
130 changes: 74 additions & 56 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from functools import partial
from typing import Callable, Iterable, Optional, Union
from typing import Iterable, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_deprecation
Expand Down Expand Up @@ -47,6 +48,11 @@ def __init__(
self.test_data_fetcher = test_data_fetcher
self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None

self._train_dataloader_source = _DataLoaderSource(None, "")
self._val_dataloader_source = _DataLoaderSource(None, "")
self._test_dataloader_source = _DataLoaderSource(None, "")
self._predict_dataloader_source = _DataLoaderSource(None, "")

@property
def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]:
if self.trainer.sanity_checking:
Expand Down Expand Up @@ -190,27 +196,23 @@ def attach_dataloaders(
test_dataloaders: Optional[EVAL_DATALOADERS] = None,
predict_dataloaders: Optional[EVAL_DATALOADERS] = None,
) -> None:
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloaders is not None:
self.trainer.train_dataloader = None
train_dataloader = _PatchDataLoader(train_dataloaders, "train")
train_dataloader.patch(model)

if val_dataloaders is not None:
self.trainer.val_dataloaders = None
val_dataloader = _PatchDataLoader(val_dataloaders, "val")
val_dataloader.patch(model)

if test_dataloaders is not None:
self.trainer.test_dataloaders = None
test_dataloader = _PatchDataLoader(test_dataloaders, "test")
test_dataloader.patch(model)

if predict_dataloaders is not None:
self.trainer.predict_dataloaders = None
predict_dataloader = _PatchDataLoader(predict_dataloaders, "predict")
predict_dataloader.patch(model)
self.trainer.train_dataloader = None
self.trainer.val_dataloaders = None
self.trainer.test_dataloaders = None
self.trainer.predict_dataloaders = None

self._train_dataloader_source = _DataLoaderSource(
train_dataloaders if train_dataloaders is not None else model, "train_dataloader"
)
self._val_dataloader_source = _DataLoaderSource(
val_dataloaders if val_dataloaders is not None else model, "val_dataloader"
)
self._test_dataloader_source = _DataLoaderSource(
test_dataloaders if test_dataloaders is not None else model, "test_dataloader"
)
self._predict_dataloader_source = _DataLoaderSource(
predict_dataloaders if predict_dataloaders is not None else model, "predict_dataloader"
)

def attach_datamodule(
self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None
Expand All @@ -219,11 +221,10 @@ def attach_datamodule(
if datamodule is None:
return

# Override loader hooks
dl_methods = ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader")
for method in dl_methods:
if is_overridden(method, datamodule):
setattr(model, method, getattr(datamodule, method))
self._train_dataloader_source = _DataLoaderSource(datamodule, "train_dataloader")
self._val_dataloader_source = _DataLoaderSource(datamodule, "val_dataloader")
self._test_dataloader_source = _DataLoaderSource(datamodule, "test_dataloader")
self._predict_dataloader_source = _DataLoaderSource(datamodule, "predict_dataloader")

# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
Expand All @@ -238,13 +239,6 @@ def attach_datamodule(
if hasattr(datamodule, "data_pipeline"):
model.data_pipeline = datamodule.data_pipeline

@staticmethod
def detach_data(model: "pl.LightningModule") -> None:
for stage in ("train", "val", "test", "predict"):
loader = getattr(model, f"{stage}_dataloader", None)
if isinstance(loader, _PatchDataLoader):
loader.unpatch(model)

def teardown(self) -> None:
if self.train_data_fetcher:
self.train_data_fetcher.teardown()
Expand All @@ -260,32 +254,56 @@ def teardown(self) -> None:
self.sanity_check_data_fetcher = None


class _PatchDataLoader:
r"""
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.
@dataclass
class _DataLoaderSource:
"""Stores the information where the dataloaders come from.
The source can be
Args:
dataloader: Dataloader object to return when called.
1. from a ``*_datalaoder()`` method on the :class:`~pytorch_lightning.core.lightning.LightningModule`,
2. from a ``*_datalaoder()`` method on the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`,
3. a direct instance of a :class:`~torch.utils.data.DataLoader` or supported collections thereof.
Arguments:
instance: A LightningModule, LightningDataModule, or (a collection of) dataloader(s).
name: A name for this dataloader source. If the instance is a module, the name corresponds to the hook
that returns the desired dataloader(s).
"""

def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], stage: str) -> None:
self.dataloader = dataloader
instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]]
name: str

def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
"""Returns the dataloader from the source.
If the source is a module, the method with the corresponding :attr:`name` gets called.
"""
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import

if not self.name:
return self.instance

if isinstance(self.instance, LightningModule):
return self.instance.trainer.call_hook(self.name, pl_module=self.instance)

if isinstance(self.instance, LightningDataModule):
method = getattr(self.instance, self.name)
return method()

return self.instance

def is_defined(self) -> bool:
"""Returns whether the source dataloader can be retrieved or not.
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)
self._old_loader: Optional[Callable] = None
self.stage = stage
If the source is a module it checks that the method with given :attr:`name` is overridden.
"""
return not self.is_module() or is_overridden(self.name, self.instance)

def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
return self.dataloader
def is_module(self) -> bool:
"""Returns whether the the DataLoader source is a LightningModule or a LightningDataModule.
def patch(self, model: "pl.LightningModule") -> None:
self._old_loader = getattr(model, self.stage + "_dataloader")
setattr(model, self.stage + "_dataloader", self)
It does not check whether ``*_dataloader`` methods are actually overridden.
"""
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import

def unpatch(self, model: "pl.LightningModule") -> None:
setattr(model, self.stage + "_dataloader", self._old_loader)
self._old_loader = None
return isinstance(self.instance, (LightningModule, LightningDataModule))
20 changes: 11 additions & 9 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,10 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
source = self.data_connector._val_dataloader_source
pl_module = self.lightning_module or model
has_loader = is_overridden("val_dataloader", pl_module)
has_step = is_overridden("validation_step", pl_module)
if has_loader and has_step:
if source.is_defined() and has_step:
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(
RunningStage.VALIDATING, model=pl_module
)
Expand All @@ -502,10 +502,10 @@ def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
source = self.data_connector._test_dataloader_source
pl_module = self.lightning_module or model
has_loader = is_overridden("test_dataloader", pl_module)
has_step = is_overridden("test_step", pl_module)
if has_loader and has_step:
if source.is_defined() and has_step:
self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader(
RunningStage.TESTING, model=pl_module
)
Expand All @@ -516,9 +516,9 @@ def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None)
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
source = self.data_connector._predict_dataloader_source
pl_module = self.lightning_module or model
has_loader = is_overridden("predict_dataloader", pl_module)
if has_loader:
if source.is_defined():
self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(
RunningStage.PREDICTING, model=pl_module
)
Expand All @@ -540,14 +540,16 @@ def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = No
def request_dataloader(
self, stage: RunningStage, model: Optional["pl.LightningModule"] = None
) -> Union[DataLoader, List[DataLoader]]:
"""Handles downloading data in the GPU or TPU case.
"""Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage.
Returns:
The dataloader
The requested dataloader
"""
source = getattr(self.data_connector, f"_{stage.dataloader_prefix}_dataloader_source")

hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = self.call_hook(hook, pl_module=model)
dataloader = source.dataloader()
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.training_type_plugin.barrier("get_dataloaders")
Expand Down
Loading

0 comments on commit 2c16f1d

Please sign in to comment.