diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 13415a47c..93724c99c 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1377,6 +1377,63 @@ def contiguous(self) -> T: ) return out + def densify(self, *, layout: torch.layout = torch.strided): + """Attempts to represent the lazy stack with contiguous tensors (plain tensors or nested). + + Keyword Args: + layout (torch.layout): the layout of the nested tensors, if any. Defaults to + :class:`~torch.strided`. + + """ + result = TensorDict._new_unsafe( + batch_size=self.batch_size, device=self.device, names=self.names + ) + for key in self._exclusive_keys(): + list_of_entries = [ + td._get_str(key, default=None) for td in self.tensordicts + ] + is_tensor = all( + isinstance(item, torch.Tensor) or item is None + for item in list_of_entries + ) + if is_tensor: + shapes = { + tensor.shape if tensor is not None else None + for tensor in list_of_entries + } + if None in shapes: + # There must be at least one non-None value + a_shape = None + while a_shape is None: + a_shape = shapes.pop() + if not a_shape: + raise RuntimeError( + f"Cannot densify a tensordict with values with empty shape and exclusive keys: got shape {a_shape}." + ) + none_shape = a_shape[:-1] + (0,) + for tensor in list_of_entries: + if tensor is not None: + a_tensor = tensor.new_zeros(none_shape) + break + list_of_entries = [ + tensor if tensor is not None else a_tensor + for tensor in list_of_entries + ] + shapes.update({a_shape, none_shape}) + if len(shapes) == 1: + tensor = torch.stack(list_of_entries, self.stack_dim) + else: + if self.stack_dim == 0: + tensor = torch.nested.nested_tensor( + list_of_entries, layout=layout + ) + else: + raise NotImplementedError + else: + tensor = self._get_str(key).densify(layout=layout) + result._set_str(key, tensor, validated=True, inplace=False) + return result + def empty( self, recurse=False, *, batch_size=None, device=NO_DEFAULT, names=None ) -> T: @@ -2702,23 +2759,25 @@ def insert(self, index: int, tensordict: T) -> None: "Expected new value to be TensorDictBase instance but got " f"{type(tensordict)} instead." ) + if self.tensordicts: + batch_size = self.tensordicts[0].batch_size + device = self.tensordicts[0].device - batch_size = self.tensordicts[0].batch_size - device = self.tensordicts[0].device - - _batch_size = tensordict.batch_size - _device = tensordict.device + _batch_size = tensordict.batch_size + _device = tensordict.device - if device != _device: - raise ValueError( - f"Devices differ: stack has device={device}, new value has " - f"device={_device}." - ) - if _batch_size != batch_size: - raise ValueError( - f"Batch sizes in tensordicts differs: stack has " - f"batch_size={batch_size}, new_value has batch_size={_batch_size}." - ) + if device != _device: + raise ValueError( + f"Devices differ: stack has device={device}, new value has " + f"device={_device}." + ) + if _batch_size != batch_size: + raise ValueError( + f"Batch sizes in tensordicts differs: stack has " + f"batch_size={batch_size}, new_value has batch_size={_batch_size}." + ) + else: + batch_size = tensordict.batch_size self.tensordicts.insert(index, tensordict) @@ -2751,6 +2810,8 @@ def is_locked(self) -> bool: if not td.is_locked: return False else: + if not self.tensordicts: + return False # In this case, all tensordicts were locked before the lazy stack # was created and they were not locked through the lazy stack. # This means we cannot cache the value because this lazy stack @@ -2826,6 +2887,9 @@ def __repr__(self): ) return f"{type(self).__name__}(\n{string})" + def _exclusive_keys(self): + return {key for td in self.tensordicts for key in td.keys()} + def _repr_exclusive_fields(self): keys = set(self.keys()) exclusive_keys = [ diff --git a/tensordict/base.py b/tensordict/base.py index 3487ebb25..3710d3bc7 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3210,6 +3210,29 @@ def _memmap_( share_non_tensor, ) -> T: ... + def densify(self, layout: torch.layout = torch.strided): + """Attempts to represent the lazy stack with contiguous tensors (plain tensors or nested). + + Keyword Args: + layout (torch.layout): the layout of the nested tensors, if any. Defaults to + :class:`~torch.strided`. + + """ + any_set = False + out_dict = {} + for key, val in self.items(): + if is_tensor_collection(val): + val_dense = val.densify(layout=layout) + any_set = any_set | (val_dense is not val) + val = val_dense + out_dict[key] = val + if any_set: + result = self.empty() + for key, val in out_dict.items(): + result._set_str(key, val, validated=True, inplace=False) + return result + return self + @property def saved_path(self): """Returns the path where a memmap saved TensorDict is being stored. diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 698a0b412..c043e5ad2 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -193,6 +193,7 @@ def __subclasscheck__(self, subclass): "cosh_", "cpu", "cuda", + "densify", "div", "div_", "empty", diff --git a/test/test_tensordict.py b/test/test_tensordict.py index b6b51822f..f31c8988e 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -113,6 +113,12 @@ pytest.mark.filterwarnings( "ignore:Indexing an h5py.Dataset object with a boolean mask that needs broadcasting does not work directly" ), + pytest.mark.filterwarnings( + "ignore:The PyTorch API of nested tensors is in prototype" + ), + pytest.mark.filterwarnings( + "ignore:Lazy modules are a new feature under heavy development so changes to the API or functionality" + ), ] mp_ctx = "fork" if (not torch.cuda.is_available() and not _IS_WINDOWS) else "spawn" @@ -7794,6 +7800,34 @@ def test_create_empty(self): assert td.device == torch.device("cpu") assert td.shape == torch.Size([1, 0, 2]) + def test_densify(self): + td0 = TensorDict( + a=torch.zeros((1,)), + b=torch.zeros((2,)), + d=TensorDict(e=torch.zeros(())), + ) + td1 = TensorDict( + b=torch.ones((1,)), c=torch.ones((2,)), d=TensorDict(e=torch.ones(())) + ) + td = LazyStackedTensorDict(td0, td1, stack_dim=0) + td_jagged = td.densify(layout=torch.jagged) + assert (td_jagged.exclude("c").unbind(0)[0] == 0).all() + assert (td_jagged.exclude("a").unbind(0)[1] == 1).all() + assert not td_jagged["d", "e"].is_nested + td_strided = td.densify(layout=torch.strided) + assert (td_strided.exclude("c")[0] == 0).all() + assert (td_strided.exclude("a")[1] == 1).all() + assert not td_strided["d", "e"].is_nested + td_nest = TensorDict(td=td, batch_size=[2]) + td_nest_jagged = td_nest.densify(layout=torch.jagged) + assert (td_nest_jagged.exclude(("td", "c")).unbind(0)[0] == 0).all() + assert (td_nest_jagged.exclude(("td", "a")).unbind(0)[1] == 1).all() + assert not td_nest_jagged["td", "d", "e"].is_nested + td_nest_strided = td_nest.densify(layout=torch.strided) + assert (td_nest_strided.exclude(("td", "c"))[0] == 0).all() + assert (td_nest_strided.exclude(("td", "a"))[1] == 1).all() + assert not td_nest_strided["td", "d", "e"].is_nested + @pytest.mark.parametrize("pos1", range(8)) @pytest.mark.parametrize("pos2", range(8)) @pytest.mark.parametrize("pos3", range(8))