Skip to content

Commit 9b8a761

Browse files
committed
fix unit test
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
1 parent 732a3d4 commit 9b8a761

File tree

1 file changed

+33
-37
lines changed

1 file changed

+33
-37
lines changed

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
407407
_query_start_loc_to_chunk_indices_offsets(
408408
cu_seqlens, chunk_size, cu_seqlens[-1])
409409
Y_ref = torch.empty_like(X)
410-
state_ref = mamba_chunk_scan_combined(
410+
state_ref = mamba_chunk_scan_combined_varlen(
411411
X,
412412
dt,
413413
A,
@@ -419,7 +419,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
419419
seq_idx=seq_idx,
420420
chunk_indices=chunk_indices,
421421
chunk_offsets=chunk_offsets,
422-
return_varlen_states=True,
423422
initial_states=None,
424423
out=Y_ref,
425424
)
@@ -435,27 +434,27 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
435434
chunked_seq_idx = torch.repeat_interleave(
436435
torch.arange(len(chunked_seqlens), device=device),
437436
chunked_seqlens,
438-
output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32)
437+
output_size=chunked_cu_seqlens[-1]).to(torch.int32)
439438
chunked_input_seq_len = chunked_cu_seqlens[-1]
440-
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
441-
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
442-
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
443-
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
439+
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
440+
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
441+
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
442+
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
444443
for i in range(num_sequences):
445444
# fmt: off
446-
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
445+
chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
447446

448-
X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
449-
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
450-
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
451-
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
447+
X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
448+
dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
449+
B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
450+
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
452451
# fmt: on
453452

454453
chunk_indices, chunk_offsets = \
455454
_query_start_loc_to_chunk_indices_offsets(
456455
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
457456
Y_partial = torch.empty_like(X_chunked)
458-
partial_state = mamba_chunk_scan_combined(
457+
partial_state = mamba_chunk_scan_combined_varlen(
459458
X_chunked,
460459
dt_chunked,
461460
A,
@@ -467,7 +466,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
467466
seq_idx=chunked_seq_idx,
468467
chunk_indices=chunk_indices,
469468
chunk_offsets=chunk_offsets,
470-
return_varlen_states=True,
471469
initial_states=None,
472470
out=Y_partial,
473471
)
@@ -482,29 +480,28 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
482480
remaining_chunked_seq_idx = torch.repeat_interleave(
483481
torch.arange(len(remaining_chunked_seqlens), device=device),
484482
remaining_chunked_seqlens,
485-
output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to(
486-
torch.int32)
483+
output_size=remaining_chunked_cu_seqlens[-1]).to(torch.int32)
487484
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
488485
# fmt: off
489-
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
490-
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
491-
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
492-
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
486+
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
487+
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501
488+
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501
489+
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501
493490
for i in range(num_sequences):
494-
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
491+
remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
495492

496-
remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
497-
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
498-
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
499-
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
493+
remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
494+
remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
495+
remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
496+
remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
500497

501498
# assert input chunking is correct
502499
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
503-
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
504-
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
500+
pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
501+
pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
505502
],
506-
dim=1)
507-
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
503+
dim=0)
504+
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501
508505
# fmt: on
509506

510507
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
@@ -519,7 +516,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
519516
remaining_chunked_cu_seqlens[-1])
520517

521518
Y_chunked = torch.empty_like(remaining_X_chunked)
522-
state_chunked = mamba_chunk_scan_combined(
519+
state_chunked = mamba_chunk_scan_combined_varlen(
523520
remaining_X_chunked,
524521
remaining_dt_chunked,
525522
A,
@@ -531,25 +528,24 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
531528
seq_idx=remaining_chunked_seq_idx,
532529
chunk_indices=chunk_indices,
533530
chunk_offsets=chunk_offsets,
534-
return_varlen_states=True,
535531
initial_states=partial_state,
536532
out=Y_chunked,
537533
)
538534
Y = concat_batch_f(Y_partial, Y_chunked)
539535

540536
# kernel chunked is same as kernel overall
541537
for i in range(num_sequences):
542-
Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
543-
Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
538+
Y_seq = Y[cu_seqlens[i]:cu_seqlens[i + 1], ...]
539+
Y_ref_seq = Y_ref[cu_seqlens[i]:cu_seqlens[i + 1], ...]
544540
torch.testing.assert_close(
545-
Y_seq[:, :chunked_seqlens[i], ...],
546-
Y_ref_seq[:, :chunked_seqlens[i], ...],
541+
Y_seq[:chunked_seqlens[i], ...],
542+
Y_ref_seq[:chunked_seqlens[i], ...],
547543
atol=atol,
548544
rtol=rtol,
549545
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
550546
torch.testing.assert_close(
551-
Y_seq[:, chunked_seqlens[i]:, ...],
552-
Y_ref_seq[:, chunked_seqlens[i]:, ...],
547+
Y_seq[chunked_seqlens[i]:, ...],
548+
Y_ref_seq[chunked_seqlens[i]:, ...],
553549
atol=atol,
554550
rtol=rtol,
555551
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023

0 commit comments

Comments
 (0)