Skip to content

Commit 732a3d4

Browse files
RishiAstracyang49
authored andcommitted
fix mamba2 ssd tests for varlen refactor (#7)
merging Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
1 parent 0ef677a commit 732a3d4

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from einops import rearrange, repeat
88

99
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
10-
mamba_chunk_scan_combined)
10+
mamba_chunk_scan_combined_varlen)
1111
from vllm.platforms import current_platform
1212
from vllm.v1.attention.backends.mamba2_attn import (
1313
_query_start_loc_to_chunk_indices_offsets)
@@ -185,9 +185,14 @@ def end_boundary(n: int):
185185
IND_S = [x % full_length for x in IND_E]
186186
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
187187

188+
# varlen has implicit batch=1
189+
dt2 = dt2.squeeze(0)
190+
X2 = X2.squeeze(0)
191+
B2 = B2.squeeze(0)
192+
C2 = C2.squeeze(0)
188193
yield ([Y_min[s, IND_S[s]:IND_E[s]]
189194
for s in range(num_examples)] if return_naive_ref else None,
190-
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
195+
cu_seqlens, seq_idx, (A, dt2, X2, B2, C2))
191196

192197

193198
@pytest.mark.parametrize("itype",
@@ -219,6 +224,20 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
219224

220225
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
221226
B, C, chunk_size)
227+
228+
cu_seqlens = torch.tensor((0, seqlen), device='cuda').cumsum(dim=0)
229+
seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=cu_seqlens.device)
230+
231+
chunk_indices, chunk_offsets = \
232+
_query_start_loc_to_chunk_indices_offsets(
233+
cu_seqlens, chunk_size, cu_seqlens[-1])
234+
235+
# varlen has implicit batch=1
236+
X = X.squeeze(0)
237+
dt = dt.squeeze(0)
238+
A = A.squeeze(0)
239+
B = B.squeeze(0)
240+
C = C.squeeze(0)
222241
Y = torch.empty_like(X)
223242
final_state = mamba_chunk_scan_combined_varlen(X,
224243
dt,
@@ -234,11 +253,11 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
234253
out=Y)
235254

236255
# just test the last in sequence
237-
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
256+
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
238257

239258
# just test the last head
240259
# NOTE, in the kernel we always cast states to fp32
241-
torch.testing.assert_close(final_state[:, -1],
260+
torch.testing.assert_close(final_state[:, -1].to(torch.float32),
242261
final_state_min[:, -1].to(torch.float32),
243262
atol=atol,
244263
rtol=rtol)
@@ -303,7 +322,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
303322
cu_seqlens, chunk_size, cu_seqlens[-1])
304323

305324
Y = torch.empty_like(X)
306-
new_states = mamba_chunk_scan_combined(
325+
new_states = mamba_chunk_scan_combined_varlen(
307326
X,
308327
dt,
309328
A,
@@ -315,7 +334,6 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
315334
seq_idx=seq_idx,
316335
chunk_indices=chunk_indices,
317336
chunk_offsets=chunk_offsets,
318-
return_varlen_states=True,
319337
initial_states=states,
320338
out=Y,
321339
)
@@ -324,7 +342,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
324342
for i in range(num_examples):
325343

326344
# just test one dim and dstate
327-
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
345+
Y_eg = Y[cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
328346
Y_min_eg = Y_min[i][:, 0, 0]
329347
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
330348

0 commit comments

Comments
 (0)