3737from pytorch_lightning .trainer .states import RunningStage
3838from pytorch_lightning .utilities .apply_func import _is_dataclass_instance
3939from pytorch_lightning .utilities .auto_restart import CaptureIterableDataset , CaptureMapDataset , FastForwardSampler
40- from pytorch_lightning .utilities .enums import _FaultTolerantMode
40+ from pytorch_lightning .utilities .enums import _FaultTolerantMode , LightningEnum
4141from pytorch_lightning .utilities .exceptions import MisconfigurationException
4242from pytorch_lightning .utilities .rank_zero import rank_zero_warn
4343from pytorch_lightning .utilities .seed import pl_worker_init_function
4848warning_cache = WarningCache ()
4949
5050
51+ class _WrapAttrTag (LightningEnum ):
52+ SET = "set"
53+ DEL = "del"
54+
55+ def __call__ (self , * args ):
56+ if self == self .SET :
57+ fn = setattr
58+ else :
59+ fn = delattr
60+ return fn (* args )
61+
62+
5163def _extract_batch_size (batch : BType ) -> Generator [int , None , None ]:
5264 if isinstance (batch , Tensor ):
5365 if batch .ndim == 0 :
@@ -188,27 +200,7 @@ def _update_dataloader(
188200 dataloader : DataLoader , sampler : Union [Sampler , Iterable ], mode : Optional [RunningStage ] = None
189201) -> DataLoader :
190202 dl_args , dl_kwargs = _get_dataloader_init_args_and_kwargs (dataloader , sampler , mode )
191- dl_cls = type (dataloader )
192- try :
193- dataloader = dl_cls (* dl_args , ** dl_kwargs )
194- except TypeError as e :
195- # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
196- # `__init__` arguments map to one `DataLoader.__init__` argument
197- import re
198-
199- match = re .match (r".*__init__\(\) got multiple values .* '(\w+)'" , str (e ))
200- if not match :
201- # an unexpected `TypeError`, continue failure
202- raise
203- argument = match .groups ()[0 ]
204- message = (
205- f"The { dl_cls .__name__ } `DataLoader` implementation has an error where more than one `__init__` argument"
206- f" can be passed to its parent's `{ argument } =...` `__init__` argument. This is likely caused by allowing"
207- f" passing both a custom argument that will map to the `{ argument } ` argument as well as `**kwargs`."
208- f" `kwargs` should be filtered to make sure they don't contain the `{ argument } ` key."
209- " This argument was automatically passed to your DataLoader by PyTorch Lightning."
210- )
211- raise MisconfigurationException (message ) from e
203+ dataloader = _reinstantiate_wrapped_cls (dataloader , * dl_args , ** dl_kwargs )
212204 return dataloader
213205
214206
@@ -374,7 +366,7 @@ def _dataloader_init_kwargs_resolve_sampler(
374366 "this, expose an argument `sampler` in the `__init__` method of your custom class."
375367 )
376368
377- batch_sampler = batch_sampler_cls ( * args , ** kwargs )
369+ batch_sampler = _reinstantiate_wrapped_cls ( batch_sampler , * args , ** kwargs )
378370 else :
379371 try :
380372 batch_sampler = batch_sampler_cls (
@@ -449,6 +441,37 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
449441 dataloader .worker_init_fn = partial (pl_worker_init_function , rank = rank )
450442
451443
444+ def _reinstantiate_wrapped_cls (orig_object : Any , * args : Any , explicit_cls : Optional [Type ] = None , ** kwargs : Any ) -> Any :
445+ constructor = type (orig_object ) if explicit_cls is None else explicit_cls
446+
447+ try :
448+ result = constructor (* args , ** kwargs )
449+ except TypeError as e :
450+ # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
451+ # `__init__` arguments map to one `DataLoader.__init__` argument
452+ import re
453+
454+ match = re .match (r".*__init__\(\) got multiple values .* '(\w+)'" , str (e ))
455+ if not match :
456+ # an unexpected `TypeError`, continue failure
457+ raise
458+ argument = match .groups ()[0 ]
459+ message = (
460+ f"The { constructor .__name__ } implementation has an error where more than one `__init__` argument"
461+ f" can be passed to its parent's `{ argument } =...` `__init__` argument. This is likely caused by allowing"
462+ f" passing both a custom argument that will map to the `{ argument } ` argument as well as `**kwargs`."
463+ f" `kwargs` should be filtered to make sure they don't contain the `{ argument } ` key."
464+ " This argument was automatically passed to your object by PyTorch Lightning."
465+ )
466+ raise MisconfigurationException (message ) from e
467+
468+ attrs_record = getattr (orig_object , "__pl_attrs_record" , list ())
469+ for args , fn in attrs_record :
470+ fn (result , * args )
471+
472+ return result
473+
474+
452475def _wrap_init_method (init : Callable , store_explicit_arg : Optional [str ] = None ) -> Callable :
453476 """Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
454477 :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
@@ -457,6 +480,8 @@ def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None)
457480 def wrapper (obj : Any , * args : Any , ** kwargs : Any ) -> None :
458481 # We need to inspect `init`, as inspecting `obj.__init__`
459482 # can lead to inspecting the wrong function with multiple inheritance
483+ old_inside_init = getattr (obj , "__pl_inside_init" , False )
484+ object .__setattr__ (obj , "__pl_inside_init" , True )
460485 params = inspect .signature (init ).parameters
461486
462487 parameters_defaults = OrderedDict (
@@ -474,21 +499,49 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
474499 }
475500
476501 if not hasattr (obj , "__pl_saved_args" ):
477- obj . __pl_saved_args = args
478- obj . __pl_saved_kwargs = kwargs
479- obj . __pl_saved_arg_names = param_names
480- obj . __pl_saved_default_kwargs = default_kwargs
502+ object . __setattr__ ( obj , " __pl_saved_args" , args )
503+ object . __setattr__ ( obj , " __pl_saved_kwargs" , kwargs )
504+ object . __setattr__ ( obj , " __pl_saved_arg_names" , param_names )
505+ object . __setattr__ ( obj , " __pl_saved_default_kwargs" , default_kwargs )
481506
482507 # We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class)
483508 # so that we can be sure, that it will not get changed anymore.
484509 # That is why we are setting this in every `__init__`
485510 if store_explicit_arg is not None :
486511 if store_explicit_arg in param_names :
487- setattr (obj , f"__{ store_explicit_arg } " , args [param_names .index (store_explicit_arg )])
512+ object . __setattr__ (obj , f"__{ store_explicit_arg } " , args [param_names .index (store_explicit_arg )])
488513 elif store_explicit_arg in kwargs :
489- setattr (obj , f"__{ store_explicit_arg } " , kwargs [store_explicit_arg ])
514+ object . __setattr__ (obj , f"__{ store_explicit_arg } " , kwargs [store_explicit_arg ])
490515
491516 init (obj , * args , ** kwargs )
517+ object .__setattr__ (obj , "__pl_inside_init" , old_inside_init )
518+
519+ return wrapper
520+
521+
522+ def _wrap_attr_method (method : Callable , tag : _WrapAttrTag ) -> Callable :
523+ """Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
524+ :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
525+
526+ @functools .wraps (method )
527+ def wrapper (obj : Any , * args : Any ):
528+ # First, let's find out if we're the first in inheritance chain calling the patched method.
529+ name , * _ = args
530+ prev_call_name , prev_call_method = getattr (obj , "__pl_current_call" , (None , "method" ))
531+ first_call = not (prev_call_name == name and prev_call_method == tag )
532+
533+ # Then mark the current called method
534+ object .__setattr__ (obj , "__pl_current_call" , (name , tag ))
535+
536+ # call original method
537+ method (obj , * args )
538+ if first_call and not getattr (obj , "__pl_inside_init" , True ):
539+ # and save the value it was called with to the internal list,
540+ # if we're outside of __init__ and the original call did not fail and we're the first call
541+ attrs_record = getattr (obj , "__pl_attrs_record" , list ())
542+ attrs_record .append ((args , tag ))
543+ object .__setattr__ (obj , "__pl_attrs_record" , attrs_record )
544+ object .__setattr__ (obj , "__pl_current_call" , (prev_call_name , prev_call_method ))
492545
493546 return wrapper
494547
@@ -508,25 +561,34 @@ def recurse(cl: Type[Any]) -> None:
508561
509562
510563@contextmanager
511- def _replace_init_method (base_cls : Type , store_explicit_arg : Optional [str ] = None ) -> Generator [None , None , None ]:
564+ def _replace_dunder_methods (base_cls : Type , store_explicit_arg : Optional [str ] = None ) -> Generator [None , None , None ]:
512565 """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.
513566
514- It patches the ``__init__`` method .
567+ It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods .
515568 """
516569 classes = _get_all_subclasses (base_cls ) | {base_cls }
517570 for cls in classes :
518571 # Check that __init__ belongs to the class
519572 # https://stackoverflow.com/a/5253424
520573 if "__init__" in cls .__dict__ :
521- cls ._old_init = cls .__init__
574+ cls .__old__init__ = cls .__init__
522575 cls .__init__ = _wrap_init_method (cls .__init__ , store_explicit_arg )
576+
577+ # we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses
578+ # implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls`
579+ for patch_fn_name , tag in (("__setattr__" , _WrapAttrTag .SET ), ("__delattr__" , _WrapAttrTag .DEL )):
580+ if patch_fn_name in cls .__dict__ or cls is base_cls :
581+ saved_name = f"__old{ patch_fn_name } "
582+ setattr (cls , saved_name , getattr (cls , patch_fn_name ))
583+ setattr (cls , patch_fn_name , _wrap_attr_method (getattr (cls , patch_fn_name ), tag ))
523584 yield
524585 for cls in classes :
525- # Check that _old_init belongs to the class
526- # https://stackoverflow.com/a/5253424
527- if "_old_init" in cls .__dict__ :
528- cls .__init__ = cls ._old_init
529- del cls ._old_init
586+ for patched_name in ("__setattr__" , "__delattr__" , "__init__" ):
587+ # Check that __old__{init,setattr,delattr} belongs to the class
588+ # https://stackoverflow.com/a/5253424
589+ if f"__old{ patched_name } " in cls .__dict__ :
590+ setattr (cls , patched_name , getattr (cls , f"__old{ patched_name } " ))
591+ delattr (cls , f"__old{ patched_name } " )
530592
531593
532594def _wrap_with_capture_dataset (dataset : Dataset ) -> Dataset :
0 commit comments