Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Densify lazy tensordicts #955

Merged
merged 4 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 79 additions & 15 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,63 @@ def contiguous(self) -> T:
)
return out

def densify(self, *, layout: torch.layout = torch.strided):
"""Attempts to represent the lazy stack with contiguous tensors (plain tensors or nested).

Keyword Args:
layout (torch.layout): the layout of the nested tensors, if any. Defaults to
:class:`~torch.strided`.

"""
result = TensorDict._new_unsafe(
batch_size=self.batch_size, device=self.device, names=self.names
)
for key in self._exclusive_keys():
list_of_entries = [
td._get_str(key, default=None) for td in self.tensordicts
]
is_tensor = all(
isinstance(item, torch.Tensor) or item is None
for item in list_of_entries
)
if is_tensor:
shapes = {
tensor.shape if tensor is not None else None
for tensor in list_of_entries
}
if None in shapes:
# There must be at least one non-None value
a_shape = None
while a_shape is None:
a_shape = shapes.pop()
if not a_shape:
raise RuntimeError(
f"Cannot densify a tensordict with values with empty shape and exclusive keys: got shape {a_shape}."
)
none_shape = a_shape[:-1] + (0,)
for tensor in list_of_entries:
if tensor is not None:
a_tensor = tensor.new_zeros(none_shape)
break
list_of_entries = [
tensor if tensor is not None else a_tensor
for tensor in list_of_entries
]
shapes.update({a_shape, none_shape})
if len(shapes) == 1:
tensor = torch.stack(list_of_entries, self.stack_dim)
else:
if self.stack_dim == 0:
tensor = torch.nested.nested_tensor(
list_of_entries, layout=layout
)
else:
raise NotImplementedError
else:
tensor = self._get_str(key).densify(layout=layout)
result._set_str(key, tensor, validated=True, inplace=False)
return result

def empty(
self, recurse=False, *, batch_size=None, device=NO_DEFAULT, names=None
) -> T:
Expand Down Expand Up @@ -2702,23 +2759,25 @@ def insert(self, index: int, tensordict: T) -> None:
"Expected new value to be TensorDictBase instance but got "
f"{type(tensordict)} instead."
)
if self.tensordicts:
batch_size = self.tensordicts[0].batch_size
device = self.tensordicts[0].device

batch_size = self.tensordicts[0].batch_size
device = self.tensordicts[0].device

_batch_size = tensordict.batch_size
_device = tensordict.device
_batch_size = tensordict.batch_size
_device = tensordict.device

if device != _device:
raise ValueError(
f"Devices differ: stack has device={device}, new value has "
f"device={_device}."
)
if _batch_size != batch_size:
raise ValueError(
f"Batch sizes in tensordicts differs: stack has "
f"batch_size={batch_size}, new_value has batch_size={_batch_size}."
)
if device != _device:
raise ValueError(
f"Devices differ: stack has device={device}, new value has "
f"device={_device}."
)
if _batch_size != batch_size:
raise ValueError(
f"Batch sizes in tensordicts differs: stack has "
f"batch_size={batch_size}, new_value has batch_size={_batch_size}."
)
else:
batch_size = tensordict.batch_size

self.tensordicts.insert(index, tensordict)

Expand Down Expand Up @@ -2751,6 +2810,8 @@ def is_locked(self) -> bool:
if not td.is_locked:
return False
else:
if not self.tensordicts:
return False
# In this case, all tensordicts were locked before the lazy stack
# was created and they were not locked through the lazy stack.
# This means we cannot cache the value because this lazy stack
Expand Down Expand Up @@ -2826,6 +2887,9 @@ def __repr__(self):
)
return f"{type(self).__name__}(\n{string})"

def _exclusive_keys(self):
return {key for td in self.tensordicts for key in td.keys()}

def _repr_exclusive_fields(self):
keys = set(self.keys())
exclusive_keys = [
Expand Down
23 changes: 23 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3210,6 +3210,29 @@ def _memmap_(
share_non_tensor,
) -> T: ...

def densify(self, layout: torch.layout = torch.strided):
"""Attempts to represent the lazy stack with contiguous tensors (plain tensors or nested).

Keyword Args:
layout (torch.layout): the layout of the nested tensors, if any. Defaults to
:class:`~torch.strided`.

"""
any_set = False
out_dict = {}
for key, val in self.items():
if is_tensor_collection(val):
val_dense = val.densify(layout=layout)
any_set = any_set | (val_dense is not val)
val = val_dense
out_dict[key] = val
if any_set:
result = self.empty()
for key, val in out_dict.items():
result._set_str(key, val, validated=True, inplace=False)
return result
return self

@property
def saved_path(self):
"""Returns the path where a memmap saved TensorDict is being stored.
Expand Down
1 change: 1 addition & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def __subclasscheck__(self, subclass):
"cosh_",
"cpu",
"cuda",
"densify",
"div",
"div_",
"empty",
Expand Down
34 changes: 34 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@
pytest.mark.filterwarnings(
"ignore:Indexing an h5py.Dataset object with a boolean mask that needs broadcasting does not work directly"
),
pytest.mark.filterwarnings(
"ignore:The PyTorch API of nested tensors is in prototype"
),
pytest.mark.filterwarnings(
"ignore:Lazy modules are a new feature under heavy development so changes to the API or functionality"
),
]

mp_ctx = "fork" if (not torch.cuda.is_available() and not _IS_WINDOWS) else "spawn"
Expand Down Expand Up @@ -7794,6 +7800,34 @@ def test_create_empty(self):
assert td.device == torch.device("cpu")
assert td.shape == torch.Size([1, 0, 2])

def test_densify(self):
td0 = TensorDict(
a=torch.zeros((1,)),
b=torch.zeros((2,)),
d=TensorDict(e=torch.zeros(())),
)
td1 = TensorDict(
b=torch.ones((1,)), c=torch.ones((2,)), d=TensorDict(e=torch.ones(()))
)
td = LazyStackedTensorDict(td0, td1, stack_dim=0)
td_jagged = td.densify(layout=torch.jagged)
assert (td_jagged.exclude("c").unbind(0)[0] == 0).all()
assert (td_jagged.exclude("a").unbind(0)[1] == 1).all()
assert not td_jagged["d", "e"].is_nested
td_strided = td.densify(layout=torch.strided)
assert (td_strided.exclude("c")[0] == 0).all()
assert (td_strided.exclude("a")[1] == 1).all()
assert not td_strided["d", "e"].is_nested
td_nest = TensorDict(td=td, batch_size=[2])
td_nest_jagged = td_nest.densify(layout=torch.jagged)
assert (td_nest_jagged.exclude(("td", "c")).unbind(0)[0] == 0).all()
assert (td_nest_jagged.exclude(("td", "a")).unbind(0)[1] == 1).all()
assert not td_nest_jagged["td", "d", "e"].is_nested
td_nest_strided = td_nest.densify(layout=torch.strided)
assert (td_nest_strided.exclude(("td", "c"))[0] == 0).all()
assert (td_nest_strided.exclude(("td", "a"))[1] == 1).all()
assert not td_nest_strided["td", "d", "e"].is_nested

@pytest.mark.parametrize("pos1", range(8))
@pytest.mark.parametrize("pos2", range(8))
@pytest.mark.parametrize("pos3", range(8))
Expand Down
Loading