Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 71 additions & 54 deletions tests/kernels/mamba/test_mamba_ssm_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from einops import rearrange, repeat

from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
mamba_chunk_scan_combined_varlen)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba2_attn import (
_query_start_loc_to_chunk_indices_offsets)
Expand Down Expand Up @@ -185,9 +185,14 @@ def end_boundary(n: int):
IND_S = [x % full_length for x in IND_E]
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]

# varlen has implicit batch=1
dt2 = dt2.squeeze(0)
X2 = X2.squeeze(0)
B2 = B2.squeeze(0)
C2 = C2.squeeze(0)
yield ([Y_min[s, IND_S[s]:IND_E[s]]
for s in range(num_examples)] if return_naive_ref else None,
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
cu_seqlens, seq_idx, (A, dt2, X2, B2, C2))


@pytest.mark.parametrize("itype",
Expand All @@ -198,7 +203,7 @@ def end_boundary(n: int):
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
itype):

# this tests the kernels on a single example (no batching)
# this tests the kernels on a single example (bs=1)

# TODO: the bfloat16 case requires higher thresholds. To be investigated

Expand All @@ -219,23 +224,40 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,

Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
B, C, chunk_size)

cu_seqlens = torch.tensor((0, seqlen), device='cuda').cumsum(dim=0)
seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=cu_seqlens.device)

chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])

# varlen has implicit batch=1
X = X.squeeze(0)
dt = dt.squeeze(0)
A = A.squeeze(0)
B = B.squeeze(0)
C = C.squeeze(0)
Y = torch.empty_like(X)
final_state = mamba_chunk_scan_combined(X,
dt,
A,
B,
C,
chunk_size,
D=None,
return_final_states=True,
out=Y)
final_state = mamba_chunk_scan_combined_varlen(X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
out=Y)

# just test the last in sequence
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)

# just test the last head
# NOTE, in the kernel we always cast states to fp32
torch.testing.assert_close(final_state[:, -1],
torch.testing.assert_close(final_state[:, -1].to(torch.float32),
final_state_min[:, -1].to(torch.float32),
atol=atol,
rtol=rtol)
Expand Down Expand Up @@ -300,7 +322,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
cu_seqlens, chunk_size, cu_seqlens[-1])

Y = torch.empty_like(X)
new_states = mamba_chunk_scan_combined(
new_states = mamba_chunk_scan_combined_varlen(
X,
dt,
A,
Expand All @@ -312,7 +334,6 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=states,
out=Y,
)
Expand All @@ -321,7 +342,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
for i in range(num_examples):

# just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_eg = Y[cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)

Expand Down Expand Up @@ -386,7 +407,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
Y_ref = torch.empty_like(X)
state_ref = mamba_chunk_scan_combined(
state_ref = mamba_chunk_scan_combined_varlen(
X,
dt,
A,
Expand All @@ -398,7 +419,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=None,
out=Y_ref,
)
Expand All @@ -414,27 +434,27 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
chunked_seq_idx = torch.repeat_interleave(
torch.arange(len(chunked_seqlens), device=device),
chunked_seqlens,
output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32)
output_size=chunked_cu_seqlens[-1]).to(torch.int32)
chunked_input_seq_len = chunked_cu_seqlens[-1]
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
for i in range(num_sequences):
# fmt: off
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501

X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
# fmt: on

chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
Y_partial = torch.empty_like(X_chunked)
partial_state = mamba_chunk_scan_combined(
partial_state = mamba_chunk_scan_combined_varlen(
X_chunked,
dt_chunked,
A,
Expand All @@ -446,7 +466,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
seq_idx=chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=None,
out=Y_partial,
)
Expand All @@ -461,29 +480,28 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
remaining_chunked_seq_idx = torch.repeat_interleave(
torch.arange(len(remaining_chunked_seqlens), device=device),
remaining_chunked_seqlens,
output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to(
torch.int32)
output_size=remaining_chunked_cu_seqlens[-1]).to(torch.int32)
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
# fmt: off
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501
for i in range(num_sequences):
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501

remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501

# assert input chunking is correct
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
],
dim=1)
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
dim=0)
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501
# fmt: on

assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
Expand All @@ -498,7 +516,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
remaining_chunked_cu_seqlens[-1])

Y_chunked = torch.empty_like(remaining_X_chunked)
state_chunked = mamba_chunk_scan_combined(
state_chunked = mamba_chunk_scan_combined_varlen(
remaining_X_chunked,
remaining_dt_chunked,
A,
Expand All @@ -510,25 +528,24 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
seq_idx=remaining_chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=partial_state,
out=Y_chunked,
)
Y = concat_batch_f(Y_partial, Y_chunked)

# kernel chunked is same as kernel overall
for i in range(num_sequences):
Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_seq = Y[cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[cu_seqlens[i]:cu_seqlens[i + 1], ...]
torch.testing.assert_close(
Y_seq[:, :chunked_seqlens[i], ...],
Y_ref_seq[:, :chunked_seqlens[i], ...],
Y_seq[:chunked_seqlens[i], ...],
Y_ref_seq[:chunked_seqlens[i], ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
torch.testing.assert_close(
Y_seq[:, chunked_seqlens[i]:, ...],
Y_ref_seq[:, chunked_seqlens[i]:, ...],
Y_seq[chunked_seqlens[i]:, ...],
Y_ref_seq[chunked_seqlens[i]:, ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023
Expand Down
23 changes: 10 additions & 13 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update)
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
mamba_chunk_scan_combined_varlen)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
Expand Down Expand Up @@ -504,6 +504,7 @@ def forward_cuda(
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
query_start_loc_p = attn_metadata.query_start_loc_p

# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
Expand Down Expand Up @@ -545,6 +546,7 @@ def forward_cuda(
out, _ = self.out_proj(hidden_states)
return out

# NOTE: V0 put prefill before decode, v1 puts decode before prefill
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
Expand All @@ -570,9 +572,6 @@ def forward_cuda(
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)

# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
Expand Down Expand Up @@ -620,15 +619,15 @@ def forward_cuda(
ssm_state[state_indices_tensor_p], 0)

# NOTE: final output is an in-place update of out tensor
varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
varlen_states = mamba_chunk_scan_combined_varlen(
hidden_states_p.view(num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
dt_p.unsqueeze(0),
dt_p,
self.A,
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
-1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
-1),
chunk_size=chunk_size,
D=self.D,
Expand All @@ -639,17 +638,15 @@ def forward_cuda(
chunk_offsets=chunk_offsets_p,
cu_seqlens=query_start_loc_p,
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
self.head_dim),
state_dtype=ssm_state.dtype)

# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
ssm_state[state_indices_tensor_p] = varlen_state
ssm_state[state_indices_tensor_p] = varlen_states

# Process decode requests
if has_decode:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def causal_conv1d_fn(
batch_ptr = metadata.batch_ptr
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
else:
seqlens = np.diff(query_start_loc.to('cpu'))
seqlens = query_start_loc.diff().to('cpu')
args = seqlens
MAX_NUM_PROGRAMS = 1024

Expand Down
Loading