From 1622fb96110b071d1616bd4486d41e97bdcf3910 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 5 Feb 2024 18:16:57 +0000 Subject: [PATCH] [BugFix,Feature] filter_empty in apply (#661) (cherry picked from commit f11eac657e89e8476e1a939fe77fe9c8d366d54e) --- tensordict/_lazy.py | 9 ++++- tensordict/_td.py | 82 ++++++++++++++++++++++++++------------- tensordict/base.py | 29 ++++++++++++-- tensordict/nn/params.py | 8 ++-- tensordict/tensorclass.py | 12 ++++++ test/test_tensordict.py | 11 ++++++ 6 files changed, 116 insertions(+), 35 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 004d329e8..b0c6ee6cc 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1354,8 +1354,9 @@ def _apply_nest( named: bool = False, nested_keys: bool = False, prefix: tuple = (), + filter_empty: bool | None = None, **constructor_kwargs, - ) -> T: + ) -> T | None: if inplace and any( arg for arg in (batch_size, device, names, constructor_kwargs) ): @@ -1378,6 +1379,7 @@ def _apply_nest( nested_keys=nested_keys, prefix=prefix, inplace=inplace, + filter_empty=filter_empty, **constructor_kwargs, ) @@ -1392,11 +1394,14 @@ def _apply_nest( default=default, named=named, nested_keys=nested_keys, - prefix=prefix + (i,), + prefix=prefix, # + (i,), inplace=inplace, + filter_empty=filter_empty, ) for i, (td, *oth) in enumerate(zip(self.tensordicts, *others)) ] + if filter_empty and all(r is None for r in results): + return if not inplace: out = LazyStackedTensorDict( *results, diff --git a/tensordict/_td.py b/tensordict/_td.py index 342a75e7a..aa7574a1c 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -651,33 +651,42 @@ def _apply_nest( named: bool = False, nested_keys: bool = False, prefix: tuple = (), + filter_empty: bool | None = None, **constructor_kwargs, - ) -> T: + ) -> T | None: if inplace: - out = self + result = self + is_locked = result.is_locked elif batch_size is not None: - out = TensorDict( - {}, - batch_size=torch.Size(batch_size), - names=names, - device=self.device if not device else device, - _run_checks=False, - **constructor_kwargs, - ) + + def make_result(): + return TensorDict( + {}, + batch_size=torch.Size(batch_size), + names=names, + device=self.device if not device else device, + _run_checks=False, + **constructor_kwargs, + ) + + result = None + is_locked = False else: - out = TensorDict( - {}, - batch_size=self.batch_size, - device=self.device if not device else device, - names=self.names if self._has_names() else None, - _run_checks=False, - **constructor_kwargs, - ) - is_locked = out.is_locked - if not inplace and is_locked: - out.unlock_() + def make_result(): + return TensorDict( + {}, + batch_size=self.batch_size, + device=self.device if not device else device, + names=self.names if self._has_names() else None, + _run_checks=False, + **constructor_kwargs, + ) + result = None + is_locked = False + + any_set = False for key, item in self.items(): if not call_on_nested and _is_tensor_collection(item.__class__): if default is not NO_DEFAULT: @@ -702,6 +711,7 @@ def _apply_nest( nested_keys=nested_keys, default=default, prefix=prefix + (key,), + filter_empty=filter_empty, **constructor_kwargs, ) else: @@ -714,19 +724,36 @@ def _apply_nest( else: item_trsf = fn(item, *_others) if item_trsf is not None: + if not any_set: + if result is None: + result = make_result() + any_set = True if isinstance(self, _SubTensorDict): - out.set(key, item_trsf, inplace=inplace) + result.set(key, item_trsf, inplace=inplace) else: - out._set_str( + result._set_str( key, item_trsf, inplace=BEST_ATTEMPT_INPLACE if inplace else False, validated=checked, ) + if filter_empty and not any_set: + return + elif filter_empty is None and not any_set: + warn( + "Your resulting tensordict has no leaves but you did not specify filter_empty=False. " + "Currently, this returns an empty tree (filter_empty=True), but from v0.5 it will return " + "a None unless filter_empty=False. " + "To silcence this warning, set filter_empty to the desired value in your call to `apply`.", + category=DeprecationWarning, + ) + if result is None: + result = make_result() + if not inplace and is_locked: - out.lock_() - return out + result.lock_() + return result # Functorch compatibility @cache # noqa: B019 @@ -862,7 +889,10 @@ def _expand(tensor): names = [None] * (len(shape) - tensordict_dims) + self.names return self._fast_apply( - _expand, batch_size=shape, call_on_nested=True, names=names + _expand, + batch_size=shape, + call_on_nested=True, + names=names, ) def _unbind(self, dim: int): diff --git a/tensordict/base.py b/tensordict/base.py index 38efe9cb9..64902d147 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3729,6 +3729,8 @@ def apply_(self, fn: Callable, *others, **kwargs) -> T: *others (sequence of TensorDictBase, optional): the other tensordicts to be used. + Keyword Args: See :meth:`~.apply`. + Returns: self or a copy of self with the function applied @@ -3744,8 +3746,9 @@ def apply( names: Sequence[str] | None = None, inplace: bool = False, default: Any = NO_DEFAULT, + filter_empty: bool | None = None, **constructor_kwargs, - ) -> T: + ) -> T | None: """Applies a callable to all values stored in the tensordict and sets them in a new tensordict. The callable signature must be ``Callable[Tuple[Tensor, ...], Optional[Union[Tensor, TensorDictBase]]]``. @@ -3773,6 +3776,12 @@ def apply( default (Any, optional): default value for missing entries in the other tensordicts. If not provided, missing entries will raise a `KeyError`. + filter_empty (bool, optional): if ``True``, empty tensordicts will be + filtered out. This also comes with a lower computational cost as + empty data structures won't be created and destroyed. Non-tensor data + is considered as a leaf and thereby will be kept in the tensordict even + if left untouched by the function. + Defaults to ``False`` for backward compatibility. **constructor_kwargs: additional keyword arguments to be passed to the TensorDict constructor. @@ -3830,6 +3839,7 @@ def apply( inplace=inplace, checked=False, default=default, + filter_empty=filter_empty, **constructor_kwargs, ) @@ -3843,8 +3853,9 @@ def named_apply( names: Sequence[str] | None = None, inplace: bool = False, default: Any = NO_DEFAULT, + filter_empty: bool | None = None, **constructor_kwargs, - ) -> T: + ) -> T | None: """Applies a key-conditioned callable to all values stored in the tensordict and sets them in a new atensordict. The callable signature must be ``Callable[Tuple[str, Tensor, ...], Optional[Union[Tensor, TensorDictBase]]]``. @@ -3874,6 +3885,10 @@ def named_apply( default (Any, optional): default value for missing entries in the other tensordicts. If not provided, missing entries will raise a `KeyError`. + filter_empty (bool, optional): if ``True``, empty tensordicts will be + filtered out. This also comes with a lower computational cost as + empty data structures won't be created and destroyed. Defaults to + ``False`` for backward compatibility. **constructor_kwargs: additional keyword arguments to be passed to the TensorDict constructor. @@ -3958,6 +3973,7 @@ def named_apply( default=default, named=True, nested_keys=nested_keys, + filter_empty=filter_empty, **constructor_kwargs, ) @@ -3976,8 +3992,9 @@ def _apply_nest( named: bool = False, nested_keys: bool = False, prefix: tuple = (), + filter_empty: bool | None = None, **constructor_kwargs, - ) -> T: + ) -> T | None: ... def _fast_apply( @@ -3992,8 +4009,11 @@ def _fast_apply( default: Any = NO_DEFAULT, named: bool = False, nested_keys: bool = False, + # filter_empty must be False because we use _fast_apply for all sorts of ops like expand etc + # and non-tensor data will disappear if we use True by default. + filter_empty: bool | None = False, **constructor_kwargs, - ) -> T: + ) -> T | None: """A faster apply method. This method does not run any check after performing the func. This @@ -4013,6 +4033,7 @@ def _fast_apply( named=named, default=default, nested_keys=nested_keys, + filter_empty=filter_empty, **constructor_kwargs, ) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 9d91ca0ba..81a124c81 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -464,8 +464,9 @@ def apply( names: Sequence[str] | None = None, inplace: bool = False, default: Any = NO_DEFAULT, + filter_empty: bool | None = None, **constructor_kwargs, - ) -> TensorDictBase: + ) -> TensorDictBase | None: ... @_unlock_and_set(inplace=True) @@ -478,8 +479,9 @@ def named_apply( names: Sequence[str] | None = None, inplace: bool = False, default: Any = NO_DEFAULT, + filter_empty: bool | None = None, **constructor_kwargs, - ) -> TensorDictBase: + ) -> TensorDictBase | None: ... @_unlock_and_set(inplace=True) @@ -1079,7 +1081,7 @@ def update_at_( ... @_apply_on_data - def apply_(self, fn: Callable, *others) -> T: + def apply_(self, fn: Callable, *others, **kwargs) -> T: ... def _apply(self, fn, recurse=True): diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index a07e8dae7..2c0397422 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1394,3 +1394,15 @@ def __torch_function__( if not escape_conversion: return _from_tensordict_with_copy(tensorclass_instance, result) return result + + def _apply_nest(self, *args, **kwargs): + kwargs["filter_empty"] = False + return _wrap_method(self, "_apply_nest", self._tensordict._apply_nest)( + *args, **kwargs + ) + + def _fast_apply(self, *args, **kwargs): + kwargs["filter_empty"] = False + return _wrap_method(self, "_fast_apply", self._tensordict._fast_apply)( + *args, **kwargs + ) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index d029eaa3a..de27b62cd 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1980,6 +1980,15 @@ def get_old_val(newval, oldval): assert key == ("nested", "newkey") assert (td_1[key] == 0).all() + @pytest.mark.parametrize("inplace", [False, True]) + def test_apply_filter(self, td_name, device, inplace): + td = getattr(self, td_name)(device) + assert td.apply(lambda x: None, filter_empty=False) is not None + if td_name != "td_with_non_tensor": + assert td.apply(lambda x: None, filter_empty=True) is None + else: + assert td.apply(lambda x: None, filter_empty=True) is not None + @pytest.mark.parametrize("inplace", [False, True]) def test_apply_other(self, td_name, device, inplace): td = getattr(self, td_name)(device) @@ -4309,6 +4318,7 @@ def test_squeeze_with_none(self, td_name, device, squeeze_dim=None): assert (td.get("a") == 1).all() @pytest.mark.filterwarnings("error") + @set_lazy_legacy(True) def test_stack_onto(self, td_name, device, tmpdir): torch.manual_seed(1) td = getattr(self, td_name)(device) @@ -4361,6 +4371,7 @@ def test_stack_onto(self, td_name, device, tmpdir): assert (td_stack == td_out).all() @pytest.mark.filterwarnings("error") + @set_lazy_legacy(True) def test_stack_subclasses_on_td(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device)