Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for custom DataLoaders when instantiated in *_dataloader hook #12981

Merged
merged 41 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6f92c30
Improve wrapping dataloader
May 5, 2022
b09771b
update changelog
May 5, 2022
5e8b6bd
Update pytorch_lightning/utilities/data.py
otaj May 6, 2022
c1ed132
Update pytorch_lightning/utilities/data.py
otaj May 6, 2022
8438814
Update tests/utilities/test_data.py
otaj May 6, 2022
d4f727c
Update pytorch_lightning/utilities/data.py
otaj May 6, 2022
76ccfcd
apply suggestions
May 6, 2022
c74bbd4
Merge branch 'master' into bugfix/dataloader_wrapper
May 6, 2022
b2c403a
Update pytorch_lightning/utilities/data.py
otaj May 6, 2022
577e626
Update pytorch_lightning/utilities/data.py
otaj May 6, 2022
28a4596
separate dataloader specific parts into own wrapper plus extra test case
May 6, 2022
9f78234
merge master
May 6, 2022
38ec191
Potential change for not relying on dataset
carmocca May 6, 2022
c03378f
Revert "Potential change for not relying on dataset"
May 6, 2022
0674459
new "esoteric" test cases
May 6, 2022
80799c7
Sorry Ota! Changed my mind about separting these 2
carmocca May 6, 2022
1950aee
Filter at beginning
carmocca May 6, 2022
23fb84a
fixing for another ugly corner case
May 9, 2022
e6d91c9
Merge branch 'master' into bugfix/dataloader_wrapper
May 9, 2022
9487135
update changelog
May 10, 2022
4429bf8
Update CHANGELOG.md
otaj May 10, 2022
e010575
Update pytorch_lightning/utilities/data.py
otaj May 10, 2022
ac94aa8
Update pytorch_lightning/utilities/data.py
otaj May 10, 2022
4fed618
apply some suggestions
May 10, 2022
2fd7b53
set operations correctly
May 10, 2022
4c4ed2a
Update pytorch_lightning/utilities/data.py
otaj May 11, 2022
773bbc5
Merge branch 'master' into bugfix/dataloader_wrapper
May 11, 2022
69e3ed0
Merge branch 'master' into bugfix/dataloader_wrapper
May 11, 2022
dafeb04
make fault tolerant training work as much as possible
May 11, 2022
23973da
resolve test plus suggestions
May 11, 2022
8a8e866
Apply suggestion
carmocca May 11, 2022
90b581a
Simple rename
carmocca May 11, 2022
4137d32
Fixes
carmocca May 11, 2022
3e7e254
parametrize tests
May 11, 2022
4f11905
merge master
May 12, 2022
d467530
Merge branch 'master' into bugfix/dataloader_wrapper
akihironitta Jun 1, 2022
341d4d9
Merge branch 'master' into bugfix/dataloader_wrapper
carmocca Jun 14, 2022
9b3bf00
merge master
Jun 20, 2022
e73f07e
add tensor
Jun 20, 2022
9da5fd6
merge master
Jun 21, 2022
3f9552a
Merge branch 'master' into bugfix/dataloader_wrapper
carmocca Jun 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise an error if there are insufficient training batches when using a float value of `limit_train_batches` ([#12885](https://github.com/PyTorchLightning/pytorch-lightning/pull/12885))


- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))
- `DataLoader` instantiated inside a `*_dataloader` hook will not set the passed arguments as attributes anymore ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))


- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))


### Deprecated
Expand Down Expand Up @@ -229,6 +231,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed


- Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))


- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))


Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs
from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -228,9 +228,9 @@ def _convert_to_poptorch_loader(
# the user is returning the `poptorch.DataLoader` directly, don't change anything.
return dataloader

dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler)
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler)
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
dataloader = poptorch.DataLoader(opts, **dl_kwargs)
dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs)
return dataloader

def _handle_gradient_accumulation_steps(self) -> None:
Expand Down
199 changes: 134 additions & 65 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@
from contextlib import contextmanager
from dataclasses import fields
from functools import partial
from itertools import chain
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Tuple, Type, Union

import torch
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler, SequentialSampler
from torch.utils.data import (
BatchSampler,
DataLoader,
Dataset,
IterableDataset,
RandomSampler,
Sampler,
SequentialSampler,
)

