Skip to content

Commit 577c72a

Browse files
kfhfaryewentao256
andauthored
[CI Perf]Prune Tests in kernel/mamba (#26538)
Signed-off-by: Fardin Hoque <kfhfar@amazon.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
1 parent 314285d commit 577c72a

File tree

4 files changed

+21
-36
lines changed

4 files changed

+21
-36
lines changed

tests/kernels/mamba/test_causal_conv1d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
183183
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
184184

185185

186-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
186+
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
187187
@pytest.mark.parametrize("silu_activation", [False, True])
188188
@pytest.mark.parametrize("has_bias", [False, True])
189189
@pytest.mark.parametrize("seqlen", [1, 3])
@@ -265,7 +265,7 @@ def test_causal_conv1d_update_with_batch_gather(
265265
@pytest.mark.parametrize("silu_activation", [True])
266266
@pytest.mark.parametrize("has_bias", [True])
267267
@pytest.mark.parametrize("width", [4])
268-
@pytest.mark.parametrize("seqlen", [8, 30, 249, 2049, 4096])
268+
@pytest.mark.parametrize("seqlen", [8, 249, 4096])
269269
@pytest.mark.parametrize("dim", [64, 4096])
270270
@pytest.mark.parametrize("with_padding", [True, False])
271271
@pytest.mark.parametrize("batch", [4, 10])

tests/kernels/mamba/test_mamba_mixer2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
(64, 1),
2626
(64, 2),
2727
(64, 4), # hidden_size be divisible by num_gpus
28-
(100, 5), # and n_groups must divide hidden_size
2928
],
3029
)
3130
@pytest.mark.parametrize("dtype", [torch.float16])

tests/kernels/mamba/test_mamba_ssm.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -229,16 +229,16 @@ def selective_scan_opcheck_fn(
229229

230230

231231
@pytest.mark.parametrize("wtype", [torch.float32])
232-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
233-
@pytest.mark.parametrize("seqlen", [128, 256, 512, 1024, 2048, 4096])
232+
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
233+
@pytest.mark.parametrize("seqlen", [128, 1024, 4096])
234234
@pytest.mark.parametrize("has_delta_bias", [True])
235235
@pytest.mark.parametrize("delta_softplus", [True])
236236
@pytest.mark.parametrize("has_z", [True])
237237
@pytest.mark.parametrize("has_D", [True])
238238
@pytest.mark.parametrize("varBC_groups", [1, 2])
239239
@pytest.mark.parametrize("is_variable_C", [True])
240240
@pytest.mark.parametrize("is_variable_B", [True])
241-
@pytest.mark.parametrize("scan_chunks", [1, 2, 3])
241+
@pytest.mark.parametrize("scan_chunks", [1, 3])
242242
def test_selective_scan(
243243
is_variable_B,
244244
is_variable_C,
@@ -375,9 +375,9 @@ def test_selective_scan(
375375
)
376376

377377

378-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
378+
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
379379
@pytest.mark.parametrize("has_z", [False, True])
380-
@pytest.mark.parametrize("dstate", [16, 32, 64])
380+
@pytest.mark.parametrize("dstate", [16, 64])
381381
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
382382
def test_selective_state_update(dim, dstate, has_z, itype):
383383
device = "cuda"
@@ -413,7 +413,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
413413

414414
@pytest.mark.parametrize("wtype", [torch.float32])
415415
@pytest.mark.parametrize("itype", [torch.float32])
416-
@pytest.mark.parametrize("seqlen", [1, 128, 129, 256, 512, 1024, 2048, 4096])
416+
@pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096])
417417
@pytest.mark.parametrize("return_last_state", [True])
418418
@pytest.mark.parametrize("has_delta_bias", [True])
419419
@pytest.mark.parametrize("delta_softplus", [True])
@@ -589,9 +589,9 @@ def test_selective_scan_varlen(
589589
)
590590

591591

592-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
592+
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
593593
@pytest.mark.parametrize("has_z", [True])
594-
@pytest.mark.parametrize("dstate", [16, 32, 64])
594+
@pytest.mark.parametrize("dstate", [16, 64])
595595
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
596596
# tests correctness in case subset of the sequences are padded
597597
@pytest.mark.parametrize("with_padding", [True, False])
@@ -679,11 +679,11 @@ def test_selective_state_update_with_batch_indices(
679679
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
680680

681681

682-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
682+
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
683683
@pytest.mark.parametrize("has_z", [False, True])
684684
@pytest.mark.parametrize("tie_hdim", [False, True])
685-
@pytest.mark.parametrize("ngroups", [1, 2, 4])
686-
@pytest.mark.parametrize("dstate", [16, 32, 64])
685+
@pytest.mark.parametrize("ngroups", [1, 4])
686+
@pytest.mark.parametrize("dstate", [16, 64])
687687
@pytest.mark.parametrize("dim", [2048, 4096])
688688
def test_selective_state_update_with_heads_with_batch_indices(
689689
dim, dstate, ngroups, has_z, tie_hdim, itype

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,9 @@ def end_boundary(n: int):
188188
)
189189

190190

191-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
192-
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
193-
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
191+
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
192+
@pytest.mark.parametrize("n_heads", [4, 16, 32])
193+
@pytest.mark.parametrize("d_head", [5, 8, 32, 128])
194194
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
195195
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
196196
# this tests the kernels on a single example (bs=1)
@@ -254,32 +254,22 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it
254254
)
255255

256256

257-
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
258-
@pytest.mark.parametrize("n_heads", [4, 8, 13])
259-
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
257+
@pytest.mark.parametrize("itype", [torch.float32])
258+
@pytest.mark.parametrize("n_heads", [4, 8])
259+
@pytest.mark.parametrize("d_head", [5, 16, 32])
260260
@pytest.mark.parametrize(
261261
"seq_len_chunk_size_cases",
262262
[
263263
# small-ish chunk_size (8)
264264
(64, 8, 2, [(64, 32), (64, 32)]),
265-
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
266265
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
267266
(
268267
64,
269268
8,
270269
2,
271270
[(4, 4), (4, 4), (4, 4), (4, 4)],
272271
), # chunk_size larger than cont batches
273-
(
274-
64,
275-
8,
276-
5,
277-
[
278-
(64, 32, 16, 8, 8),
279-
(8, 16, 32, 16, 8),
280-
(8, 8, 16, 32, 16),
281-
],
282-
), # mode examples with varied lengths
272+
(64, 8, 5, [(64, 32, 16, 8, 8)]),
283273
# large-ish chunk_size (256)
284274
(64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences
285275
(
@@ -359,11 +349,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
359349
@pytest.mark.parametrize("chunk_size", [8, 256])
360350
@pytest.mark.parametrize(
361351
"seqlens",
362-
[
363-
(16, 2, 8, 13),
364-
(270, 88, 212, 203),
365-
(16, 20),
366-
],
352+
[(16, 20), (270, 88, 212, 203)],
367353
)
368354
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
369355
# This test verifies the correctness of the chunked prefill implementation

0 commit comments

Comments
 (0)