@@ -47,6 +47,10 @@ def test_encode_decode():
4747 torch .rand ((1 , 10 ), dtype = torch .float32 ),
4848 torch .rand ((3 , 5 , 4000 ), dtype = torch .float64 ),
4949 torch .tensor (1984 ), # test scalar too
50+ # Make sure to test bf16 which numpy doesn't support.
51+ torch .rand ((3 , 5 , 1000 ), dtype = torch .bfloat16 ),
52+ torch .tensor ([float ("-inf" ), float ("inf" )] * 1024 ,
53+ dtype = torch .bfloat16 ),
5054 ],
5155 numpy_array = np .arange (512 ),
5256 unrecognized = UnrecognizedType (33 ),
@@ -64,7 +68,7 @@ def test_encode_decode():
6468 # There should be the main buffer + 4 large tensor buffers
6569 # + 1 large numpy array. "large" is <= 512 bytes.
6670 # The two small tensors are encoded inline.
67- assert len (encoded ) == 6
71+ assert len (encoded ) == 8
6872
6973 decoded : MyType = decoder .decode (encoded )
7074
@@ -76,7 +80,7 @@ def test_encode_decode():
7680
7781 encoded2 = encoder .encode_into (obj , preallocated )
7882
79- assert len (encoded2 ) == 6
83+ assert len (encoded2 ) == 8
8084 assert encoded2 [0 ] is preallocated
8185
8286 decoded2 : MyType = decoder .decode (encoded2 )
0 commit comments