Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for custom DataLoaders when instantiated in *_dataloader hook #12981

Merged
merged 41 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6f92c30
Improve wrapping dataloader
May 5, 2022
b09771b
update changelog
May 5, 2022
5e8b6bd
Update pytorch_lightning/utilities/data.py
otaj May 6, 2022
c1ed132
Update pytorch_lightning/utilities/data.py
otaj May 6, 2022
8438814
Update tests/utilities/test_data.py
otaj May 6, 2022
d4f727c
Update pytorch_lightning/utilities/data.py
otaj May 6, 2022
76ccfcd
apply suggestions
May 6, 2022
c74bbd4
Merge branch 'master' into bugfix/dataloader_wrapper
May 6, 2022
b2c403a
Update pytorch_lightning/utilities/data.py
otaj May 6, 2022
577e626
Update pytorch_lightning/utilities/data.py
otaj May 6, 2022
28a4596
separate dataloader specific parts into own wrapper plus extra test case
May 6, 2022
9f78234
merge master
May 6, 2022
38ec191
Potential change for not relying on dataset
carmocca May 6, 2022
c03378f
Revert "Potential change for not relying on dataset"
May 6, 2022
0674459
new "esoteric" test cases
May 6, 2022
80799c7
Sorry Ota! Changed my mind about separting these 2
carmocca May 6, 2022
1950aee
Filter at beginning
carmocca May 6, 2022
23fb84a
fixing for another ugly corner case
May 9, 2022
e6d91c9
Merge branch 'master' into bugfix/dataloader_wrapper
May 9, 2022
9487135
update changelog
May 10, 2022
4429bf8
Update CHANGELOG.md
otaj May 10, 2022
e010575
Update pytorch_lightning/utilities/data.py
otaj May 10, 2022
ac94aa8
Update pytorch_lightning/utilities/data.py
otaj May 10, 2022
4fed618
apply some suggestions
May 10, 2022
2fd7b53
set operations correctly
May 10, 2022
4c4ed2a
Update pytorch_lightning/utilities/data.py
otaj May 11, 2022
773bbc5
Merge branch 'master' into bugfix/dataloader_wrapper
May 11, 2022
69e3ed0
Merge branch 'master' into bugfix/dataloader_wrapper
May 11, 2022
dafeb04
make fault tolerant training work as much as possible
May 11, 2022
23973da
resolve test plus suggestions
May 11, 2022
8a8e866
Apply suggestion
carmocca May 11, 2022
90b581a
Simple rename
carmocca May 11, 2022
4137d32
Fixes
carmocca May 11, 2022
3e7e254
parametrize tests
May 11, 2022
4f11905
merge master
May 12, 2022
d467530
Merge branch 'master' into bugfix/dataloader_wrapper
akihironitta Jun 1, 2022
341d4d9
Merge branch 'master' into bugfix/dataloader_wrapper
carmocca Jun 14, 2022
9b3bf00
merge master
Jun 20, 2022
e73f07e
add tensor
Jun 20, 2022
9da5fd6
merge master
Jun 21, 2022
3f9552a
Merge branch 'master' into bugfix/dataloader_wrapper
carmocca Jun 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue causing zero-division error for empty dataloaders ([#12885](https://github.com/PyTorchLightning/pytorch-lightning/pull/12885))


-
- Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))


