Skip to content

Commit b4250e5

Browse files
awaelchliotaj
andcommitted
Allowed setting attributes on DataLoader and BatchSampler when instantiated inside *_dataloader hooks (#14212)
Co-authored-by: otaj <6065855+otaj@users.noreply.github.com>
1 parent 20984f5 commit b4250e5

File tree

7 files changed

+227
-70
lines changed

7 files changed

+227
-70
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828
- Avoid raising the sampler warning if num_replicas=1 ([#14097](https://github.com/Lightning-AI/lightning/pull/14097))
2929
- Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/Lightning-AI/lightning/pull/9938))
3030
- Avoided requiring the FairScale package to use precision with the fsdp native strategy ([#14092](https://github.com/Lightning-AI/lightning/pull/14092))
31+
- Fixed not preserving set attributes on `DataLoader` and `BatchSampler` when instantiated inside `*_dataloader` hooks ([#14212](https://github.com/Lightning-AI/lightning/pull/14212))
3132

3233

3334
## [1.7.1] - 2022-08-09

src/pytorch_lightning/lite/lite.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
3636
from pytorch_lightning.utilities.data import (
3737
_auto_add_worker_init_fn,
38-
_replace_init_method,
38+
_replace_dunder_methods,
3939
_update_dataloader,
4040
has_iterable_dataset,
4141
)
@@ -403,9 +403,9 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
403403

404404
def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
405405
self._strategy.setup_environment()
406-
with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, "dataset"), _replace_init_method(
407-
BatchSampler
408-
):
406+
with self._strategy.model_sharded_context(), _replace_dunder_methods(
407+
DataLoader, "dataset"
408+
), _replace_dunder_methods(BatchSampler):
409409
return run_method(*args, **kwargs)
410410

411411
def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:

src/pytorch_lightning/strategies/ipu.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn
3131
from pytorch_lightning.utilities.apply_func import apply_to_collection
3232
from pytorch_lightning.utilities.cloud_io import get_filesystem
33-
from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs
33+
from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs, _reinstantiate_wrapped_cls
3434
from pytorch_lightning.utilities.enums import PrecisionType
3535
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3636
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -239,7 +239,9 @@ def _convert_to_poptorch_loader(
239239
dataloader, sampler, mode, self.replication_factor > 1
240240
)
241241
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
242-
dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs)
242+
dataloader = _reinstantiate_wrapped_cls(
243+
dataloader, opts, *dl_args, explicit_cls=poptorch.DataLoader, **dl_kwargs
244+
)
243245
return dataloader
244246

245247
def _handle_gradient_accumulation_steps(self) -> None:

src/pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from pytorch_lightning.utilities.data import (
3232
_auto_add_worker_init_fn,
3333
_is_dataloader_shuffled,
34-
_replace_init_method,
34+
_replace_dunder_methods,
3535
_update_dataloader,
3636
has_iterable_dataset,
3737
has_len_all_ranks,
@@ -428,9 +428,11 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat
428428
"""
429429
source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source")
430430

431-
with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler):
431+
with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler):
432432
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
433-
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning
433+
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning.
434+
# Also, it records all attribute setting and deletion using patched `__setattr__` and `__delattr__`
435+
# methods so that the re-instantiated object is as close to the original as possible.
434436
dataloader = source.dataloader()
435437
if isinstance(dataloader, tuple):
436438
dataloader = list(dataloader)

src/pytorch_lightning/utilities/data.py

Lines changed: 99 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from pytorch_lightning.trainer.states import RunningStage
3838
from pytorch_lightning.utilities.apply_func import _is_dataclass_instance
3939
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
40-
from pytorch_lightning.utilities.enums import _FaultTolerantMode
40+
from pytorch_lightning.utilities.enums import _FaultTolerantMode, LightningEnum
4141
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4242
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
4343
from pytorch_lightning.utilities.seed import pl_worker_init_function
@@ -48,6 +48,18 @@
4848
warning_cache = WarningCache()
4949

5050

51+
class _WrapAttrTag(LightningEnum):
52+
SET = "set"
53+
DEL = "del"
54+
55+
def __call__(self, *args):
56+
if self == self.SET:
57+
fn = setattr
58+
else:
59+
fn = delattr
60+
return fn(*args)
61+
62+
5163
def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
5264
if isinstance(batch, Tensor):
5365
if batch.ndim == 0:
@@ -188,27 +200,7 @@ def _update_dataloader(
188200
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
189201
) -> DataLoader:
190202
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode)
191-
dl_cls = type(dataloader)
192-
try:
193-
dataloader = dl_cls(*dl_args, **dl_kwargs)
194-
except TypeError as e:
195-
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
196-
# `__init__` arguments map to one `DataLoader.__init__` argument
197-
import re
198-
199-
match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
200-
if not match:
201-
# an unexpected `TypeError`, continue failure
202-
raise
203-
argument = match.groups()[0]
204-
message = (
205-
f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument"
206-
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
207-
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
208-
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
209-
" This argument was automatically passed to your DataLoader by PyTorch Lightning."
210-
)
211-
raise MisconfigurationException(message) from e
203+
dataloader = _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs)
212204
return dataloader
213205

214206

@@ -374,7 +366,7 @@ def _dataloader_init_kwargs_resolve_sampler(
374366
"this, expose an argument `sampler` in the `__init__` method of your custom class."
375367
)
376368

377-
batch_sampler = batch_sampler_cls(*args, **kwargs)
369+
batch_sampler = _reinstantiate_wrapped_cls(batch_sampler, *args, **kwargs)
378370
else:
379371
try:
380372
batch_sampler = batch_sampler_cls(
@@ -449,6 +441,37 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
449441
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)
450442

451443

444+
def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any:
445+
constructor = type(orig_object) if explicit_cls is None else explicit_cls
446+
447+
try:
448+
result = constructor(*args, **kwargs)
449+
except TypeError as e:
450+
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
451+
# `__init__` arguments map to one `DataLoader.__init__` argument
452+
import re
453+
454+
match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
455+
if not match:
456+
# an unexpected `TypeError`, continue failure
457+
raise
458+
argument = match.groups()[0]
459+
message = (
460+
f"The {constructor.__name__} implementation has an error where more than one `__init__` argument"
461+
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
462+
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
463+
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
464+
" This argument was automatically passed to your object by PyTorch Lightning."
465+
)
466+
raise MisconfigurationException(message) from e
467+
468+
attrs_record = getattr(orig_object, "__pl_attrs_record", list())
469+
for args, fn in attrs_record:
470+
fn(result, *args)
471+
472+
return result
473+
474+
452475
def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable:
453476
"""Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
454477
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
@@ -457,6 +480,8 @@ def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None)
457480
def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
458481
# We need to inspect `init`, as inspecting `obj.__init__`
459482
# can lead to inspecting the wrong function with multiple inheritance
483+
old_inside_init = getattr(obj, "__pl_inside_init", False)
484+
object.__setattr__(obj, "__pl_inside_init", True)
460485
params = inspect.signature(init).parameters
461486

462487
parameters_defaults = OrderedDict(
@@ -474,21 +499,49 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
474499
}
475500

476501
if not hasattr(obj, "__pl_saved_args"):
477-
obj.__pl_saved_args = args
478-
obj.__pl_saved_kwargs = kwargs
479-
obj.__pl_saved_arg_names = param_names
480-
obj.__pl_saved_default_kwargs = default_kwargs
502+
object.__setattr__(obj, "__pl_saved_args", args)
503+
object.__setattr__(obj, "__pl_saved_kwargs", kwargs)
504+
object.__setattr__(obj, "__pl_saved_arg_names", param_names)
505+
object.__setattr__(obj, "__pl_saved_default_kwargs", default_kwargs)
481506

482507
# We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class)
483508
# so that we can be sure, that it will not get changed anymore.
484509
# That is why we are setting this in every `__init__`
485510
if store_explicit_arg is not None:
486511
if store_explicit_arg in param_names:
487-
setattr(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
512+
object.__setattr__(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
488513
elif store_explicit_arg in kwargs:
489-
setattr(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])
514+
object.__setattr__(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])
490515

491516
init(obj, *args, **kwargs)
517+
object.__setattr__(obj, "__pl_inside_init", old_inside_init)
518+
519+
return wrapper
520+
521+
522+
def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable:
523+
"""Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
524+
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
525+
526+
@functools.wraps(method)
527+
def wrapper(obj: Any, *args: Any):
528+
# First, let's find out if we're the first in inheritance chain calling the patched method.
529+
name, *_ = args
530+
prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method"))
531+
first_call = not (prev_call_name == name and prev_call_method == tag)
532+
533+
# Then mark the current called method
534+
object.__setattr__(obj, "__pl_current_call", (name, tag))
535+
536+
# call original method
537+
method(obj, *args)
538+
if first_call and not getattr(obj, "__pl_inside_init", True):
539+
# and save the value it was called with to the internal list,
540+
# if we're outside of __init__ and the original call did not fail and we're the first call
541+
attrs_record = getattr(obj, "__pl_attrs_record", list())
542+
attrs_record.append((args, tag))
543+
object.__setattr__(obj, "__pl_attrs_record", attrs_record)
544+
object.__setattr__(obj, "__pl_current_call", (prev_call_name, prev_call_method))
492545

493546
return wrapper
494547

@@ -508,25 +561,34 @@ def recurse(cl: Type[Any]) -> None:
508561

509562

510563
@contextmanager
511-
def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
564+
def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
512565
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.
513566
514-
It patches the ``__init__`` method.
567+
It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods.
515568
"""
516569
classes = _get_all_subclasses(base_cls) | {base_cls}
517570
for cls in classes:
518571
# Check that __init__ belongs to the class
519572
# https://stackoverflow.com/a/5253424
520573
if "__init__" in cls.__dict__:
521-
cls._old_init = cls.__init__
574+
cls.__old__init__ = cls.__init__
522575
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
576+
577+
# we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses
578+
# implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls`
579+
for patch_fn_name, tag in (("__setattr__", _WrapAttrTag.SET), ("__delattr__", _WrapAttrTag.DEL)):
580+
if patch_fn_name in cls.__dict__ or cls is base_cls:
581+
saved_name = f"__old{patch_fn_name}"
582+
setattr(cls, saved_name, getattr(cls, patch_fn_name))
583+
setattr(cls, patch_fn_name, _wrap_attr_method(getattr(cls, patch_fn_name), tag))
523584
yield
524585
for cls in classes:
525-
# Check that _old_init belongs to the class
526-
# https://stackoverflow.com/a/5253424
527-
if "_old_init" in cls.__dict__:
528-
cls.__init__ = cls._old_init
529-
del cls._old_init
586+
for patched_name in ("__setattr__", "__delattr__", "__init__"):
587+
# Check that __old__{init,setattr,delattr} belongs to the class
588+
# https://stackoverflow.com/a/5253424
589+
if f"__old{patched_name}" in cls.__dict__:
590+
setattr(cls, patched_name, getattr(cls, f"__old{patched_name}"))
591+
delattr(cls, f"__old{patched_name}")
530592

531593

532594
def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset:

tests/tests_pytorch/lite/test_lite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_setup_dataloaders_return_type():
177177
assert lite_dataloader1.dataset is dataset1
178178

179179

180-
@mock.patch("pytorch_lightning.lite.lite._replace_init_method")
180+
@mock.patch("pytorch_lightning.lite.lite._replace_dunder_methods")
181181
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
182182
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""
183183

0 commit comments

Comments
 (0)