@@ -588,15 +588,20 @@ def test_ne(self):
588588 assert (memmap != ~ memmap ).all ()
589589
590590
591+ @pytest .mark .parametrize ("layout" , [torch .jagged , torch .strided ])
591592class TestNestedTensor :
592- shape = torch .tensor ([[2 , 3 ], [2 , 4 ], [3 , 2 ]])
593+ def shape (self , layout ):
594+ if layout is torch .strided :
595+ return torch .tensor ([[2 , 3 ], [2 , 4 ], [3 , 2 ]])
596+ return torch .tensor ([[2 , 3 ], [3 , 3 ], [4 , 3 ]])
593597
594598 @pytest .mark .skipif (not HAS_NESTED_TENSOR , reason = "Nested tensor incomplete" )
595- def test_with_filename (self , tmpdir ):
599+ def test_with_filename (self , tmpdir , layout ):
596600 filename = tmpdir + "/test_file2.memmap"
597601 tensor = MemoryMappedTensor .empty (
598- self .shape , filename = filename , dtype = torch .int
602+ self .shape ( layout ) , filename = filename , dtype = torch .int , layout = layout ,
599603 )
604+ assert tensor .layout is layout
600605 assert isinstance (tensor , MemoryMappedTensor )
601606 assert tensor .dtype == torch .int
602607 tensor .fill_ (2 )
@@ -605,22 +610,24 @@ def test_with_filename(self, tmpdir):
605610
606611 filename = tmpdir + "/test_file0.memmap"
607612 tensor = MemoryMappedTensor .zeros (
608- self .shape , filename = filename , dtype = torch .bool
613+ self .shape ( layout ) , filename = filename , dtype = torch .bool , layout = layout ,
609614 )
615+ assert tensor .layout is layout
610616 assert isinstance (tensor , MemoryMappedTensor )
611617 assert tensor .dtype == torch .bool
612618 assert tensor .filename is not None
613619
614620 filename = tmpdir + "/test_file1.memmap"
615- tensor = MemoryMappedTensor .ones (self .shape , filename = filename , dtype = torch .int )
621+ tensor = MemoryMappedTensor .ones (self .shape (layout ), filename = filename , dtype = torch .int , layout = layout )
622+ assert tensor .layout is layout
616623 assert type (tensor ) is MemoryMappedTensor
617624 assert tensor .dtype == torch .int
618625 assert (tensor [0 ] == 1 ).all ()
619626 assert tensor .filename is not None
620627
621628 filename = tmpdir + "/test_file3.memmap"
622629 tensor = torch .nested .nested_tensor (
623- [torch .zeros (shape .tolist ()) + i for i , shape in enumerate (self .shape )]
630+ [torch .zeros (shape .tolist ()) + i for i , shape in enumerate (self .shape ( layout ) )]
624631 )
625632 memmap_tensor = MemoryMappedTensor .from_tensor (tensor , filename = filename )
626633 assert type (memmap_tensor ) is MemoryMappedTensor
@@ -629,35 +636,35 @@ def test_with_filename(self, tmpdir):
629636 assert (t1 == t2 ).all ()
630637
631638 memmap_tensor2 = MemoryMappedTensor .from_filename (
632- filename , dtype = memmap_tensor .dtype , shape = self .shape
639+ filename , dtype = memmap_tensor .dtype , shape = self .shape ( layout )
633640 )
634641 assert type (memmap_tensor2 ) is MemoryMappedTensor
635642 for t1 , t2 in zip (memmap_tensor2 , memmap_tensor ):
636643 assert t1 .dtype == t2 .dtype
637644 assert (t1 == t2 ).all ()
638645
639646 @pytest .mark .skipif (not HAS_NESTED_TENSOR , reason = "Nested tensor incomplete" )
640- def test_with_handler (self ):
641- tensor = MemoryMappedTensor .empty (self .shape , dtype = torch .int )
647+ def test_with_handler (self , layout ):
648+ tensor = MemoryMappedTensor .empty (self .shape ( layout ) , dtype = torch .int , layout = layout )
642649 assert isinstance (tensor , MemoryMappedTensor )
643650 assert tensor .dtype == torch .int
644651 tensor .fill_ (2 )
645652 assert (tensor [0 ] == 2 ).all ()
646653 assert tensor ._handler is not None
647654
648- tensor = MemoryMappedTensor .zeros (self .shape , dtype = torch .bool )
655+ tensor = MemoryMappedTensor .zeros (self .shape ( layout ) , dtype = torch .bool , layout = layout )
649656 assert isinstance (tensor , MemoryMappedTensor )
650657 assert tensor .dtype == torch .bool
651658 assert tensor ._handler is not None
652659
653- tensor = MemoryMappedTensor .ones (self .shape , dtype = torch .int )
660+ tensor = MemoryMappedTensor .ones (self .shape ( layout ) , dtype = torch .int , layout = layout )
654661 assert type (tensor ) is MemoryMappedTensor
655662 assert tensor .dtype == torch .int
656663 assert (tensor [0 ] == 1 ).all ()
657664 assert tensor ._handler is not None
658665
659666 tensor = torch .nested .nested_tensor (
660- [torch .zeros (shape .tolist ()) + i for i , shape in enumerate (self .shape )]
667+ [torch .zeros (shape .tolist ()) + i for i , shape in enumerate (self .shape ( layout ) )]
661668 )
662669 memmap_tensor = MemoryMappedTensor .from_tensor (tensor )
663670 assert type (memmap_tensor ) is MemoryMappedTensor
@@ -666,7 +673,7 @@ def test_with_handler(self):
666673 assert (t1 == t2 ).all ()
667674
668675 memmap_tensor2 = MemoryMappedTensor .from_handler (
669- memmap_tensor ._handler , dtype = memmap_tensor .dtype , shape = self .shape
676+ memmap_tensor ._handler , dtype = memmap_tensor .dtype , shape = self .shape ( layout ), layout = layout
670677 )
671678 assert type (memmap_tensor2 ) is MemoryMappedTensor
672679 for t1 , t2 in zip (memmap_tensor2 , memmap_tensor ):
@@ -675,34 +682,34 @@ def test_with_handler(self):
675682
676683 @pytest .mark .skipif (not HAS_NESTED_TENSOR , reason = "Nested tensor incomplete" )
677684 @pytest .mark .parametrize ("with_filename" , [False , True ])
678- def test_from_storage (self , with_filename , tmpdir ):
685+ def test_from_storage (self , with_filename , tmpdir , layout ):
679686 if with_filename :
680687 filename = Path (tmpdir ) / "file.memmap"
681688 filename = str (filename )
682689 else :
683690 filename = None
684691 a = MemoryMappedTensor .from_tensor (
685- torch .arange (10 , dtype = torch .float64 ), filename = filename
692+ torch .arange (10 , dtype = torch .float64 ), filename = filename , layout = layout ,
686693 )
687694 assert type (a ) is MemoryMappedTensor
688695 shape = torch .tensor ([[2 , 2 ], [2 , 3 ]])
689696 b = MemoryMappedTensor .from_storage (
690- a .untyped_storage (), filename = filename , shape = shape , dtype = a .dtype
697+ a .untyped_storage (), filename = filename , shape = shape , dtype = a .dtype , layout = layout ,
691698 )
692699 assert type (b ) is MemoryMappedTensor
693700 assert (b ._nested_tensor_size () == shape ).all ()
694701 assert (b [0 ] == torch .arange (4 ).view (2 , 2 )).all ()
695702 assert (b [1 ] == torch .arange (4 , 10 ).view (2 , 3 )).all ()
696703
697704 @pytest .mark .skipif (not HAS_NESTED_TENSOR , reason = "Nested tensor incomplete" )
698- def test_save_td_with_nested (self , tmpdir ):
705+ def test_save_td_with_nested (self , tmpdir , layout ):
699706 td = TensorDict (
700707 {
701708 "a" : torch .nested .nested_tensor (
702709 [
703710 torch .arange (12 , dtype = torch .float64 ).view (3 , 4 ),
704711 torch .arange (15 , dtype = torch .float64 ).view (3 , 5 ),
705- ]
712+ ], layout = layout ,
706713 )
707714 },
708715 batch_size = [2 , 3 ],
0 commit comments