Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove dataloader patching on the LightningModule #9764

Merged
merged 57 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
874614f
draft
awaelchli Sep 29, 2021
6d6b3ea
draft
awaelchli Sep 29, 2021
a10c526
clean up
awaelchli Sep 29, 2021
a1c5537
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2021
5c36fbc
check availability of val/test dataloader
awaelchli Sep 29, 2021
84086e6
availability check / property
awaelchli Sep 29, 2021
5d93ee4
hack around on tpu test
awaelchli Sep 30, 2021
7b19d5e
fix test_dataloaders_reset_and_attach test
awaelchli Sep 30, 2021
1fc7c2a
wip
awaelchli Sep 30, 2021
3e05cfc
specify error message in test
awaelchli Sep 30, 2021
35bb187
fix scale batch size test
awaelchli Sep 30, 2021
37dddfc
remove patch_loader_code check from is_overridden util
awaelchli Sep 30, 2021
e8b8dcb
remove patch_loader_code reference from plugins registry
awaelchli Sep 30, 2021
91c24db
add is_module method
awaelchli Sep 30, 2021
5af46ce
update tests for is_module() check
awaelchli Sep 30, 2021
921e5d8
Merge branch 'master' into feature/remove-dataloader-patching
awaelchli Oct 11, 2021
a8d4a26
update tests
awaelchli Oct 11, 2021
8da6bfc
fix unused imports
awaelchli Oct 11, 2021
791ee39
update unit tests to use attach data function
awaelchli Oct 11, 2021
54c1086
use dataloader from trainer
awaelchli Oct 11, 2021
28b95ee
fix test not using the right dataloader
awaelchli Oct 11, 2021
5d76890
fix test
awaelchli Oct 12, 2021
8d8f817
fix test
awaelchli Oct 12, 2021
7881a5d
remove redundant fixme comment
awaelchli Oct 12, 2021
7916083
remove comment
awaelchli Oct 12, 2021
062de83
Merge branch 'master' into feature/remove-dataloader-patching
awaelchli Oct 12, 2021
7ec1b5c
update
awaelchli Oct 12, 2021
66e1e47
Merge branch 'master' into feature/remove-dataloader-patching
awaelchli Oct 12, 2021
abc8bf6
rename dataloader source
awaelchli Oct 12, 2021
5587ea4
typing dataloaders
awaelchli Oct 12, 2021
4baa4da
add docs
awaelchli Oct 12, 2021
6cca816
update changelog
awaelchli Oct 12, 2021
caa869b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2021
ccbe477
add unit tests
awaelchli Oct 12, 2021
3628ba6
Merge remote-tracking branch 'origin/feature/remove-dataloader-patchi…
awaelchli Oct 12, 2021
909489c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2021
8dbb918
delete methods
awaelchli Oct 12, 2021
bcd376e
Merge remote-tracking branch 'origin/feature/remove-dataloader-patchi…
awaelchli Oct 12, 2021
c496779
address fixme
awaelchli Oct 12, 2021
ff906c1
val sanity
rohitgr7 Oct 12, 2021
fb0f347
val sanity
rohitgr7 Oct 12, 2021
2f2a431
is_available
awaelchli Oct 14, 2021
f3bb2fb
simplify
awaelchli Oct 14, 2021
a68bbe3
use call_hook() for LightningModule
awaelchli Oct 14, 2021
02d01e4
Merge branch 'master' into feature/remove-dataloader-patching
awaelchli Oct 14, 2021
c355a9d
ensure model has a trainer in unit tests
awaelchli Oct 14, 2021
3052405
fix deepspeed dl request
rohitgr7 Oct 14, 2021
535f423
Apply suggestions from code review
awaelchli Oct 14, 2021
5f9b699
Update pytorch_lightning/utilities/model_helpers.py
awaelchli Oct 14, 2021
2e0496e
rename is_available
awaelchli Oct 14, 2021
c264368
Merge branch 'master' into feature/remove-dataloader-patching
awaelchli Oct 20, 2021
c56eb60
resolve merge error
awaelchli Oct 20, 2021
7fc97ed
address reviews
awaelchli Oct 20, 2021
c194b44
rename reqest -> dataloader
awaelchli Oct 20, 2021
7883b22
update predict check
awaelchli Oct 20, 2021
b8e8034
fix bug in example
awaelchli Oct 20, 2021
4fb6081
add datamodules test() and predict() calls, otherwise loops get skipp…
awaelchli Oct 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `should_rank_save_checkpoint` property from Trainer ([#9433](https://github.com/PyTorchLightning/pytorch-lightning/pull/9433))


- Removed automatic patching of `{train,val,test,predict}_dataloader()` on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))


