diff --git a/tensordict/_td.py b/tensordict/_td.py index 13d2c4d04..c9098d021 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -59,7 +59,7 @@ _LOCK_ERROR, _maybe_correct_neg_dim, _mismatch_keys, - _NON_STR_KEY_ERR, + _NON_STR_KEY_ERR,_nested_tensor_shape, _NON_STR_KEY_TUPLE_ERR, _parse_to, _pass_through, @@ -4701,7 +4701,19 @@ def _save_metadata(data: TensorDictBase, prefix: Path, metadata=None): def _populate_memmap(*, dest, value, key, copy_existing, prefix, like, existsok): filename = None if prefix is None else str(prefix / f"{key}.memmap") if value.is_nested: - shape = value._nested_tensor_size() + if value.layout is torch.strided: + shape = value._nested_tensor_size() + else: + offsets = value.offsets() + if offsets is None: + lengths = value.lengths() + else: + lengths = offsets.diff() + shapes = [lengths] + for s in value.shape[2:]: + shapes.append(torch.full_like(lengths, s)) + shape = torch.stack(shapes, -1) + value = value.values() # Make the shape a memmap tensor too if prefix is not None: shape_filename = Path(filename) @@ -4713,6 +4725,7 @@ def _populate_memmap(*, dest, value, key, copy_existing, prefix, like, existsok) existsok=existsok, copy_data=True, ) + else: shape = None memmap_tensor = MemoryMappedTensor.from_tensor( @@ -4795,7 +4808,7 @@ def _update_metadata(*, metadata, key, value, is_collection): "shape": ( list(value.shape) if not value.is_nested - else list(value._nested_tensor_size().shape) + else list(_nested_tensor_shape(value).shape) ), "dtype": str(value.dtype), "is_nested": value.is_nested, diff --git a/tensordict/memmap.py b/tensordict/memmap.py index 43b4067fc..2825c79f9 100644 --- a/tensordict/memmap.py +++ b/tensordict/memmap.py @@ -517,11 +517,11 @@ def zeros(cls, *args, **kwargs): @classmethod @overload - def empty(cls, *size, dtype=None, device=None, filename=None): ... + def empty(cls, *size, dtype=None, device=None, filename=None, layout=None): ... @classmethod @overload - def empty(cls, shape, *, dtype=None, device=None, filename=None): ... + def empty(cls, shape, *, dtype=None, device=None, filename=None, layout=None): ... @classmethod def empty(cls, *args, **kwargs): @@ -539,8 +539,10 @@ def empty(cls, *args, **kwargs): is provided, a handler is used. existsok (bool, optional): whether it is ok to overwrite an existing file. Defaults to ``False``. + layout (torch.layout): the layout of the tensor if nested. Only `None` (default), `torch.jagged` and + `torch.strided` are accepted. """ - shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs) + shape, device, dtype, _, filename, layout = _proc_args_const(*args, **kwargs) if device is not None: device = torch.device(device) if device.type != "cpu": @@ -573,11 +575,19 @@ def empty(cls, *args, **kwargs): else: raise RuntimeError(NESTED_TENSOR_ERR) result = torch.frombuffer(memoryview(handler.buffer), dtype=dtype) - result = torch._nested_view_from_buffer( - result, - shape, - *offsets_strides, - ) + if layout in (None, torch.strided): + result = torch._nested_view_from_buffer( + result, + shape, + *offsets_strides, + layout=layout, + ) + else: + result = result.view((-1, *shape[0].tolist())) + result = torch.nested.nested_tensor_from_jagged( + result, + lengths=result[:, 0], + ) result = cls(result) result._handler = handler return result @@ -597,11 +607,20 @@ def empty(cls, *args, **kwargs): offsets_strides = func_offset_stride(shape) else: raise RuntimeError(NESTED_TENSOR_ERR) - result = torch._nested_view_from_buffer( - result, - shape, - *offsets_strides, - ) + if layout in (None, torch.strided): + result = torch._nested_view_from_buffer( + result, + shape, + *offsets_strides, + ) + else: + # TODO: we should not assume that the 2nd dim is the ragged one + result = result.view((-1, *shape[0, 1:].tolist())) + result = torch.nested.nested_tensor_from_jagged( + result, + lengths=result[:, 0], + ) + result = cls(result) result.filename = filename return result @@ -1030,6 +1049,7 @@ def _proc_args_const(*args, **kwargs): kwargs.pop("dtype", None), kwargs.pop("fill_value", None), kwargs.pop("filename", None), + kwargs.pop("layout", None), ) diff --git a/tensordict/utils.py b/tensordict/utils.py index afa54853b..60d5bb2e5 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -3028,3 +3028,18 @@ def _check_is_unflatten(new_shape, old_shape, return_flatten_dim=False): # j = len(new_shape) - j - 1 return out, (i, j) return out + +def _nested_tensor_shape(value): + if value.layout is torch.strided: + shape = value._nested_tensor_size() + else: + offsets = value.offsets() + if offsets is None: + lengths = value.lengths() + else: + lengths = offsets.diff() + shapes = [lengths] + for s in value.shape[2:]: + shapes.append(torch.full_like(lengths, s)) + shape = torch.stack(shapes, -1) + return shape \ No newline at end of file diff --git a/test/test_memmap.py b/test/test_memmap.py index 44f4a88f3..0d7e81d01 100644 --- a/test/test_memmap.py +++ b/test/test_memmap.py @@ -588,15 +588,20 @@ def test_ne(self): assert (memmap != ~memmap).all() +@pytest.mark.parametrize("layout", [torch.jagged, torch.strided]) class TestNestedTensor: - shape = torch.tensor([[2, 3], [2, 4], [3, 2]]) + def shape(self, layout): + if layout is torch.strided: + return torch.tensor([[2, 3], [2, 4], [3, 2]]) + return torch.tensor([[2, 3], [3, 3], [4, 3]]) @pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete") - def test_with_filename(self, tmpdir): + def test_with_filename(self, tmpdir, layout): filename = tmpdir + "/test_file2.memmap" tensor = MemoryMappedTensor.empty( - self.shape, filename=filename, dtype=torch.int + self.shape(layout), filename=filename, dtype=torch.int, layout=layout, ) + assert tensor.layout is layout assert isinstance(tensor, MemoryMappedTensor) assert tensor.dtype == torch.int tensor.fill_(2) @@ -605,14 +610,16 @@ def test_with_filename(self, tmpdir): filename = tmpdir + "/test_file0.memmap" tensor = MemoryMappedTensor.zeros( - self.shape, filename=filename, dtype=torch.bool + self.shape(layout), filename=filename, dtype=torch.bool, layout=layout, ) + assert tensor.layout is layout assert isinstance(tensor, MemoryMappedTensor) assert tensor.dtype == torch.bool assert tensor.filename is not None filename = tmpdir + "/test_file1.memmap" - tensor = MemoryMappedTensor.ones(self.shape, filename=filename, dtype=torch.int) + tensor = MemoryMappedTensor.ones(self.shape(layout), filename=filename, dtype=torch.int, layout=layout) + assert tensor.layout is layout assert type(tensor) is MemoryMappedTensor assert tensor.dtype == torch.int assert (tensor[0] == 1).all() @@ -620,7 +627,7 @@ def test_with_filename(self, tmpdir): filename = tmpdir + "/test_file3.memmap" tensor = torch.nested.nested_tensor( - [torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape)] + [torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape(layout))] ) memmap_tensor = MemoryMappedTensor.from_tensor(tensor, filename=filename) assert type(memmap_tensor) is MemoryMappedTensor @@ -629,7 +636,7 @@ def test_with_filename(self, tmpdir): assert (t1 == t2).all() memmap_tensor2 = MemoryMappedTensor.from_filename( - filename, dtype=memmap_tensor.dtype, shape=self.shape + filename, dtype=memmap_tensor.dtype, shape=self.shape(layout) ) assert type(memmap_tensor2) is MemoryMappedTensor for t1, t2 in zip(memmap_tensor2, memmap_tensor): @@ -637,27 +644,27 @@ def test_with_filename(self, tmpdir): assert (t1 == t2).all() @pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete") - def test_with_handler(self): - tensor = MemoryMappedTensor.empty(self.shape, dtype=torch.int) + def test_with_handler(self, layout): + tensor = MemoryMappedTensor.empty(self.shape(layout), dtype=torch.int, layout=layout) assert isinstance(tensor, MemoryMappedTensor) assert tensor.dtype == torch.int tensor.fill_(2) assert (tensor[0] == 2).all() assert tensor._handler is not None - tensor = MemoryMappedTensor.zeros(self.shape, dtype=torch.bool) + tensor = MemoryMappedTensor.zeros(self.shape(layout), dtype=torch.bool, layout=layout) assert isinstance(tensor, MemoryMappedTensor) assert tensor.dtype == torch.bool assert tensor._handler is not None - tensor = MemoryMappedTensor.ones(self.shape, dtype=torch.int) + tensor = MemoryMappedTensor.ones(self.shape(layout), dtype=torch.int, layout=layout) assert type(tensor) is MemoryMappedTensor assert tensor.dtype == torch.int assert (tensor[0] == 1).all() assert tensor._handler is not None tensor = torch.nested.nested_tensor( - [torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape)] + [torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape(layout))] ) memmap_tensor = MemoryMappedTensor.from_tensor(tensor) assert type(memmap_tensor) is MemoryMappedTensor @@ -666,7 +673,7 @@ def test_with_handler(self): assert (t1 == t2).all() memmap_tensor2 = MemoryMappedTensor.from_handler( - memmap_tensor._handler, dtype=memmap_tensor.dtype, shape=self.shape + memmap_tensor._handler, dtype=memmap_tensor.dtype, shape=self.shape(layout), layout=layout ) assert type(memmap_tensor2) is MemoryMappedTensor for t1, t2 in zip(memmap_tensor2, memmap_tensor): @@ -675,19 +682,19 @@ def test_with_handler(self): @pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete") @pytest.mark.parametrize("with_filename", [False, True]) - def test_from_storage(self, with_filename, tmpdir): + def test_from_storage(self, with_filename, tmpdir, layout): if with_filename: filename = Path(tmpdir) / "file.memmap" filename = str(filename) else: filename = None a = MemoryMappedTensor.from_tensor( - torch.arange(10, dtype=torch.float64), filename=filename + torch.arange(10, dtype=torch.float64), filename=filename, layout=layout, ) assert type(a) is MemoryMappedTensor shape = torch.tensor([[2, 2], [2, 3]]) b = MemoryMappedTensor.from_storage( - a.untyped_storage(), filename=filename, shape=shape, dtype=a.dtype + a.untyped_storage(), filename=filename, shape=shape, dtype=a.dtype, layout=layout, ) assert type(b) is MemoryMappedTensor assert (b._nested_tensor_size() == shape).all() @@ -695,14 +702,14 @@ def test_from_storage(self, with_filename, tmpdir): assert (b[1] == torch.arange(4, 10).view(2, 3)).all() @pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete") - def test_save_td_with_nested(self, tmpdir): + def test_save_td_with_nested(self, tmpdir, layout): td = TensorDict( { "a": torch.nested.nested_tensor( [ torch.arange(12, dtype=torch.float64).view(3, 4), torch.arange(15, dtype=torch.float64).view(3, 5), - ] + ], layout=layout, ) }, batch_size=[2, 3],