-
Expand Down
84 changes: 58 additions & 26 deletions pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def _update_dataloader(
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
" This argument was automatically passed to your DataLoader by PyTorch Lightning,"
" however you are also passing it in your implementation."
otaj marked this conversation as resolved.
Show resolved Hide resolved
)
raise MisconfigurationException(message) from e
return dataloader
Expand All @@ -208,8 +210,16 @@ def _get_dataloader_init_kwargs(
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")

# get the dataloader instance attributes
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
was_wrapped = hasattr(dataloader, "_set_arg_names")
if was_wrapped:
attrs = {k: getattr(dataloader, "__" + k) for k in dataloader._set_arg_names}
original_dataset = dataloader.__dataset # we have this saved from _wrap_init
else:
# get the dataloader instance attributes
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
# We cannot be 100% sure the class sets dataset argument. Let's set it to None to be safe
# and hope we will get it from `dl_kwargs`. If not, we will fail in `if required_args:` check.
otaj marked this conversation as resolved.
Show resolved Hide resolved
original_dataset = None
# not part of `vars`
attrs["multiprocessing_context"] = dataloader.multiprocessing_context

Expand All @@ -218,17 +228,28 @@ def _get_dataloader_init_kwargs(
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if has_variadic_kwargs:
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
carmocca marked this conversation as resolved.
Show resolved Hide resolved
params.update(inspect.signature(DataLoader.__init__).parameters)
del params["self"]

if was_wrapped:
# if the dataloader was wrapped in a hook, only take arguments with default values
# and assume user passes their kwargs correctly
otaj marked this conversation as resolved.
Show resolved Hide resolved
params.update(
{k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty}
)
else:
params.update(inspect.signature(DataLoader.__init__).parameters)
params.pop("self", None)

# keep only the params whose default is different to the current attr value
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_defaults.add("dataset")

if not was_wrapped:
# add `dataset` as it might have been replaced with `*args`
non_defaults.add("dataset")

# kwargs to re-construct the dataloader
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
if isinstance(dl_kwargs["dataset"], IterableDataset):
dataset = dl_kwargs.get("dataset", original_dataset)
if isinstance(dataset, IterableDataset):
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None
else:
Expand All @@ -244,12 +265,10 @@ def _get_dataloader_init_kwargs(
required_args = sorted(required_args)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
"`*_dataloader` hook of your module, we will do this for you. Otherwise, define `self.missing_arg_name` inside your `__init__`."
)

if not has_variadic_kwargs:
Expand All @@ -259,12 +278,10 @@ def _get_dataloader_init_kwargs(
missing_kwargs = sorted(missing_kwargs)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
"add the `__init__` arguments or allow passing `**kwargs`"
)

if _FaultTolerantMode.detect_current_mode().is_automatic:
Expand Down Expand Up @@ -336,14 +353,29 @@ def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
]

cls = type(obj)
for arg_name, arg_value in chain(zip(param_names, args), kwargs.items()):
if hasattr(cls, arg_name) and getattr(cls, arg_name).fset is None:
# the class defines a read-only (no setter) property of this name. it's likely that the implementation
# will set `self._arg_name = arg_value` in `__init__` which is the attribute returned by the `arg_name`
# property so we are fine skipping in that case
continue
setattr(obj, arg_name, arg_value)
otaj marked this conversation as resolved.
Show resolved Hide resolved
otaj marked this conversation as resolved.
Show resolved Hide resolved
if not hasattr(obj, "_set_arg_names"):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
set_arg_names = set()
for arg_name, arg_value in chain(zip(param_names, args), kwargs.items()):
if not hasattr(obj, "__" + arg_name):
otaj marked this conversation as resolved.
Show resolved Hide resolved
# Set the value privately if it has not been set yet. This achieves two things:
# 1. The argument is the one that was passed in the outermost call
# and not overriden by a call to `super().__init__()`
# 2. The argument is the passed value and does not get overwritten in `__init__()`
setattr(obj, "__" + arg_name, arg_value)
set_arg_names.add(arg_name)
obj._set_arg_names = set(set_arg_names)

if not hasattr(obj, "__dataset"):
# We have not found value for dataset yet, but we will find it eventually
# because it has to be passed to the original DataLoader
dataset_value = None
if "dataset" in param_names[: len(args)]:
dataset_value = args[param_names.index("dataset")]
elif "dataset" in kwargs:
dataset_value = kwargs["dataset"]
if dataset_value is not None:
otaj marked this conversation as resolved.
Show resolved Hide resolved
setattr(obj, "__dataset", dataset_value)
otaj marked this conversation as resolved.
Show resolved Hide resolved

init(obj, *args, **kwargs)

return wrapper
Expand All @@ -367,7 +399,7 @@ def recurse(cl: Type[Any]) -> None:
def _replace_dataloader_init_method() -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
subclasses = _get_all_subclasses(DataLoader)
subclasses = _get_all_subclasses(DataLoader) | {DataLoader}
for subclass in subclasses:
subclass._old_init = subclass.__init__
subclass.__init__ = _wrap_init(subclass.__init__)
Expand Down
2 changes: 1 addition & 1 deletion tests/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __init__(self, new_arg, *args, **kwargs):
with pytest.raises(
MisconfigurationException,
match=(
r"Trying to inject `DistributedSampler` into the `CustomDataLoader` instance.*"
r"Trying to inject custom `Sampler` into the `CustomDataLoader` instance.*"
r"The missing attributes are \['new_arg'\]"
),
):
Expand Down
41 changes: 29 additions & 12 deletions tests/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,28 @@ def test_has_len_all_rank():


def test_update_dataloader_typerror_custom_exception():
class BadImpl(DataLoader):
class BadStandaloneGoodHookImpl(DataLoader):
def __init__(self, foo, *args, **kwargs):
self.foo = foo
# positional conflict with `dataset`
super().__init__(foo, *args, **kwargs)

dataloader = BadImpl([1, 2, 3])
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"):
_update_dataloader(dataloader, dataloader.sampler)

class BadImpl2(DataLoader):
with _replace_dataloader_init_method():
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
assert isinstance(new_dataloader, BadStandaloneGoodHookImpl)

class BadImpl(DataLoader):
def __init__(self, randomize, *args, **kwargs):
self.randomize = randomize
# keyword conflict with `shuffle`
super().__init__(*args, shuffle=randomize, **kwargs)

dataloader = BadImpl2(False, [])
dataloader = BadImpl(False, [])
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"):
_update_dataloader(dataloader, dataloader.sampler)

Expand All @@ -166,30 +171,36 @@ def __init__(self, randomize, *args, **kwargs):

def test_replace_dataloader_init_method():
"""Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and
sets them as attributes."""
sets them as private attributes."""

class DataLoaderSubclass1(DataLoader):
def __init__(self, attribute1, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute1 = attribute1
self.at1 = attribute1
super().__init__(*args, **kwargs)

class DataLoaderSubclass2(DataLoaderSubclass1):
def __init__(self, attribute2, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute2 = attribute2
self.at2 = attribute2
super().__init__(attribute2 + "-2", *args, **kwargs)

with _replace_dataloader_init_method():
dataloader = DataLoaderSubclass1("attribute1", dataset=range(4), batch_size=2)

assert dataloader.attribute1 == "attribute1"
assert dataloader.__attribute1 == "attribute1"
assert dataloader._set_arg_names == {"attribute1", "dataset", "batch_size"}
assert dataloader.dataset == range(4)
assert dataloader.batch_size == 2
assert dataloader.at1 == "attribute1" # But the value still gets passed when it should

with _replace_dataloader_init_method():
dataloader = DataLoaderSubclass2("attribute2", dataset=range(4), batch_size=2)

assert dataloader.attribute1 == "attribute2-2"
assert dataloader.attribute2 == "attribute2"
assert dataloader.__attribute2 == "attribute2"
assert dataloader._set_arg_names == {"attribute2", "dataset", "batch_size"}
assert dataloader.dataset == range(4)
assert dataloader.batch_size == 2
assert dataloader.at1 == "attribute2-2" # But the value still gets passed when it should
assert dataloader.at2 == "attribute2" # But the value still gets passed when it should

# Failing test case from issue 12564
class MyBaseDataLoader(DataLoader):
Expand All @@ -207,6 +218,9 @@ def __init__(self, data: torch.Tensor, *args, **kwargs):

assert dataloader.data is data
assert dataloader.dataset == range(10)
assert dataloader.__data is data
assert dataloader.__dataset == range(10)
assert dataloader._set_arg_names == {"data", "batch_size"}

# `poptorch.DataLoader` uses this pattern, simulate it
class PoptorchDataLoader(DataLoader):
Expand All @@ -221,12 +235,15 @@ def options(self):
# †his read-only property pattern is fine
dataloader = PoptorchDataLoader(123, [1])
assert dataloader.options == 123
assert not hasattr(dataloader, "__options")

# still works with the init replacement
with _replace_dataloader_init_method():
dataloader = PoptorchDataLoader(123, [1])

assert dataloader.options == 123
assert dataloader._set_arg_names == {"options"}
assert dataloader.__options == 123


@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
Expand Down