import pytorch_lightning as pl
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
Expand Down Expand Up @@ -179,10 +186,10 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
def _update_dataloader(
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
) -> DataLoader:
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode=mode)
dl_cls = type(dataloader)
try:
dataloader = dl_cls(**dl_kwargs)
dataloader = dl_cls(*dl_args, **dl_kwargs)
except TypeError as e:
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
# `__init__` arguments map to one `DataLoader.__init__` argument
Expand All @@ -198,38 +205,62 @@ def _update_dataloader(
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
" This argument was automatically passed to your DataLoader by PyTorch Lightning."
)
raise MisconfigurationException(message) from e
return dataloader


def _get_dataloader_init_kwargs(
def _get_dataloader_init_args_and_kwargs(
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
) -> Dict[str, Any]:
) -> Tuple[Tuple[Any], Dict[str, Any]]:
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")

# get the dataloader instance attributes
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
# not part of `vars`
attrs["multiprocessing_context"] = dataloader.multiprocessing_context
was_wrapped = hasattr(dataloader, "__pl_dl_args")
if was_wrapped:
dl_args = dataloader.__pl_dl_args
dl_kwargs = dataloader.__pl_dl_kwargs
arg_names = dataloader.__pl_dl_arg_names
original_dataset = dataloader.__dataset # we have this saved from _wrap_init
else:
# get the dataloader instance attributes
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
# We cannot be 100% sure the class sets dataset argument. Let's set it to None to be safe
# and hope we can get it from the instance attributes
original_dataset = None
# not part of `vars`
attrs["multiprocessing_context"] = dataloader.multiprocessing_context
arg_names = ()

# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if has_variadic_kwargs:
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
params.update(inspect.signature(DataLoader.__init__).parameters)
del params["self"]

# keep only the params whose default is different to the current attr value
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_defaults.add("dataset")
if was_wrapped:
# if the dataloader was wrapped in a hook, only take arguments with default values
# and assume user passes their kwargs correctly
params.update(
{k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty}
)
else:
params.update(inspect.signature(DataLoader.__init__).parameters)
params.pop("self", None)

if not was_wrapped:
# keep only the params whose default is different to the current attr value
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}

# kwargs to re-construct the dataloader
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
if isinstance(dl_kwargs["dataset"], IterableDataset):
# add `dataset` as it might have been replaced with `*args`
non_defaults.add("dataset")
# kwargs to re-construct the dataloader
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
dl_args = ()

dataset = dl_kwargs.get("dataset", original_dataset)
if isinstance(dataset, IterableDataset):
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None
else:
Expand All @@ -238,40 +269,43 @@ def _get_dataloader_init_kwargs(
required_args = {
p.name
for p in params.values()
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
and p.default is p.empty
and p.name not in dl_kwargs
and p.name not in arg_names
}
# the dataloader has required args which we could not extract from the existing attributes
if required_args:
required_args = sorted(required_args)
dataloader_cls_name = dataloader.__class__.__name__
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args)
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
"`*_dataloader` hook of your module, we will do this for you."
f" Otherwise, define {missing_args_message} inside your `__init__`."
)

if not has_variadic_kwargs:
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = dl_kwargs.keys() - params.keys()
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
if missing_kwargs:
missing_kwargs = sorted(missing_kwargs)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
"add the `__init__` arguments or allow passing `**kwargs`"
)

if _FaultTolerantMode.detect_current_mode().is_automatic:
dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs)
dl_args, dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(
was_wrapped, arg_names, dl_args, dl_kwargs
)

return dl_kwargs
return dl_args, dl_kwargs


def _dataloader_init_kwargs_resolve_sampler(
Expand Down Expand Up @@ -321,30 +355,35 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)


def _wrap_init(init: Callable) -> Callable:
"""Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of
:class:`~torch.utils.data.DataLoader`."""
def _wrap_dataloader_init(init: Callable) -> Callable:
"""Wraps the ``__init__`` method of :class:`~torch.utils.data.DataLoader` in order to enable re-instantiation
of custom subclasses."""

@functools.wraps(init)
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
# We need to inspect `init`, as inspecting `obj.__init__`
# can lead to inspecting the wrong function with multiple inheritance
params = inspect.signature(init).parameters

param_names = [
param_names = tuple(
param.name
for param in params.values()
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
]

cls = type(obj)
for arg_name, arg_value in chain(zip(param_names, args), kwargs.items()):
if hasattr(cls, arg_name) and getattr(cls, arg_name).fset is None:
# the class defines a read-only (no setter) property of this name. it's likely that the implementation
# will set `self._arg_name = arg_value` in `__init__` which is the attribute returned by the `arg_name`
# property so we are fine skipping in that case
continue
setattr(obj, arg_name, arg_value)
)
param_names = param_names[: len(args)]