### Fixed


Expand Down Expand Up @@ -525,6 +528,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Reset `val_dataloader` in `tuner/batch_size_scaling` ([#9857](https://github.com/PyTorchLightning/pytorch-lightning/pull/9857))


- Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))


## [1.4.9] - 2021-09-30

- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/plugins/plugins_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,7 @@ def is_register_plugins_overridden(plugin: type) -> bool:
else:
return False

if hasattr(plugin_attr, "patch_loader_code"):
is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__)
else:
is_overridden = plugin_attr.__code__ is not super_attr.__code__
return is_overridden
return plugin_attr.__code__ is not super_attr.__code__


def call_training_type_register_plugins(root: Path, base_module: str) -> None:
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,9 @@ def _auto_select_batch_size(self):
# train_micro_batch_size_per_gpu is used for throughput logging purposes
# by default we try to use the batch size of the loader
batch_size = 1
if hasattr(self.lightning_module, "train_dataloader"):
train_dataloader = self.lightning_module.train_dataloader()
train_dl_source = self.lightning_module.trainer.data_connector._train_dataloader_source
if train_dl_source.is_available():
train_dataloader = train_dl_source.request()
if hasattr(train_dataloader, "batch_sampler"):
batch_size = train_dataloader.batch_sampler.batch_size
return batch_size
Expand Down
26 changes: 12 additions & 14 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@

import torch
import torch.multiprocessing as mp
from torch.nn import Module
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
_OMEGACONF_AVAILABLE,
Expand Down Expand Up @@ -95,19 +94,18 @@ def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> No
)

@staticmethod
def _validate_patched_dataloaders(model: Module) -> None:
def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
"""Validate and fail fast if the dataloaders were passed directly to fit."""
if hasattr(model, "train_dataloader") and isinstance(model.train_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.train_dataloader.dataloader)

if hasattr(model, "val_dataloader") and isinstance(model.val_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.val_dataloader.dataloader)

if hasattr(model, "test_dataloader") and isinstance(model.test_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.test_dataloader.dataloader)

if hasattr(model, "predict_dataloader") and isinstance(model.predict_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.predict_dataloader.dataloader)
connector: DataConnector = model.trainer.data_connector
sources = (
connector._train_dataloader_source,
connector._val_dataloader_source,
connector._test_dataloader_source,
connector._predict_dataloader_source,
)
for dataloader_source in sources:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if not dataloader_source.is_module():
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
TPUSpawnPlugin._validate_dataloader(dataloader_source.instance)

def connect(self, model: "pl.LightningModule") -> None:
TPUSpawnPlugin._validate_patched_dataloaders(model)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None
# -----------------------------------
# verify model has a train dataloader
# -----------------------------------
has_train_dataloader = is_overridden("train_dataloader", model)
has_train_dataloader = self.trainer.data_connector._train_dataloader_source.is_available()
if not has_train_dataloader:
raise MisconfigurationException(
"No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a"
Expand Down Expand Up @@ -176,7 +176,7 @@ def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: s
)

def __verify_predict_loop_configuration(self, model: "pl.LightningModule") -> None:
has_predict_dataloader = is_overridden("predict_dataloader", model)
has_predict_dataloader = self.trainer.data_connector._predict_dataloader_source.is_available()
if not has_predict_dataloader:
raise MisconfigurationException("Dataloader not found for `Trainer.predict`")
# ----------------------------------------------
Expand Down
132 changes: 76 additions & 56 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from functools import partial
from typing import Callable, Iterable, Optional, Union
from typing import Iterable, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_deprecation
Expand Down Expand Up @@ -47,6 +48,11 @@ def __init__(
self.test_data_fetcher = test_data_fetcher
self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None

self._train_dataloader_source = _DataLoaderSource()
self._val_dataloader_source = _DataLoaderSource()
self._test_dataloader_source = _DataLoaderSource()
self._predict_dataloader_source = _DataLoaderSource()

@property
def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]:
if self.trainer.sanity_checking:
Expand Down Expand Up @@ -190,27 +196,23 @@ def attach_dataloaders(
test_dataloaders: Optional[EVAL_DATALOADERS] = None,
predict_dataloaders: Optional[EVAL_DATALOADERS] = None,
) -> None:
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloaders is not None:
self.trainer.train_dataloader = None
train_dataloader = _PatchDataLoader(train_dataloaders, "train")
train_dataloader.patch(model)

