Skip to content

Commit 22269e2

Browse files
author
Vincent Moens
committed
[Feature] memmory mapped jagged tensors
ghstack-source-id: d5fb34e Pull Request resolved: #1291
1 parent c61d045 commit 22269e2

File tree

4 files changed

+89
-34
lines changed

4 files changed

+89
-34
lines changed

tensordict/_td.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
_LOCK_ERROR,
6060
_maybe_correct_neg_dim,
6161
_mismatch_keys,
62-
_NON_STR_KEY_ERR,
62+
_NON_STR_KEY_ERR,_nested_tensor_shape,
6363
_NON_STR_KEY_TUPLE_ERR,
6464
_parse_to,
6565
_pass_through,
@@ -4701,7 +4701,19 @@ def _save_metadata(data: TensorDictBase, prefix: Path, metadata=None):
47014701
def _populate_memmap(*, dest, value, key, copy_existing, prefix, like, existsok):
47024702
filename = None if prefix is None else str(prefix / f"{key}.memmap")
47034703
if value.is_nested:
4704-
shape = value._nested_tensor_size()
4704+
if value.layout is torch.strided:
4705+
shape = value._nested_tensor_size()
4706+
else:
4707+
offsets = value.offsets()
4708+
if offsets is None:
4709+
lengths = value.lengths()
4710+
else:
4711+
lengths = offsets.diff()
4712+
shapes = [lengths]
4713+
for s in value.shape[2:]:
4714+
shapes.append(torch.full_like(lengths, s))
4715+
shape = torch.stack(shapes, -1)
4716+
value = value.values()
47054717
# Make the shape a memmap tensor too
47064718
if prefix is not None:
47074719
shape_filename = Path(filename)
@@ -4713,6 +4725,7 @@ def _populate_memmap(*, dest, value, key, copy_existing, prefix, like, existsok)
47134725
existsok=existsok,
47144726
copy_data=True,
47154727
)
4728+
47164729
else:
47174730
shape = None
47184731
memmap_tensor = MemoryMappedTensor.from_tensor(
@@ -4795,7 +4808,7 @@ def _update_metadata(*, metadata, key, value, is_collection):
47954808
"shape": (
47964809
list(value.shape)
47974810
if not value.is_nested
4798-
else list(value._nested_tensor_size().shape)
4811+
else list(_nested_tensor_shape(value).shape)
47994812
),
48004813
"dtype": str(value.dtype),
48014814
"is_nested": value.is_nested,

tensordict/memmap.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -517,11 +517,11 @@ def zeros(cls, *args, **kwargs):
517517

518518
@classmethod
519519
@overload
520-
def empty(cls, *size, dtype=None, device=None, filename=None): ...
520+
def empty(cls, *size, dtype=None, device=None, filename=None, layout=None): ...
521521

522522
@classmethod
523523
@overload
524-
def empty(cls, shape, *, dtype=None, device=None, filename=None): ...
524+
def empty(cls, shape, *, dtype=None, device=None, filename=None, layout=None): ...
525525

526526
@classmethod
527527
def empty(cls, *args, **kwargs):
@@ -539,8 +539,10 @@ def empty(cls, *args, **kwargs):
539539
is provided, a handler is used.
540540
existsok (bool, optional): whether it is ok to overwrite an existing file.
541541
Defaults to ``False``.
542+
layout (torch.layout): the layout of the tensor if nested. Only `None` (default), `torch.jagged` and
543+
`torch.strided` are accepted.
542544
"""
543-
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
545+
shape, device, dtype, _, filename, layout = _proc_args_const(*args, **kwargs)
544546
if device is not None:
545547
device = torch.device(device)
546548
if device.type != "cpu":
@@ -573,11 +575,19 @@ def empty(cls, *args, **kwargs):
573575
else:
574576
raise RuntimeError(NESTED_TENSOR_ERR)
575577
result = torch.frombuffer(memoryview(handler.buffer), dtype=dtype)
576-
result = torch._nested_view_from_buffer(
577-
result,
578-
shape,
579-
*offsets_strides,
580-
)
578+
if layout in (None, torch.strided):
579+
result = torch._nested_view_from_buffer(
580+
result,
581+
shape,
582+
*offsets_strides,
583+
layout=layout,
584+
)
585+
else:
586+
result = result.view((-1, *shape[0].tolist()))
587+
result = torch.nested.nested_tensor_from_jagged(
588+
result,
589+
lengths=result[:, 0],
590+
)
581591
result = cls(result)
582592
result._handler = handler
583593
return result
@@ -597,11 +607,20 @@ def empty(cls, *args, **kwargs):
597607
offsets_strides = func_offset_stride(shape)
598608
else:
599609
raise RuntimeError(NESTED_TENSOR_ERR)
600-
result = torch._nested_view_from_buffer(
601-
result,
602-
shape,
603-
*offsets_strides,
604-
)
610+
if layout in (None, torch.strided):
611+
result = torch._nested_view_from_buffer(
612+
result,
613+
shape,
614+
*offsets_strides,
615+
)
616+
else:
617+
# TODO: we should not assume that the 2nd dim is the ragged one
618+
result = result.view((-1, *shape[0, 1:].tolist()))
619+
result = torch.nested.nested_tensor_from_jagged(
620+
result,
621+
lengths=result[:, 0],
622+
)
623+
605624
result = cls(result)
606625
result.filename = filename
607626
return result
@@ -1030,6 +1049,7 @@ def _proc_args_const(*args, **kwargs):
10301049
kwargs.pop("dtype", None),
10311050
kwargs.pop("fill_value", None),
10321051
kwargs.pop("filename", None),
1052+
kwargs.pop("layout", None),
10331053
)
10341054

10351055

tensordict/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3028,3 +3028,18 @@ def _check_is_unflatten(new_shape, old_shape, return_flatten_dim=False):
30283028
# j = len(new_shape) - j - 1
30293029
return out, (i, j)
30303030
return out
3031+
3032+
def _nested_tensor_shape(value):
3033+
if value.layout is torch.strided:
3034+
shape = value._nested_tensor_size()
3035+
else:
3036+
offsets = value.offsets()
3037+
if offsets is None:
3038+
lengths = value.lengths()
3039+
else:
3040+
lengths = offsets.diff()
3041+
shapes = [lengths]
3042+
for s in value.shape[2:]:
3043+
shapes.append(torch.full_like(lengths, s))
3044+
shape = torch.stack(shapes, -1)
3045+
return shape

test/test_memmap.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -588,15 +588,20 @@ def test_ne(self):
588588
assert (memmap != ~memmap).all()
589589

590590

591+
@pytest.mark.parametrize("layout", [torch.jagged, torch.strided])
591592
class TestNestedTensor:
592-
shape = torch.tensor([[2, 3], [2, 4], [3, 2]])
593+
def shape(self, layout):
594+
if layout is torch.strided:
595+
return torch.tensor([[2, 3], [2, 4], [3, 2]])
596+
return torch.tensor([[2, 3], [3, 3], [4, 3]])
593597

594598
@pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete")
595-
def test_with_filename(self, tmpdir):
599+
def test_with_filename(self, tmpdir, layout):
596600
filename = tmpdir + "/test_file2.memmap"
597601
tensor = MemoryMappedTensor.empty(
598-
self.shape, filename=filename, dtype=torch.int
602+
self.shape(layout), filename=filename, dtype=torch.int, layout=layout,
599603
)
604+
assert tensor.layout is layout
600605
assert isinstance(tensor, MemoryMappedTensor)
601606
assert tensor.dtype == torch.int
602607
tensor.fill_(2)
@@ -605,22 +610,24 @@ def test_with_filename(self, tmpdir):
605610

606611
filename = tmpdir + "/test_file0.memmap"
607612
tensor = MemoryMappedTensor.zeros(
608-
self.shape, filename=filename, dtype=torch.bool
613+
self.shape(layout), filename=filename, dtype=torch.bool, layout=layout,
609614
)
615+
assert tensor.layout is layout
610616
assert isinstance(tensor, MemoryMappedTensor)
611617
assert tensor.dtype == torch.bool
612618
assert tensor.filename is not None
613619

614620
filename = tmpdir + "/test_file1.memmap"
615-
tensor = MemoryMappedTensor.ones(self.shape, filename=filename, dtype=torch.int)
621+
tensor = MemoryMappedTensor.ones(self.shape(layout), filename=filename, dtype=torch.int, layout=layout)
622+
assert tensor.layout is layout
616623
assert type(tensor) is MemoryMappedTensor
617624
assert tensor.dtype == torch.int
618625
assert (tensor[0] == 1).all()
619626
assert tensor.filename is not None
620627

621628
filename = tmpdir + "/test_file3.memmap"
622629
tensor = torch.nested.nested_tensor(
623-
[torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape)]
630+
[torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape(layout))]
624631
)
625632
memmap_tensor = MemoryMappedTensor.from_tensor(tensor, filename=filename)
626633
assert type(memmap_tensor) is MemoryMappedTensor
@@ -629,35 +636,35 @@ def test_with_filename(self, tmpdir):
629636
assert (t1 == t2).all()
630637

631638
memmap_tensor2 = MemoryMappedTensor.from_filename(
632-
filename, dtype=memmap_tensor.dtype, shape=self.shape
639+
filename, dtype=memmap_tensor.dtype, shape=self.shape(layout)
633640
)
634641
assert type(memmap_tensor2) is MemoryMappedTensor
635642
for t1, t2 in zip(memmap_tensor2, memmap_tensor):
636643
assert t1.dtype == t2.dtype
637644
assert (t1 == t2).all()
638645

639646
@pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete")
640-
def test_with_handler(self):
641-
tensor = MemoryMappedTensor.empty(self.shape, dtype=torch.int)
647+
def test_with_handler(self, layout):
648+
tensor = MemoryMappedTensor.empty(self.shape(layout), dtype=torch.int, layout=layout)
642649
assert isinstance(tensor, MemoryMappedTensor)
643650
assert tensor.dtype == torch.int
644651
tensor.fill_(2)
645652
assert (tensor[0] == 2).all()
646653
assert tensor._handler is not None
647654

648-
tensor = MemoryMappedTensor.zeros(self.shape, dtype=torch.bool)
655+
tensor = MemoryMappedTensor.zeros(self.shape(layout), dtype=torch.bool, layout=layout)
649656
assert isinstance(tensor, MemoryMappedTensor)
650657
assert tensor.dtype == torch.bool
651658
assert tensor._handler is not None
652659

653-
tensor = MemoryMappedTensor.ones(self.shape, dtype=torch.int)
660+
tensor = MemoryMappedTensor.ones(self.shape(layout), dtype=torch.int, layout=layout)
654661
assert type(tensor) is MemoryMappedTensor
655662
assert tensor.dtype == torch.int
656663
assert (tensor[0] == 1).all()
657664
assert tensor._handler is not None
658665

659666
tensor = torch.nested.nested_tensor(
660-
[torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape)]
667+
[torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape(layout))]
661668
)
662669
memmap_tensor = MemoryMappedTensor.from_tensor(tensor)
663670
assert type(memmap_tensor) is MemoryMappedTensor
@@ -666,7 +673,7 @@ def test_with_handler(self):
666673
assert (t1 == t2).all()
667674

668675
memmap_tensor2 = MemoryMappedTensor.from_handler(
669-
memmap_tensor._handler, dtype=memmap_tensor.dtype, shape=self.shape
676+
memmap_tensor._handler, dtype=memmap_tensor.dtype, shape=self.shape(layout), layout=layout
670677
)
671678
assert type(memmap_tensor2) is MemoryMappedTensor
672679
for t1, t2 in zip(memmap_tensor2, memmap_tensor):
@@ -675,34 +682,34 @@ def test_with_handler(self):
675682

676683
@pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete")
677684
@pytest.mark.parametrize("with_filename", [False, True])
678-
def test_from_storage(self, with_filename, tmpdir):
685+
def test_from_storage(self, with_filename, tmpdir, layout):
679686
if with_filename:
680687
filename = Path(tmpdir) / "file.memmap"
681688
filename = str(filename)
682689
else:
683690
filename = None
684691
a = MemoryMappedTensor.from_tensor(
685-
torch.arange(10, dtype=torch.float64), filename=filename
692+
torch.arange(10, dtype=torch.float64), filename=filename, layout=layout,
686693
)
687694
assert type(a) is MemoryMappedTensor
688695
shape = torch.tensor([[2, 2], [2, 3]])
689696
b = MemoryMappedTensor.from_storage(
690-
a.untyped_storage(), filename=filename, shape=shape, dtype=a.dtype
697+
a.untyped_storage(), filename=filename, shape=shape, dtype=a.dtype, layout=layout,
691698
)
692699
assert type(b) is MemoryMappedTensor
693700
assert (b._nested_tensor_size() == shape).all()
694701
assert (b[0] == torch.arange(4).view(2, 2)).all()
695702
assert (b[1] == torch.arange(4, 10).view(2, 3)).all()
696703

697704
@pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete")
698-
def test_save_td_with_nested(self, tmpdir):
705+
def test_save_td_with_nested(self, tmpdir, layout):
699706
td = TensorDict(
700707
{
701708
"a": torch.nested.nested_tensor(
702709
[
703710
torch.arange(12, dtype=torch.float64).view(3, 4),
704711
torch.arange(15, dtype=torch.float64).view(3, 5),
705-
]
712+
], layout=layout,
706713
)
707714
},
708715
batch_size=[2, 3],

0 commit comments

Comments
 (0)