if not hasattr(obj, "__pl_dl_args"):
obj.__pl_dl_args = args
obj.__pl_dl_kwargs = kwargs
obj.__pl_dl_arg_names = param_names

# We want to use the latest possible value for dataset argument (i.e. ideally what gets passed to DataLoader)
# so that we can be sure, that it will not get changed anymore.
# That is why we are setting this in every `__init__`
if "dataset" in param_names:
setattr(obj, "__dataset", args[param_names.index("dataset")])
elif "dataset" in kwargs:
setattr(obj, "__dataset", kwargs["dataset"])

init(obj, *args, **kwargs)

return wrapper
Expand All @@ -368,33 +407,63 @@ def recurse(cl: Type[Any]) -> None:
def _replace_dataloader_init_method() -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
subclasses = _get_all_subclasses(DataLoader)
for subclass in subclasses:
subclass._old_init = subclass.__init__
subclass.__init__ = _wrap_init(subclass.__init__)
classes = _get_all_subclasses(DataLoader) | {DataLoader}
wrapped = set()
for cls in classes:
if cls.__init__ not in wrapped:
cls._old_init = cls.__init__
cls.__init__ = _wrap_dataloader_init(cls.__init__)
wrapped.add(cls.__init__)
yield
for subclass in subclasses:
subclass.__init__ = subclass._old_init
del subclass._old_init
for cls in classes:
if hasattr(cls, "_old_init"):
cls.__init__ = cls._old_init
del cls._old_init


def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict:
dataset = dl_kwargs["dataset"]
def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset:
if isinstance(dataset, IterableDataset):
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dataset)
elif get_len(dataset) != float("inf"):
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dataset)
return CaptureIterableDataset(dataset=dataset)
if get_len(dataset) != float("inf"):
return CaptureMapDataset(dataset=dataset)
raise RuntimeError("This shouldn't happen, please open an issue on Lightning Github repository.")


def _apply_fault_tolerant_automatic_capture_dataset_wrapper(
was_wrapped: bool, arg_names: Tuple[str, ...], dl_args: Tuple[Any, ...], dl_kwargs: Dict[str, Any]
) -> Tuple[Tuple[str, ...], Dict[str, Any]]:
if "dataset" in dl_kwargs:
dl_kwargs["dataset"] = _wrap_with_capture_dataset(dl_kwargs["dataset"])
elif "dataset" in arg_names:
dataset_idx = arg_names.index("dataset")
dataset = _wrap_with_capture_dataset(dl_args[dataset_idx])
dl_args = dl_args[:dataset_idx] + (dataset,) + dl_args[dataset_idx + 1 :]
else:
raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.")
return dl_kwargs
if was_wrapped:
avoid_message = (
" To avoid this, either pass `DataLoader(dataset=your_dataset)` or the positional dataset argument"
" `DataLoader(your_dataset, ...)`."
)
else:
avoid_message = " To avoid this, define `self.dataset = dataset` inside your DataLoader's `__init__`."

raise MisconfigurationException(
"You enabled automatic Fault Tolerant mode, but we were not able to replace your dataset"
" with Fault Tolerant wrapper, because you have a custom DataLoader." + avoid_message
)

return dl_args, dl_kwargs


def _is_dataloader_shuffled(dataloader: object) -> bool:
if hasattr(dataloader, "shuffle"):
if hasattr(dataloader, "__pl_dl_kwargs"):
# this attribute is not part of PyTorch's DataLoader, but could have been set by
# our `_replace_dataloader_init_method` context manager
return dataloader.shuffle
if "shuffle" in dataloader.__pl_dl_kwargs:
return dataloader.__pl_dl_kwargs["shuffle"]
if "shuffle" in dataloader.__pl_dl_arg_names:
return dataloader.__pl_dl_args[dataloader.__pl_dl_arg_names.index("shuffle")]
if isinstance(dataloader.dataset, IterableDataset):
# shuffling is useless with iterable datasets
return False
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(self, new_arg, *args, **kwargs):
with pytest.raises(
MisconfigurationException,
match=(
r"Trying to inject `DistributedSampler` into the `CustomDataLoader` instance.*"
r"Trying to inject custom `Sampler` into the `CustomDataLoader` instance.*"
r"The missing attributes are \['new_arg'\]"
),
):
Expand Down
Loading