if val_dataloaders is not None:
self.trainer.val_dataloaders = None
val_dataloader = _PatchDataLoader(val_dataloaders, "val")
val_dataloader.patch(model)

if test_dataloaders is not None:
self.trainer.test_dataloaders = None
test_dataloader = _PatchDataLoader(test_dataloaders, "test")
test_dataloader.patch(model)

if predict_dataloaders is not None:
self.trainer.predict_dataloaders = None
predict_dataloader = _PatchDataLoader(predict_dataloaders, "predict")
predict_dataloader.patch(model)
self.trainer.train_dataloader = None
self.trainer.val_dataloaders = None
self.trainer.test_dataloaders = None
self.trainer.predict_dataloaders = None

self._train_dataloader_source = _DataLoaderSource(
train_dataloaders if train_dataloaders is not None else model, "train_dataloader"
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
)
self._val_dataloader_source = _DataLoaderSource(
val_dataloaders if val_dataloaders is not None else model, "val_dataloader"
)
self._test_dataloader_source = _DataLoaderSource(
test_dataloaders if test_dataloaders is not None else model, "test_dataloader"
)
self._predict_dataloader_source = _DataLoaderSource(
predict_dataloaders if predict_dataloaders is not None else model, "predict_dataloader"
)

def attach_datamodule(
self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None
Expand All @@ -219,11 +221,10 @@ def attach_datamodule(
if datamodule is None:
return

# Override loader hooks
dl_methods = ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader")
for method in dl_methods:
if is_overridden(method, datamodule):
setattr(model, method, getattr(datamodule, method))
self._train_dataloader_source = _DataLoaderSource(datamodule, "train_dataloader")
self._val_dataloader_source = _DataLoaderSource(datamodule, "val_dataloader")
self._test_dataloader_source = _DataLoaderSource(datamodule, "test_dataloader")
self._predict_dataloader_source = _DataLoaderSource(datamodule, "predict_dataloader")

# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
Expand All @@ -238,13 +239,6 @@ def attach_datamodule(
if hasattr(datamodule, "data_pipeline"):
model.data_pipeline = datamodule.data_pipeline

@staticmethod
def detach_data(model: "pl.LightningModule") -> None:
for stage in ("train", "val", "test", "predict"):
loader = getattr(model, f"{stage}_dataloader", None)
if isinstance(loader, _PatchDataLoader):
loader.unpatch(model)

def teardown(self) -> None:
if self.train_data_fetcher:
self.train_data_fetcher.teardown()
Expand All @@ -260,32 +254,58 @@ def teardown(self) -> None:
self.sanity_check_data_fetcher = None


class _PatchDataLoader:
r"""
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.
@dataclass
class _DataLoaderSource:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Stores the information where the dataloaders come from.

The source can be

Args:
dataloader: Dataloader object to return when called.
1. from a ``*_datalaoder()`` method on the :class:`~pytorch_lightning.core.lightning.LightningModule`,
2. from a ``*_datalaoder()`` method on the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`,
3. a direct instance of a :class:`~torch.utils.data.DataLoader` or supported collections thereof.

Arguments:
instance: A LightningModule, LightningDataModule, or (a collection of) dataloader(s).
name: A name for this dataloader source. If the instance is a module, the name corresponds to the hook
that returns the desired dataloader(s).
"""

def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], stage: str) -> None:
self.dataloader = dataloader
instance: Optional[
Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]
] = None
name: str = ""
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def request(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the dataloader from the source.

If the source is a module, the method with the corresponding
:attr:`name` gets called.
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import

if not self.name:
return self.instance

if isinstance(self.instance, LightningModule):
return self.instance.trainer.call_hook(self.name, pl_module=self.instance)

if isinstance(self.instance, LightningDataModule):
return getattr(self.instance, self.name)()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

return self.instance

def is_available(self) -> bool:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Returns whether the source dataloader is available.

# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)
self._old_loader: Optional[Callable] = None
self.stage = stage
If the source is a module it checks that the method with given :attr:`name` is overridden.
"""
return not self.is_module() or is_overridden(self.name, self.instance)

def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
return self.dataloader
def is_module(self) -> bool:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Returns whether the the DataLoader source is a LightningModule or a LightningDataModule.

def patch(self, model: "pl.LightningModule") -> None:
self._old_loader = getattr(model, self.stage + "_dataloader")
setattr(model, self.stage + "_dataloader", self)
It does not check whether ``*_dataloader`` methods are actually overridden.
"""
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import

def unpatch(self, model: "pl.LightningModule") -> None:
setattr(model, self.stage + "_dataloader", self._old_loader)
self._old_loader = None
return isinstance(self.instance, (LightningModule, LightningDataModule))
20 changes: 11 additions & 9 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,10 +487,10 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
source = self.data_connector._val_dataloader_source
pl_module = self.lightning_module or model
has_loader = is_overridden("val_dataloader", pl_module)
has_step = is_overridden("validation_step", pl_module)
if has_loader and has_step:
if source.is_available() and has_step:
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(
RunningStage.VALIDATING, model=pl_module
)
Expand All @@ -501,10 +501,10 @@ def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
source = self.data_connector._test_dataloader_source
pl_module = self.lightning_module or model
has_loader = is_overridden("test_dataloader", pl_module)
has_step = is_overridden("test_step", pl_module)
if has_loader and has_step:
if source.is_available() and has_step:
self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader(
RunningStage.TESTING, model=pl_module
)
Expand All @@ -515,9 +515,9 @@ def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None)
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
source = self.data_connector._predict_dataloader_source
pl_module = self.lightning_module or model
has_loader = is_overridden("predict_dataloader", pl_module)
if has_loader:
if source.is_available():
self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(
RunningStage.PREDICTING, model=pl_module
)
Expand All @@ -539,14 +539,16 @@ def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = No
def request_dataloader(
self, stage: RunningStage, model: Optional["pl.LightningModule"] = None
) -> Union[DataLoader, List[DataLoader]]:
"""Handles downloading data in the GPU or TPU case.
"""Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage.

