@@ -229,16 +229,16 @@ def selective_scan_opcheck_fn(
229229
230230
231231@pytest .mark .parametrize ("wtype" , [torch .float32 ])
232- @pytest .mark .parametrize ("itype" , [torch .float32 , torch .float16 , torch . bfloat16 ])
233- @pytest .mark .parametrize ("seqlen" , [128 , 256 , 512 , 1024 , 2048 , 4096 ])
232+ @pytest .mark .parametrize ("itype" , [torch .float32 , torch .bfloat16 ])
233+ @pytest .mark .parametrize ("seqlen" , [128 , 1024 , 4096 ])
234234@pytest .mark .parametrize ("has_delta_bias" , [True ])
235235@pytest .mark .parametrize ("delta_softplus" , [True ])
236236@pytest .mark .parametrize ("has_z" , [True ])
237237@pytest .mark .parametrize ("has_D" , [True ])
238238@pytest .mark .parametrize ("varBC_groups" , [1 , 2 ])
239239@pytest .mark .parametrize ("is_variable_C" , [True ])
240240@pytest .mark .parametrize ("is_variable_B" , [True ])
241- @pytest .mark .parametrize ("scan_chunks" , [1 , 2 , 3 ])
241+ @pytest .mark .parametrize ("scan_chunks" , [1 , 3 ])
242242def test_selective_scan (
243243 is_variable_B ,
244244 is_variable_C ,
@@ -375,9 +375,9 @@ def test_selective_scan(
375375 )
376376
377377
378- @pytest .mark .parametrize ("itype" , [torch .float32 , torch .float16 , torch . bfloat16 ])
378+ @pytest .mark .parametrize ("itype" , [torch .float32 , torch .bfloat16 ])
379379@pytest .mark .parametrize ("has_z" , [False , True ])
380- @pytest .mark .parametrize ("dstate" , [16 , 32 , 64 ])
380+ @pytest .mark .parametrize ("dstate" , [16 , 64 ])
381381@pytest .mark .parametrize ("dim" , [2048 , 2048 + 16 , 4096 ])
382382def test_selective_state_update (dim , dstate , has_z , itype ):
383383 device = "cuda"
@@ -413,7 +413,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
413413
414414@pytest .mark .parametrize ("wtype" , [torch .float32 ])
415415@pytest .mark .parametrize ("itype" , [torch .float32 ])
416- @pytest .mark .parametrize ("seqlen" , [1 , 128 , 129 , 256 , 512 , 1024 , 2048 , 4096 ])
416+ @pytest .mark .parametrize ("seqlen" , [1 , 256 , 1024 , 4096 ])
417417@pytest .mark .parametrize ("return_last_state" , [True ])
418418@pytest .mark .parametrize ("has_delta_bias" , [True ])
419419@pytest .mark .parametrize ("delta_softplus" , [True ])
@@ -589,9 +589,9 @@ def test_selective_scan_varlen(
589589 )
590590
591591
592- @pytest .mark .parametrize ("itype" , [torch .float32 , torch .float16 , torch . bfloat16 ])
592+ @pytest .mark .parametrize ("itype" , [torch .float32 , torch .bfloat16 ])
593593@pytest .mark .parametrize ("has_z" , [True ])
594- @pytest .mark .parametrize ("dstate" , [16 , 32 , 64 ])
594+ @pytest .mark .parametrize ("dstate" , [16 , 64 ])
595595@pytest .mark .parametrize ("dim" , [2048 , 2048 + 16 , 4096 ])
596596# tests correctness in case subset of the sequences are padded
597597@pytest .mark .parametrize ("with_padding" , [True , False ])
@@ -679,11 +679,11 @@ def test_selective_state_update_with_batch_indices(
679679 assert torch .allclose (out [:batch_size ], out_ref , rtol = rtol , atol = atol )
680680
681681
682- @pytest .mark .parametrize ("itype" , [torch .float32 , torch .float16 , torch . bfloat16 ])
682+ @pytest .mark .parametrize ("itype" , [torch .float32 , torch .bfloat16 ])
683683@pytest .mark .parametrize ("has_z" , [False , True ])
684684@pytest .mark .parametrize ("tie_hdim" , [False , True ])
685- @pytest .mark .parametrize ("ngroups" , [1 , 2 , 4 ])
686- @pytest .mark .parametrize ("dstate" , [16 , 32 , 64 ])
685+ @pytest .mark .parametrize ("ngroups" , [1 , 4 ])
686+ @pytest .mark .parametrize ("dstate" , [16 , 64 ])
687687@pytest .mark .parametrize ("dim" , [2048 , 4096 ])
688688def test_selective_state_update_with_heads_with_batch_indices (
689689 dim , dstate , ngroups , has_z , tie_hdim , itype
0 commit comments