Skip to content

Commit

Permalink
[BugFix,Feature] filter_empty in apply (#661)
Browse files Browse the repository at this point in the history
(cherry picked from commit f11eac6)
  • Loading branch information
vmoens committed Mar 25, 2024
1 parent a749b51 commit 1622fb9
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 35 deletions.
9 changes: 7 additions & 2 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,8 +1354,9 @@ def _apply_nest(
named: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
filter_empty: bool | None = None,
**constructor_kwargs,
) -> T:
) -> T | None:
if inplace and any(
arg for arg in (batch_size, device, names, constructor_kwargs)
):
Expand All @@ -1378,6 +1379,7 @@ def _apply_nest(
nested_keys=nested_keys,
prefix=prefix,
inplace=inplace,
filter_empty=filter_empty,
**constructor_kwargs,
)

Expand All @@ -1392,11 +1394,14 @@ def _apply_nest(
default=default,
named=named,
nested_keys=nested_keys,
prefix=prefix + (i,),
prefix=prefix, # + (i,),
inplace=inplace,
filter_empty=filter_empty,
)
for i, (td, *oth) in enumerate(zip(self.tensordicts, *others))
]
if filter_empty and all(r is None for r in results):
return
if not inplace:
out = LazyStackedTensorDict(
*results,
Expand Down
82 changes: 56 additions & 26 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,33 +651,42 @@ def _apply_nest(
named: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
filter_empty: bool | None = None,
**constructor_kwargs,
) -> T:
) -> T | None:
if inplace:
out = self
result = self
is_locked = result.is_locked
elif batch_size is not None:
out = TensorDict(
{},
batch_size=torch.Size(batch_size),
names=names,
device=self.device if not device else device,
_run_checks=False,
**constructor_kwargs,
)

def make_result():
return TensorDict(
{},
batch_size=torch.Size(batch_size),
names=names,
device=self.device if not device else device,
_run_checks=False,
**constructor_kwargs,
)

result = None
is_locked = False
else:
out = TensorDict(
{},
batch_size=self.batch_size,
device=self.device if not device else device,
names=self.names if self._has_names() else None,
_run_checks=False,
**constructor_kwargs,
)

is_locked = out.is_locked
if not inplace and is_locked:
out.unlock_()
def make_result():
return TensorDict(
{},
batch_size=self.batch_size,
device=self.device if not device else device,
names=self.names if self._has_names() else None,
_run_checks=False,
**constructor_kwargs,
)

result = None
is_locked = False