Returns:
The dataloader
The requested dataloader
"""
source = getattr(self.data_connector, f"_{stage.dataloader_prefix}_dataloader_source")

hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = self.call_hook(hook, pl_module=model)
dataloader = source.request()
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.training_type_plugin.barrier("get_dataloaders")
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,9 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
return self.predict_loop.run()

def _run_sanity_check(self, ref_model):
using_val_step = ref_model.val_dataloader is not None and is_overridden("validation_step", ref_model)
using_val_step = self.data_connector._val_dataloader_source.is_available() and is_overridden(
"validation_step", ref_model
)
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

# run tiny validation (if validation defined)
Expand Down Expand Up @@ -1358,8 +1360,6 @@ def _call_teardown_hook(self) -> None:
if self.datamodule is not None:
self.datamodule.teardown(stage=fn)

self.data_connector.detach_data(self.lightning_module)

self.call_hook("teardown", stage=fn)

self.lightning_module._current_fx_name = None
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def scale_batch_size(
" If this is not the intended behavior, please remove either one."
)

if hasattr(model.train_dataloader, "patch_loader_code"):
if not trainer.data_connector._train_dataloader_source.is_module():
raise MisconfigurationException(
"The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`."
" Please disable the feature or incorporate the dataloader into the model."
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/utilities/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,7 @@ def is_overridden(
if parent_attr is None:
raise ValueError("The parent should define the method")

# cannot pickle `__code__` so cannot verify if `PatchDataloader`
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
instance_code = getattr(instance_attr, "patch_loader_code", None) or instance_attr.__code__
instance_code = instance_attr.__code__
parent_code = parent_attr.__code__

return instance_code != parent_code
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
Loading