@@ -1836,6 +1836,7 @@ def test_dense_to_jagged(
18361836 outer_dense_size = st .integers (0 , 5 ),
18371837 inner_dense_size = st .integers (0 , 5 ),
18381838 padding_value = st .sampled_from ([0 , - 1e-8 ]),
1839+ dtype = st .sampled_from ([torch .float , torch .half , torch .bfloat16 , torch .double ]),
18391840 use_cpu = st .booleans () if gpu_available else st .just (True ),
18401841 )
18411842 @settings (verbosity = Verbosity .verbose , max_examples = 20 , deadline = None )
@@ -1845,8 +1846,12 @@ def test_jagged_to_padded_dense(
18451846 outer_dense_size : int ,
18461847 inner_dense_size : int ,
18471848 padding_value : float ,
1849+ dtype : torch .dtype ,
18481850 use_cpu : bool ,
18491851 ) -> None :
1852+ # CPU doesn't support bfloat16
1853+ assume (not use_cpu or dtype != torch .bfloat16 )
1854+
18501855 # Testing with a basic crafted example.
18511856 # dense representation is
18521857 # [[[[0, 1], [ 0, 0], [0, 0]],
@@ -2006,7 +2011,7 @@ def mul_func(*args) -> torch.Tensor:
20062011 H = st .integers (1 , 3 ),
20072012 max_L = st .integers (1 , 32 ),
20082013 D = st .integers (0 , 32 ),
2009- dtype = st .sampled_from ([torch .float , torch .half , torch .double ]),
2014+ dtype = st .sampled_from ([torch .float , torch .half , torch .bfloat16 , torch . double ]),
20102015 use_cpu = st .booleans () if gpu_available else st .just (True ),
20112016 )
20122017 def test_batched_dense_vec_jagged_2d_mul (
@@ -2019,6 +2024,9 @@ def test_batched_dense_vec_jagged_2d_mul(
20192024 use_cpu : bool ,
20202025 ) -> None :
20212026 assume (H == 1 or B != 0 )
2027+ # CPU doesn't support bfloat16
2028+ assume (not use_cpu or dtype != torch .bfloat16 )
2029+
20222030 device = torch .device ("cpu" if use_cpu else "cuda" )
20232031 torch .backends .cuda .matmul .allow_tf32 = False
20242032
0 commit comments