@@ -22,6 +22,10 @@ class MyType:
2222 list_of_tensors : list [torch .Tensor ]
2323 numpy_array : np .ndarray
2424 unrecognized : UnrecognizedType
25+ small_f_contig_tensor : torch .Tensor
26+ large_f_contig_tensor : torch .Tensor
27+ small_non_contig_tensor : torch .Tensor
28+ large_non_contig_tensor : torch .Tensor
2529
2630
2731def test_encode_decode ():
@@ -40,17 +44,21 @@ def test_encode_decode():
4044 ],
4145 numpy_array = np .arange (512 ),
4246 unrecognized = UnrecognizedType (33 ),
47+ small_f_contig_tensor = torch .rand (5 , 4 ).t (),
48+ large_f_contig_tensor = torch .rand (1024 , 4 ).t (),
49+ small_non_contig_tensor = torch .rand (2 , 4 )[:, 1 :3 ],
50+ large_non_contig_tensor = torch .rand (1024 , 512 )[:, 10 :20 ],
4351 )
4452
4553 encoder = MsgpackEncoder ()
4654 decoder = MsgpackDecoder (MyType )
4755
4856 encoded = encoder .encode (obj )
4957
50- # There should be the main buffer + 2 large tensor buffers
51- # + 1 large numpy array. "large" is <= 256 bytes.
58+ # There should be the main buffer + 4 large tensor buffers
59+ # + 1 large numpy array. "large" is <= 512 bytes.
5260 # The two small tensors are encoded inline.
53- assert len (encoded ) == 4
61+ assert len (encoded ) == 6
5462
5563 decoded : MyType = decoder .decode (encoded )
5664
@@ -62,7 +70,7 @@ def test_encode_decode():
6270
6371 encoded2 = encoder .encode_into (obj , preallocated )
6472
65- assert len (encoded2 ) == 4
73+ assert len (encoded2 ) == 6
6674 assert encoded2 [0 ] is preallocated
6775
6876 decoded2 : MyType = decoder .decode (encoded2 )
@@ -78,3 +86,9 @@ def assert_equal(obj1: MyType, obj2: MyType):
7886 for a , b in zip (obj1 .list_of_tensors , obj2 .list_of_tensors ))
7987 assert np .array_equal (obj1 .numpy_array , obj2 .numpy_array )
8088 assert obj1 .unrecognized .an_int == obj2 .unrecognized .an_int
89+ assert torch .equal (obj1 .small_f_contig_tensor , obj2 .small_f_contig_tensor )
90+ assert torch .equal (obj1 .large_f_contig_tensor , obj2 .large_f_contig_tensor )
91+ assert torch .equal (obj1 .small_non_contig_tensor ,
92+ obj2 .small_non_contig_tensor )
93+ assert torch .equal (obj1 .large_non_contig_tensor ,
94+ obj2 .large_non_contig_tensor )
0 commit comments