any_set = False
for key, item in self.items():
if not call_on_nested and _is_tensor_collection(item.__class__):
if default is not NO_DEFAULT:
Expand All @@ -702,6 +711,7 @@ def _apply_nest(
nested_keys=nested_keys,
default=default,
prefix=prefix + (key,),
filter_empty=filter_empty,
**constructor_kwargs,
)
else:
Expand All @@ -714,19 +724,36 @@ def _apply_nest(
else:
item_trsf = fn(item, *_others)
if item_trsf is not None:
if not any_set:
if result is None:
result = make_result()
any_set = True
if isinstance(self, _SubTensorDict):
out.set(key, item_trsf, inplace=inplace)
result.set(key, item_trsf, inplace=inplace)
else:
out._set_str(
result._set_str(
key,
item_trsf,
inplace=BEST_ATTEMPT_INPLACE if inplace else False,
validated=checked,
)

if filter_empty and not any_set:
return
elif filter_empty is None and not any_set:
warn(
"Your resulting tensordict has no leaves but you did not specify filter_empty=False. "
"Currently, this returns an empty tree (filter_empty=True), but from v0.5 it will return "
"a None unless filter_empty=False. "
"To silcence this warning, set filter_empty to the desired value in your call to `apply`.",
category=DeprecationWarning,
)
if result is None:
result = make_result()

if not inplace and is_locked:
out.lock_()
return out
result.lock_()
return result

# Functorch compatibility
@cache # noqa: B019
Expand Down Expand Up @@ -862,7 +889,10 @@ def _expand(tensor):

names = [None] * (len(shape) - tensordict_dims) + self.names
return self._fast_apply(
_expand, batch_size=shape, call_on_nested=True, names=names
_expand,
batch_size=shape,
call_on_nested=True,
names=names,
)

def _unbind(self, dim: int):
Expand Down
29 changes: 25 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3729,6 +3729,8 @@ def apply_(self, fn: Callable, *others, **kwargs) -> T:
*others (sequence of TensorDictBase, optional): the other
tensordicts to be used.
Keyword Args: See :meth:`~.apply`.
Returns:
self or a copy of self with the function applied
Expand All @@ -3744,8 +3746,9 @@ def apply(
names: Sequence[str] | None = None,
inplace: bool = False,
default: Any = NO_DEFAULT,
filter_empty: bool | None = None,
**constructor_kwargs,
) -> T:
) -> T | None:
"""Applies a callable to all values stored in the tensordict and sets them in a new tensordict.
The callable signature must be ``Callable[Tuple[Tensor, ...], Optional[Union[Tensor, TensorDictBase]]]``.
Expand Down Expand Up @@ -3773,6 +3776,12 @@ def apply(
default (Any, optional): default value for missing entries in the
other tensordicts. If not provided, missing entries will
raise a `KeyError`.
filter_empty (bool, optional): if ``True``, empty tensordicts will be
filtered out. This also comes with a lower computational cost as
empty data structures won't be created and destroyed. Non-tensor data
is considered as a leaf and thereby will be kept in the tensordict even
if left untouched by the function.
Defaults to ``False`` for backward compatibility.
**constructor_kwargs: additional keyword arguments to be passed to the
TensorDict constructor.
Expand Down Expand Up @@ -3830,6 +3839,7 @@ def apply(
inplace=inplace,
checked=False,
default=default,
filter_empty=filter_empty,
**constructor_kwargs,
)

Expand All @@ -3843,8 +3853,9 @@ def named_apply(
names: Sequence[str] | None = None,
inplace: bool = False,
default: Any = NO_DEFAULT,
filter_empty: bool | None = None,
**constructor_kwargs,
) -> T:
) -> T | None:
"""Applies a key-conditioned callable to all values stored in the tensordict and sets them in a new atensordict.
The callable signature must be ``Callable[Tuple[str, Tensor, ...], Optional[Union[Tensor, TensorDictBase]]]``.
Expand Down Expand Up @@ -3874,6 +3885,10 @@ def named_apply(
default (Any, optional): default value for missing entries in the
other tensordicts. If not provided, missing entries will
raise a `KeyError`.
filter_empty (bool, optional): if ``True``, empty tensordicts will be
filtered out. This also comes with a lower computational cost as
empty data structures won't be created and destroyed. Defaults to
``False`` for backward compatibility.
**constructor_kwargs: additional keyword arguments to be passed to the
TensorDict constructor.
Expand Down Expand Up @@ -3958,6 +3973,7 @@ def named_apply(
default=default,
named=True,
nested_keys=nested_keys,
filter_empty=filter_empty,
**constructor_kwargs,
)

Expand All @@ -3976,8 +3992,9 @@ def _apply_nest(
named: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
filter_empty: bool | None = None,
**constructor_kwargs,
) -> T:
) -> T | None:
...

def _fast_apply(
Expand All @@ -3992,8 +4009,11 @@ def _fast_apply(
default: Any = NO_DEFAULT,
named: bool = False,
nested_keys: bool = False,
# filter_empty must be False because we use _fast_apply for all sorts of ops like expand etc
# and non-tensor data will disappear if we use True by default.
filter_empty: bool | None = False,
**constructor_kwargs,
) -> T:
) -> T | None:
"""A faster apply method.
This method does not run any check after performing the func. This
Expand All @@ -4013,6 +4033,7 @@ def _fast_apply(
named=named,
default=default,
nested_keys=nested_keys,
filter_empty=filter_empty,
**constructor_kwargs,
)

Expand Down
8 changes: 5 additions & 3 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,9 @@ def apply(
names: Sequence[str] | None = None,
inplace: bool = False,
default: Any = NO_DEFAULT,
filter_empty: bool | None = None,
**constructor_kwargs,
) -> TensorDictBase:
) -> TensorDictBase | None:
...

@_unlock_and_set(inplace=True)
Expand All @@ -478,8 +479,9 @@ def named_apply(
names: Sequence[str] | None = None,
inplace: bool = False,
default: Any = NO_DEFAULT,
filter_empty: bool | None = None,
**constructor_kwargs,
) -> TensorDictBase:
) -> TensorDictBase | None:
...

@_unlock_and_set(inplace=True)
Expand Down Expand Up @@ -1079,7 +1081,7 @@ def update_at_(
...

@_apply_on_data
def apply_(self, fn: Callable, *others) -> T:
def apply_(self, fn: Callable, *others, **kwargs) -> T:
...

def _apply(self, fn, recurse=True):
Expand Down
12 changes: 12 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,3 +1394,15 @@ def __torch_function__(
if not escape_conversion:
return _from_tensordict_with_copy(tensorclass_instance, result)
return result

def _apply_nest(self, *args, **kwargs):
kwargs["filter_empty"] = False
return _wrap_method(self, "_apply_nest", self._tensordict._apply_nest)(
*args, **kwargs
)

def _fast_apply(self, *args, **kwargs):
kwargs["filter_empty"] = False
return _wrap_method(self, "_fast_apply", self._tensordict._fast_apply)(
*args, **kwargs
)
11 changes: 11 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1980,6 +1980,15 @@ def get_old_val(newval, oldval):
assert key == ("nested", "newkey")
assert (td_1[key] == 0).all()

@pytest.mark.parametrize("inplace", [False, True])
def test_apply_filter(self, td_name, device, inplace):
td = getattr(self, td_name)(device)
assert td.apply(lambda x: None, filter_empty=False) is not None
if td_name != "td_with_non_tensor":
assert td.apply(lambda x: None, filter_empty=True) is None
else:
assert td.apply(lambda x: None, filter_empty=True) is not None

@pytest.mark.parametrize("inplace", [False, True])
def test_apply_other(self, td_name, device, inplace):
td = getattr(self, td_name)(device)
Expand Down Expand Up @@ -4309,6 +4318,7 @@ def test_squeeze_with_none(self, td_name, device, squeeze_dim=None):
assert (td.get("a") == 1).all()

@pytest.mark.filterwarnings("error")
@set_lazy_legacy(True)
def test_stack_onto(self, td_name, device, tmpdir):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
Expand Down Expand Up @@ -4361,6 +4371,7 @@ def test_stack_onto(self, td_name, device, tmpdir):
assert (td_stack == td_out).all()

@pytest.mark.filterwarnings("error")
@set_lazy_legacy(True)
def test_stack_subclasses_on_td(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
Expand Down

0 comments on commit 1622fb9

Please sign in to comment.