Skip to content

Commit

Permalink
[Feature] Add pad argument to TensorDict.where (#539)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 10, 2023
1 parent d445682 commit 7747505
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 47 deletions.
7 changes: 7 additions & 0 deletions tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,13 @@ def _cat(
implements_for_memmap(torch.cat)(_cat)


def _where(condition, input, other):
return torch.where(condition=condition, input=input.as_tensor(), other=other)


implements_for_memmap(torch.where)(_where)


def set_transfer_ownership(memmap: MemmapTensor, value: bool = True) -> None:
"""Changes the transfer_ownership attribute of a MemmapTensor."""
if isinstance(memmap, MemmapTensor):
Expand Down
2 changes: 1 addition & 1 deletion tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def transpose(self, dim0, dim1):
...

@_carry_over
def where(self, condition, other, *, out=None):
def where(self, condition, other, *, out=None, pad=None):
...

@_carry_over
Expand Down
6 changes: 4 additions & 2 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,10 @@ def is_contiguous(self):
def masked_fill(self, mask, value):
return self.to_tensordict().masked_fill(mask, value)

def where(self, condition, other, *, out=None):
return self.to_tensordict().where(condition=condition, other=other, out=out)
def where(self, condition, other, *, out=None, pad=None):
return self.to_tensordict().where(
condition=condition, other=other, out=out, pad=pad
)

def masked_fill_(self, mask, value):
for key in self.keys(include_nested=True, leaves_only=True):
Expand Down
137 changes: 102 additions & 35 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2932,15 +2932,21 @@ def masked_fill(self, mask: Tensor, value: float | bool) -> T:
"""
raise NotImplementedError

def where(self, condition, other, *, out=None):
def where(self, condition, other, *, out=None, pad=None): # noqa: D417
"""Return a ``TensorDict`` of elements selected from either self or other, depending on condition.
Args:
condition (BoolTensor): When ``True`` (nonzero), yields ``self``,
otherwise yields ``other``.
other (TensorDictBase or Scalar): value (if ``other`` is a scalar)
or values selected at indices where condition is ``False``.
out (Tensor, optional): the output ``TensorDictBase`` instance.
Keyword Args:
out (TensorDictBase, optional): the output ``TensorDictBase`` instance.
pad_value (scalar, optional): if provided, missing keys from the source
or destination tensordict will be written as `torch.where(mask, self, pad)`
or `torch.where(mask, pad, other)`. Defaults to ``None``, ie
missing keys are not tolerated.
"""
raise NotImplementedError
Expand Down Expand Up @@ -4802,38 +4808,87 @@ def to(tensor):
result.batch_size = batch_size
return result

def where(self, condition, other, *, out=None):
if out is None:
if _is_tensor_collection(other.__class__):

def func(tensor, _other):
return torch.where(
expand_as_right(condition, tensor), tensor, _other
)
def where(self, condition, other, *, out=None, pad=None):
if _is_tensor_collection(other.__class__):

return self._fast_apply(func, other)
else:
def func(tensor, _other, key):
if tensor is None:
if pad is not None:
tensor = _other
_other = pad
else:
raise KeyError(
f"Key {key} not found and no pad value provided."
)
cond = expand_as_right(~condition, tensor)
elif _other is None:
if pad is not None:
_other = pad
else:
raise KeyError(
f"Key {key} not found and no pad value provided."
)
cond = expand_as_right(condition, tensor)
else:
cond = expand_as_right(condition, tensor)
return torch.where(
condition=cond,
input=tensor,
other=_other,
)

def func(tensor):
return torch.where(
expand_as_right(condition, tensor), tensor, other
result = self.empty() if out is None else out
other_keys = set(other.keys())
# we turn into a list because out could be = to self!
for key in list(self.keys()):
tensor = self._get_str(key, default=NO_DEFAULT)
_other = other._get_str(key, default=None)
if _is_tensor_collection(type(tensor)):
_out = None if out is None else out._get_str(key, None)
if _other is None:
_other = tensor.empty()
val = tensor.where(
condition=condition, other=_other, out=_out, pad=pad
)

return self._fast_apply(func)
else:
val = func(tensor, _other, key)
result._set_str(key, val, inplace=False, validated=True)
other_keys.discard(key)
for key in other_keys:
tensor = None
_other = other._get_str(key, default=NO_DEFAULT)
if _is_tensor_collection(type(_other)):
try:
tensor = _other.empty()
except NotImplementedError:
# H5 tensordicts do not support select()
tensor = _other.to_tensordict().empty()
val = _other.where(
condition=~condition, other=tensor, out=None, pad=pad
)
else:
val = func(tensor, _other, key)
result._set_str(key, val, inplace=False, validated=True)
return result
else:
if _is_tensor_collection(other.__class__):
if out is None:

def func(tensor, _other, _out):
def func(tensor):
return torch.where(
expand_as_right(condition, tensor), tensor, _other, out=_out
condition=expand_as_right(condition, tensor),
input=tensor,
other=other,
)

return self._fast_apply(func, other, out)
return self._fast_apply(func)
else:

def func(tensor, _out):
return torch.where(
expand_as_right(condition, tensor), tensor, other, out=_out
condition=expand_as_right(condition, tensor),
input=tensor,
other=other,
out=_out,
)

return self._fast_apply(func, out)
Expand Down Expand Up @@ -6179,8 +6234,10 @@ def pin_memory(self) -> T:
def detach_(self) -> T:
raise RuntimeError("Detaching a sub-tensordict in-place cannot be done.")

def where(self, condition, other, *, out=None):
return self.to_tensordict().where(condition=condition, other=other, out=out)
def where(self, condition, other, *, out=None, pad=None):
return self.to_tensordict().where(
condition=condition, other=other, out=out, pad=pad
)

def masked_fill_(self, mask: Tensor, value: float | bool) -> T:
for key, item in self.items():
Expand Down Expand Up @@ -7866,26 +7923,34 @@ def sort_keys(element):

rename_key = _renamed_inplace_method(rename_key_)

def where(self, condition, other, *, out=None):
def where(self, condition, other, *, out=None, pad=None):
condition = condition.unbind(self.stack_dim)
if _is_tensor_collection(other.__class__) or (
isinstance(other, Tensor)
and other.shape[: self.stack_dim] == self.shape[: self.stack_dim]
):
other = other.unbind(self.stack_dim)
return torch.stack(
result = torch.stack(
[
td.where(cond, _other)
td.where(cond, _other, pad=pad)
for td, cond, _other in zip(self.tensordicts, condition, other)
],
self.stack_dim,
out=out,
)
return torch.stack(
[td.where(cond, other) for td, cond in zip(self.tensordicts, condition)],
self.stack_dim,
out=out,
)
else:
result = torch.stack(
[
td.where(cond, other, pad=pad)
for td, cond in zip(self.tensordicts, condition)
],
self.stack_dim,
)
# We should not pass out to stack because this will overwrite the tensors in-place, but
# we don't want that
if out is not None:
out.update(result)
return out
return result

def masked_fill_(self, mask: Tensor, value: float | bool) -> T:
mask_unbind = mask.unbind(dim=self.stack_dim)
Expand Down Expand Up @@ -8371,8 +8436,10 @@ def detach_(self) -> _CustomOpTensorDict:
self._source.detach_()
return self

def where(self, condition, other, *, out=None):
return self.to_tensordict().where(condition=condition, other=other, out=out)
def where(self, condition, other, *, out=None, pad=None):
return self.to_tensordict().where(
condition=condition, other=other, out=out, pad=pad
)

def masked_fill_(self, mask: Tensor, value: float | bool) -> _CustomOpTensorDict:
for key, item in self.items():
Expand Down
48 changes: 39 additions & 9 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,15 +1241,7 @@ def test_where(self, td_name, device):
for k in td.keys(True, True):
assert (td_where.get(k)[~mask] == 1).all()
td_where = td.clone()
# torch.where(mask, td, torch.zeros((), device=device), out=td_where)
# for k in td.keys(True, True):
# assert (td_where.get(k)[~mask] == 0).all()
if td_name == "td_params":
with pytest.raises(
RuntimeError, match="don't support automatic differentiation"
):
torch.where(mask, td, torch.ones_like(td), out=td_where)
return

if td_name == "td_h5":
with pytest.raises(
RuntimeError,
Expand All @@ -1261,6 +1253,44 @@ def test_where(self, td_name, device):
for k in td.keys(True, True):
assert (td_where.get(k)[~mask] == 1).all()

def test_where_pad(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
# test with other empty td
mask = torch.zeros(td.shape, dtype=torch.bool, device=td.device).bernoulli_()
if td_name in ("td_h5",):
td_full = td.to_tensordict()
else:
td_full = td
td_empty = td_full.empty()
result = td.where(mask, td_empty, pad=1)
for v in result.values(True, True):
assert (v[~mask] == 1).all()
td_empty = td_full.empty()
result = td_empty.where(~mask, td, pad=1)
for v in result.values(True, True):
assert (v[~mask] == 1).all()
# with output
td_out = td_full.empty()
result = td.where(mask, td_empty, pad=1, out=td_out)
for v in result.values(True, True):
assert (v[~mask] == 1).all()
if td_name not in ("td_params",):
assert result is td_out
else:
assert isinstance(result, TensorDictParams)
td_out = td_full.empty()
td_empty = td_full.empty()
result = td_empty.where(~mask, td, pad=1, out=td_out)
for v in result.values(True, True):
assert (v[~mask] == 1).all()
assert result is td_out

with pytest.raises(KeyError, match="not found and no pad value provided"):
td.where(mask, td_full.empty())
with pytest.raises(KeyError, match="not found and no pad value provided"):
td_full.empty().where(mask, td)

def test_masking_set(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
Expand Down

1 comment on commit 7747505

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 7747505 Previous: d445682 Ratio
benchmarks/common/memmap_benchmarks_test.py::test_add_one[memmap_tensor0] 15670.286508178997 iter/sec (stddev: 0.00009494664314424387) 32122.917760474204 iter/sec (stddev: 0.0000294097153603822) 2.05
benchmarks/common/memmap_benchmarks_test.py::test_memmaptd_index_op 189.19260881392057 iter/sec (stddev: 0.0004554519369218494) 423.5088871865344 iter/sec (stddev: 0.000016536187168753856) 2.24
benchmarks/nn/functional_benchmarks_test.py::test_vmap_mlp_speed[True-True] 455.95745853793323 iter/sec (stddev: 0.0007778332927176489) 927.3947237659027 iter/sec (stddev: 0.0001962858442279279) 2.03
benchmarks/nn/functional_benchmarks_test.py::test_vmap_mlp_speed[True-False] 876.462017150474 iter/sec (stddev: 0.0004153831061975895) 1865.028996032084 iter/sec (stddev: 0.000016194502687459375) 2.13
benchmarks/nn/functional_benchmarks_test.py::test_vmap_mlp_speed[False-True] 527.1710484337934 iter/sec (stddev: 0.0006221517744050702) 1089.7057081929795 iter/sec (stddev: 0.0001642519244406766) 2.07
benchmarks/nn/functional_benchmarks_test.py::test_vmap_mlp_speed[False-False] 1153.8229353936238 iter/sec (stddev: 0.00020387949991083456) 2336.442670140485 iter/sec (stddev: 0.0001217527776581118) 2.02

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.