diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index fc60d5ac82b2..927af32588e6 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -7,7 +7,7 @@ from einops import rearrange, repeat from vllm.model_executor.layers.mamba.ops.ssd_combined import ( - mamba_chunk_scan_combined) + mamba_chunk_scan_combined_varlen) from vllm.platforms import current_platform from vllm.v1.attention.backends.mamba2_attn import ( _query_start_loc_to_chunk_indices_offsets) @@ -185,9 +185,14 @@ def end_boundary(n: int): IND_S = [x % full_length for x in IND_E] IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] + # varlen has implicit batch=1 + dt2 = dt2.squeeze(0) + X2 = X2.squeeze(0) + B2 = B2.squeeze(0) + C2 = C2.squeeze(0) yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)] if return_naive_ref else None, - cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) + cu_seqlens, seq_idx, (A, dt2, X2, B2, C2)) @pytest.mark.parametrize("itype", @@ -198,7 +203,7 @@ def end_boundary(n: int): def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype): - # this tests the kernels on a single example (no batching) + # this tests the kernels on a single example (bs=1) # TODO: the bfloat16 case requires higher thresholds. To be investigated @@ -219,23 +224,40 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, B, C, chunk_size) + + cu_seqlens = torch.tensor((0, seqlen), device='cuda').cumsum(dim=0) + seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=cu_seqlens.device) + + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + cu_seqlens, chunk_size, cu_seqlens[-1]) + + # varlen has implicit batch=1 + X = X.squeeze(0) + dt = dt.squeeze(0) + A = A.squeeze(0) + B = B.squeeze(0) + C = C.squeeze(0) Y = torch.empty_like(X) - final_state = mamba_chunk_scan_combined(X, - dt, - A, - B, - C, - chunk_size, - D=None, - return_final_states=True, - out=Y) + final_state = mamba_chunk_scan_combined_varlen(X, + dt, + A, + B, + C, + chunk_size, + D=None, + cu_seqlens=cu_seqlens, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + out=Y) # just test the last in sequence - torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol) + torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol) # just test the last head # NOTE, in the kernel we always cast states to fp32 - torch.testing.assert_close(final_state[:, -1], + torch.testing.assert_close(final_state[:, -1].to(torch.float32), final_state_min[:, -1].to(torch.float32), atol=atol, rtol=rtol) @@ -300,7 +322,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, cu_seqlens, chunk_size, cu_seqlens[-1]) Y = torch.empty_like(X) - new_states = mamba_chunk_scan_combined( + new_states = mamba_chunk_scan_combined_varlen( X, dt, A, @@ -312,7 +334,6 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, seq_idx=seq_idx, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets, - return_varlen_states=True, initial_states=states, out=Y, ) @@ -321,7 +342,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, for i in range(num_examples): # just test one dim and dstate - Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] + Y_eg = Y[cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] Y_min_eg = Y_min[i][:, 0, 0] torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol) @@ -386,7 +407,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): _query_start_loc_to_chunk_indices_offsets( cu_seqlens, chunk_size, cu_seqlens[-1]) Y_ref = torch.empty_like(X) - state_ref = mamba_chunk_scan_combined( + state_ref = mamba_chunk_scan_combined_varlen( X, dt, A, @@ -398,7 +419,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): seq_idx=seq_idx, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets, - return_varlen_states=True, initial_states=None, out=Y_ref, ) @@ -414,27 +434,27 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): chunked_seq_idx = torch.repeat_interleave( torch.arange(len(chunked_seqlens), device=device), chunked_seqlens, - output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32) + output_size=chunked_cu_seqlens[-1]).to(torch.int32) chunked_input_seq_len = chunked_cu_seqlens[-1] - X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...] - dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...] - B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...] - C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...] + X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...] + dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...] + B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...] + C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...] for i in range(num_sequences): # fmt: off - chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501 + chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501 - X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501 - dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501 - B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501 - C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501 + X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501 + dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501 + B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501 + C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501 # fmt: on chunk_indices, chunk_offsets = \ _query_start_loc_to_chunk_indices_offsets( chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1]) Y_partial = torch.empty_like(X_chunked) - partial_state = mamba_chunk_scan_combined( + partial_state = mamba_chunk_scan_combined_varlen( X_chunked, dt_chunked, A, @@ -446,7 +466,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): seq_idx=chunked_seq_idx, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets, - return_varlen_states=True, initial_states=None, out=Y_partial, ) @@ -461,29 +480,28 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): remaining_chunked_seq_idx = torch.repeat_interleave( torch.arange(len(remaining_chunked_seqlens), device=device), remaining_chunked_seqlens, - output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to( - torch.int32) + output_size=remaining_chunked_cu_seqlens[-1]).to(torch.int32) remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1] # fmt: off - remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 - remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 - remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 - remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501 for i in range(num_sequences): - remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501 + remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501 - remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501 - remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501 - remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501 - remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501 + remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501 + remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501 + remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501 + remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501 # assert input chunking is correct concat_chunk_f = lambda pt1, pt2, i: torch.cat([ - pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...], - pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...], + pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...], + pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...], ], - dim=1) - concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501 + dim=0) + concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501 # fmt: on assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X) @@ -498,7 +516,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): remaining_chunked_cu_seqlens[-1]) Y_chunked = torch.empty_like(remaining_X_chunked) - state_chunked = mamba_chunk_scan_combined( + state_chunked = mamba_chunk_scan_combined_varlen( remaining_X_chunked, remaining_dt_chunked, A, @@ -510,7 +528,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): seq_idx=remaining_chunked_seq_idx, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets, - return_varlen_states=True, initial_states=partial_state, out=Y_chunked, ) @@ -518,17 +535,17 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): # kernel chunked is same as kernel overall for i in range(num_sequences): - Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] - Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + Y_seq = Y[cu_seqlens[i]:cu_seqlens[i + 1], ...] + Y_ref_seq = Y_ref[cu_seqlens[i]:cu_seqlens[i + 1], ...] torch.testing.assert_close( - Y_seq[:, :chunked_seqlens[i], ...], - Y_ref_seq[:, :chunked_seqlens[i], ...], + Y_seq[:chunked_seqlens[i], ...], + Y_ref_seq[:chunked_seqlens[i], ...], atol=atol, rtol=rtol, msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023 torch.testing.assert_close( - Y_seq[:, chunked_seqlens[i]:, ...], - Y_ref_seq[:, chunked_seqlens[i]:, ...], + Y_seq[chunked_seqlens[i]:, ...], + Y_ref_seq[chunked_seqlens[i]:, ...], atol=atol, rtol=rtol, msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023 diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 908ea6e0025f..6dd09fad7a90 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -29,7 +29,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_state_update) from vllm.model_executor.layers.mamba.ops.ssd_combined import ( - mamba_chunk_scan_combined) + mamba_chunk_scan_combined_varlen) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( LoaderFunction, composed_weight_loader, sharded_weight_loader) @@ -504,6 +504,7 @@ def forward_cuda( seq_idx_p = attn_metadata.seq_idx_p chunk_indices_p = attn_metadata.chunk_indices_p chunk_offsets_p = attn_metadata.chunk_offsets_p + query_start_loc_p = attn_metadata.query_start_loc_p # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -545,6 +546,7 @@ def forward_cuda( out, _ = self.out_proj(hidden_states) return out + # NOTE: V0 put prefill before decode, v1 puts decode before prefill num_prefills = attn_metadata.num_prefills # request count num_decodes = attn_metadata.num_decode_tokens # token count (=request) num_prefill_tokens = attn_metadata.num_prefill_tokens # token count @@ -570,9 +572,6 @@ def forward_cuda( [num_decodes, num_prefills], dim=0, ) - query_start_loc_p = ( - attn_metadata.query_start_loc[-num_prefills - 1:] - - num_decodes if has_prefill else None) # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -620,15 +619,15 @@ def forward_cuda( ssm_state[state_indices_tensor_p], 0) # NOTE: final output is an in-place update of out tensor - varlen_state = mamba_chunk_scan_combined( - hidden_states_p.view(1, num_prefill_tokens, + varlen_states = mamba_chunk_scan_combined_varlen( + hidden_states_p.view(num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim), - dt_p.unsqueeze(0), + dt_p, self.A, - B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, + B_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1), - C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, + C_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1), chunk_size=chunk_size, D=self.D, @@ -639,17 +638,15 @@ def forward_cuda( chunk_offsets=chunk_offsets_p, cu_seqlens=query_start_loc_p, initial_states=initial_states, - return_varlen_states=True, - return_final_states=False, dt_softplus=True, dt_limit=(0.0, float("inf")), - out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, + out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim), state_dtype=ssm_state.dtype) # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor - ssm_state[state_indices_tensor_p] = varlen_state + ssm_state[state_indices_tensor_p] = varlen_states # Process decode requests if has_decode: diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 010fcdda156c..a9eedd11767e 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -425,7 +425,7 @@ def causal_conv1d_fn( batch_ptr = metadata.batch_ptr token_chunk_offset_ptr = metadata.token_chunk_offset_ptr else: - seqlens = np.diff(query_start_loc.to('cpu')) + seqlens = query_start_loc.diff().to('cpu') args = seqlens MAX_NUM_PROGRAMS = 1024 diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 11ca1255ebfb..601b71ab2a51 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -99,34 +99,28 @@ def _bmm_chunk_fwd_kernel( seq_idx_ptr, # Matrix dimensions seqlen, - chunk_size, - K, - ngroups, - stride_a_batch, - stride_a_seqlen, - stride_a_head, - stride_ak, - stride_b_batch, - stride_b_seqlen, - stride_b_head, - stride_bk, - stride_out_batch, - stride_out_chunk, - stride_out_head, - stride_outm, - stride_outn, - stride_seq_idx_batch, - stride_seq_idx_seqlen, + chunk_size: tl.constexpr, + K: tl.constexpr, + ngroups: tl.constexpr, + stride_a_seqlen: tl.int64, + stride_a_head: tl.int64, + stride_ak: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_bk: tl.constexpr, + stride_out_chunk: tl.int64, + stride_out_head: tl.int64, + stride_outm: tl.int64, + stride_outn: tl.constexpr, + stride_seq_idx_seqlen: tl.constexpr, # Meta-parameters IS_CAUSAL: tl.constexpr, dot_dtype: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): - pid_b = tl.program_id(axis=1) - pid_ch = tl.program_id(axis=2).to(tl.int64) + pid_ch = tl.program_id(axis=1).to(tl.int64) pid_c = pid_ch // ngroups pid_h = pid_ch - pid_c * ngroups num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) @@ -135,10 +129,10 @@ def _bmm_chunk_fwd_kernel( if IS_CAUSAL: if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: return - a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + a_ptr += pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + b_ptr += pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head + + seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -150,6 +144,8 @@ def _bmm_chunk_fwd_kernel( chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # compute a * b.T for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & @@ -165,18 +161,19 @@ def _bmm_chunk_fwd_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - if HAS_SEQ_IDX: - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1) - seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, - mask=offs_n < chunk_size_limit, - other=-2) - acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) - out = acc.to(out_ptr.dtype.element_ty) - out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head + # Zero out the results that are not from the same request + # in the varlen batch + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, + mask=offs_n < chunk_size_limit, + other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + + out = acc.to(out_ptr.dtype.element_ty) + out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) tl.store(out_ptrs, @@ -185,78 +182,61 @@ def _bmm_chunk_fwd_kernel( (offs_n[None, :] < chunk_size)) -def _bmm_chunk_fwd(a, - b, - chunk_size, - seq_idx=None, - causal=False, - output_dtype=None): +def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None): """ Argument: - a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + a: (seqlen, ngroups, k) + b: (seqlen, ngroups, k) + seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are guaranteed to be correct. Return: - out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + out: (nchunks, ngroups, chunk_size, chunk_size) """ - # Check constraints. - has_groups = a.dim() == 4 - if not has_groups: - batch, seqlen, k = a.shape - else: - batch, seqlen, ngroups, k = a.shape + seqlen, ngroups, k = a.shape assert b.shape == a.shape - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - if a.stride(-1) != 1 and a.stride(1) != 1: + assert seq_idx is not None + assert seq_idx.shape == (seqlen, ) + if a.stride(-1) != 1 and a.stride(0) != 1: a = a.contiguous() - if b.stride(-1) != 1 and b.stride(1) != 1: + if b.stride(-1) != 1 and b.stride(0) != 1: b = b.contiguous() + nchunks = math.ceil(seqlen / chunk_size) # Allocates output. out_dtype = a.dtype if output_dtype is None else output_dtype - out = torch.empty( - (batch, nchunks, chunk_size, chunk_size) if not has_groups else - (batch, nchunks, ngroups, chunk_size, chunk_size), - device=a.device, - dtype=out_dtype) + out = torch.empty((nchunks, ngroups, chunk_size, chunk_size), + device=a.device, + dtype=out_dtype) dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32)) grid = lambda META: (triton.cdiv( chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( - chunk_size, META['BLOCK_SIZE_N']), batch, nchunks - if not has_groups else nchunks * ngroups) + chunk_size, META['BLOCK_SIZE_N']), nchunks * ngroups) with torch.cuda.device(a.device.index): _bmm_chunk_fwd_kernel[grid]( - a, - b, - out, - seq_idx, - seqlen, - chunk_size, - k, - ngroups if has_groups else 1, - a.stride(0), - a.stride(1), - 0 if not has_groups else a.stride(2), - a.stride(-1), - b.stride(0), - b.stride(1), - 0 if not has_groups else b.stride(2), - b.stride(-1), - out.stride(0), - out.stride(1), - 0 if not has_groups else out.stride(2), - out.stride(-2), - out.stride(-1), - *((seq_idx.stride(0), - seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - causal, - dot_dtype, - HAS_SEQ_IDX=seq_idx is not None, + a_ptr=a, + b_ptr=b, + out_ptr=out, + seq_idx_ptr=seq_idx, + seqlen=seqlen, + chunk_size=chunk_size, + K=k, + ngroups=ngroups, + stride_a_seqlen=a.stride(0), + stride_a_head=a.stride(1), + stride_ak=a.stride(2), + stride_b_seqlen=b.stride(0), + stride_b_head=b.stride(1), + stride_bk=b.stride(2), + stride_out_chunk=out.stride(0), + stride_out_head=out.stride(1), + stride_outm=out.stride(-2), + stride_outn=out.stride(-1), + stride_seq_idx_seqlen=seq_idx.stride(0), + IS_CAUSAL=causal, + dot_dtype=dot_dtype, ) return out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index fb8350e191c9..add72617fcea 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -6,7 +6,6 @@ # ruff: noqa: E501,SIM102 -import torch from packaging import version from vllm.triton_utils import tl, triton @@ -114,7 +113,6 @@ def _chunk_scan_fwd_kernel( x_ptr, z_ptr, out_ptr, - out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, @@ -126,60 +124,49 @@ def _chunk_scan_fwd_kernel( chunk_offsets_ptr, chunk_meta_num, # Matrix dimensions - chunk_size, - hdim, - dstate, - batch, + chunk_size: tl.constexpr, + hdim: tl.constexpr, + dstate: tl.constexpr, seqlen, - nheads_ngroups_ratio, + nheads_ngroups_ratio: tl.constexpr, # Strides - stride_cb_batch, - stride_cb_chunk, - stride_cb_head, - stride_cb_csize_m, - stride_cb_csize_k, - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_z_batch, - stride_z_seqlen, - stride_z_head, - stride_z_hdim, - stride_out_batch, - stride_out_seqlen, - stride_out_head, - stride_out_hdim, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - stride_C_batch, - stride_C_seqlen, - stride_C_head, - stride_C_dstate, - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_init_states_batch, - stride_init_states_head, - stride_init_states_hdim, - stride_init_states_dstate, - stride_D_head, + stride_cb_chunk: tl.int64, + stride_cb_head: tl.int64, + stride_cb_csize_m: tl.int64, + stride_cb_csize_k: tl.constexpr, + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_z_seqlen: tl.int64, + stride_z_head: tl.int64, + stride_z_hdim: tl.constexpr, + stride_out_seqlen: tl.int64, + stride_out_head: tl.int64, + stride_out_hdim: tl.constexpr, + stride_dt_chunk: tl.int64, + stride_dt_head: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_head: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_seq_idx_seqlen: tl.constexpr, + stride_C_seqlen: tl.int64, + stride_C_head: tl.int64, + stride_C_dstate: tl.constexpr, + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_init_states_batch: tl.int64, + stride_init_states_head: tl.int64, + stride_init_states_hdim: tl.int64, + stride_init_states_dstate: tl.constexpr, + stride_D_head: tl.constexpr, # Meta-parameters IS_CAUSAL: tl.constexpr, HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, HAS_Z: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, @@ -187,9 +174,7 @@ def _chunk_scan_fwd_kernel( IS_TRITON_22: tl.constexpr, HAS_INITSTATES: tl.constexpr, ): - pid_bc = tl.program_id(axis=1).to(tl.int64) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch + pid_c = tl.program_id(axis=1).to(tl.int64) if not HAS_INITSTATES: c_idx = pid_c c_off = 0 @@ -201,53 +186,51 @@ def _chunk_scan_fwd_kernel( num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( - pid_h // nheads_ngroups_ratio) * stride_cb_head - x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + ( + cb_ptr += c_idx * stride_cb_chunk + (pid_h // + nheads_ngroups_ratio) * stride_cb_head + x_ptr += c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += c_idx * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += c_idx * chunk_size * stride_C_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_C_head # M-block offsets and prev states # - logic in next block may override these if there is an active offset offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) - prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head + prev_states_ptr = states_ptr + c_idx * stride_states_chunk + pid_h * stride_states_head prev_states_hdim = stride_states_hdim prev_states_dstate = stride_states_dstate chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen - - # - we only need seq_idx_prev to be aligned to chunk boundary - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, - mask=c_idx >= 1, - other=0) - - if HAS_INITSTATES: - # if there are init states, we only need seq_idx_m to point - # what is the current seq_idx - - # get current seq idx - if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: - seq_idx_m = tl.load( - seq_idx_ptr + - (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) - - # - recall that in ssd_state_passing, for the case c_off == 0 - # i.e., the very first sequence, we made states_ptr hold its initial state - # so this edge case is taken care of - if ((c_off == 0) and - (seq_idx_prev != seq_idx_m - ) # if a seq is changed exactly on boundary - or (c_off > 0) # implies a new example (pseudo chunk) - ): - - # - replace prev_states_ptr with init_states - prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head - prev_states_hdim = stride_init_states_hdim # override strides - prev_states_dstate = stride_init_states_dstate + + seq_idx_ptr += c_idx * chunk_size * stride_seq_idx_seqlen + # - we only need seq_idx_prev to be aligned to chunk boundary + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + mask=c_idx >= 1, + other=0) + + if HAS_INITSTATES: + # if there are init states, we only need seq_idx_m to point + # what is the current seq_idx + + # get current seq idx + if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: + seq_idx_m = tl.load( + seq_idx_ptr + + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) + + # - recall that in ssd_state_passing, for the case c_off == 0 + # i.e., the very first sequence, we made states_ptr hold its initial state + # so this edge case is taken care of + if ((c_off == 0) and (seq_idx_prev != seq_idx_m + ) # if a seq is changed exactly on boundary + or (c_off > 0) # implies a new example (pseudo chunk) + ): + + # - replace prev_states_ptr with init_states + prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim # override strides + prev_states_dstate = stride_init_states_dstate offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, @@ -256,7 +239,6 @@ def _chunk_scan_fwd_kernel( # - handle chunk state limit if HAS_INITSTATES: - # have to split this if otherwise compilation will have problems dA_cs_m_boundary = 0.0 @@ -296,13 +278,11 @@ def _chunk_scan_fwd_kernel( dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), other=0.0).to(tl.float32) - - if HAS_SEQ_IDX: + else: # - handle seq idx when HAS_INITSTATES==False - if not HAS_INITSTATES: - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) @@ -319,18 +299,15 @@ def _chunk_scan_fwd_kernel( prev_states_ptrs = prev_states_ptr + ( offs_n[None, :] * prev_states_hdim + offs_k_dstate[:, None] * prev_states_dstate) - if HAS_SEQ_IDX: - - if not HAS_INITSTATES: - # - this is for continuous batching where there is no init states - scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), - 0.0) - else: - # - if there is initstates, we will rely on prev_states, no zeroing - # required. - scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + + if not HAS_INITSTATES: + # - this is for continuous batching where there is no init states + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) else: - scale_m = tl.exp(dA_cs_m) + # - if there is initstates, we will rely on prev_states, no zeroing + # required. + scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + if BLOCK_SIZE_DSTATE <= 128: C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & @@ -416,15 +393,7 @@ def _chunk_scan_fwd_kernel( acc += x_residual * D if HAS_Z: - out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head - out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + - offs_out_n[None, :]) - tl.store(out_x_ptrs, - acc, - mask=(offs_out_m[:, None] < chunk_size_limit) & - (offs_out_n[None, :] < hdim)) - - z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptr += c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) z = tl.load(z_ptrs, @@ -433,7 +402,7 @@ def _chunk_scan_fwd_kernel( other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) - out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptr += c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) tl.store(out_ptrs, @@ -449,126 +418,110 @@ def _chunk_scan_fwd( dA_cumsum, C, states, + out, + seq_idx, D=None, z=None, - seq_idx=None, chunk_indices=None, chunk_offsets=None, initial_states=None, - out=None, ): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = C.shape + assert seq_idx is not None, "this implementation requires seq_idx" + + seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = C.shape assert nheads % ngroups == 0 - assert C.shape == (batch, seqlen, ngroups, dstate) - assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) - if z is not None: - assert z.shape == x.shape + assert C.shape == (seqlen, ngroups, dstate) + assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size) if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads, ) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - assert states.shape == (batch, nchunks, nheads, headdim, dstate) - - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - - if initial_states is not None: - # with initial states, we need to take care of how - # seq_idx crosses the boundaries - assert batch == 1, "chunk scan only supports initial states with batch 1" - assert chunk_indices is not None and chunk_offsets is not None, \ - "chunk_indices and chunk_offsets should have been set" - else: - chunk_indices, chunk_offsets = None, None - else: - chunk_indices, chunk_offsets = None, None - - assert out.shape == x.shape - if z is not None: - out_x = torch.empty_like(x) - assert out_x.stride() == out.stride() + assert z.shape == x.shape + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (nheads, nchunks, chunk_size) + assert states.shape == (nchunks, nheads, headdim, dstate) + assert seq_idx.shape == (seqlen, ) + + if initial_states is not None: + # with initial states, we need to take care of how + # seq_idx crosses the boundaries + assert chunk_indices is not None and chunk_offsets is not None, \ + "chunk_indices and chunk_offsets should have been set" else: - out_x = None + chunk_indices, chunk_offsets = None, None grid = lambda META: ( triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( - headdim, META['BLOCK_SIZE_N']), batch * nchunks + headdim, META['BLOCK_SIZE_N']), nchunks if chunk_offsets is None else len(chunk_offsets), nheads) - z_strides = ((z.stride(0), z.stride(1), z.stride(2), - z.stride(3)) if z is not None else (0, 0, 0, 0)) + + z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else + (0, 0, 0)) + initial_states_strides = ((initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) + if initial_states is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel[grid]( - cb, - x, - z, - out, - out_x, - dt, - dA_cumsum, - seq_idx, - C, - states, - D, - initial_states, - chunk_indices, - chunk_offsets, - len(chunk_indices) if chunk_indices is not None else 0, - chunk_size, - headdim, - dstate, - batch, - seqlen, - nheads // ngroups, - cb.stride(0), - cb.stride(1), - cb.stride(2), - cb.stride(3), - cb.stride(4), - x.stride(0), - x.stride(1), - x.stride(2), - x.stride(3), - z_strides[0], - z_strides[1], - z_strides[2], - z_strides[3], - out.stride(0), - out.stride(1), - out.stride(2), - out.stride(3), - dt.stride(0), - dt.stride(2), - dt.stride(1), - dt.stride(3), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else - (0, 0)), - C.stride(0), - C.stride(1), - C.stride(2), - C.stride(3), - states.stride(0), - states.stride(1), - states.stride(2), - states.stride(3), - states.stride(4), - *((initial_states.stride(0), initial_states.stride(1), - initial_states.stride(2), - initial_states.stride(3)) if initial_states is not None else - (0, 0, 0, 0)), - D.stride(0) if D is not None else 0, - True, - D is not None, - D.dim() == 2 if D is not None else True, - BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + cb_ptr=cb, + x_ptr=x, + z_ptr=z, + out_ptr=out, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + seq_idx_ptr=seq_idx, + C_ptr=C, + states_ptr=states, + D_ptr=D, + initstates_ptr=initial_states, + chunk_indices_ptr=chunk_indices, + chunk_offsets_ptr=chunk_offsets, + chunk_meta_num=len(chunk_indices) if chunk_indices is not None else 0, + chunk_size=chunk_size, + hdim=headdim, + dstate=dstate, + seqlen=seqlen, + nheads_ngroups_ratio=nheads // ngroups, + stride_cb_chunk=cb.stride(0), + stride_cb_head=cb.stride(1), + stride_cb_csize_m=cb.stride(2), + stride_cb_csize_k=cb.stride(3), + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_z_seqlen=z_strides[0], + stride_z_head=z_strides[1], + stride_z_hdim=z_strides[2], + stride_out_seqlen=out.stride(0), + stride_out_head=out.stride(1), + stride_out_hdim=out.stride(2), + stride_dt_chunk=dt.stride(1), + stride_dt_head=dt.stride(0), + stride_dt_csize=dt.stride(2), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_seq_idx_seqlen=seq_idx.stride(0), + stride_C_seqlen=C.stride(0), + stride_C_head=C.stride(1), + stride_C_dstate=C.stride(2), + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_init_states_batch=initial_states_strides[0], + stride_init_states_head=initial_states_strides[1], + stride_init_states_hdim=initial_states_strides[2], + stride_init_states_dstate=initial_states_strides[3], + stride_D_head=D.stride(0) if D is not None else 0, + IS_CAUSAL=True, + HAS_D=D is not None, + D_HAS_HDIM=D.dim() == 2 if D is not None else True, HAS_Z=z is not None, - HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), IS_TRITON_22=TRITON_22, HAS_INITSTATES=initial_states is not None, ) - return out_x + return diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 2e657426143b..8ee41f2cbc1b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -35,41 +35,35 @@ def _chunk_cumsum_fwd_kernel( dt_out_ptr, dA_cumsum_ptr, # Matrix dimension - batch, seqlen, - nheads, - chunk_size, - dt_min, - dt_max, + nheads: tl.constexpr, + chunk_size: tl.constexpr, + dt_min: tl.constexpr, + dt_max: tl.constexpr, # Strides - stride_dt_batch, - stride_dt_seqlen, - stride_dt_head, - stride_A_head, - stride_dt_bias_head, - stride_dt_out_batch, - stride_dt_out_chunk, - stride_dt_out_head, - stride_dt_out_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, + stride_dt_seqlen: tl.int64, + stride_dt_head: tl.constexpr, + stride_A_head: tl.constexpr, + stride_dt_bias_head: tl.constexpr, + stride_dt_out_head: tl.int64, + stride_dt_out_chunk: tl.int64, + stride_dt_out_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, # Meta-parameters DT_SOFTPLUS: tl.constexpr, HAS_DT_BIAS: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, ): - pid_b = tl.program_id(axis=0) - # if dt is long, may cause problems, so use 64 bit # https://github.com/triton-lang/triton/issues/1058 - pid_c = tl.program_id(axis=1).to(tl.int64) - pid_h = tl.program_id(axis=2) - dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen - dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_c = tl.program_id(axis=0).to(tl.int64) + pid_h = tl.program_id(axis=1) + dt_ptr += pid_c * chunk_size * stride_dt_seqlen + dt_out_ptr += pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) @@ -93,9 +87,8 @@ def _chunk_cumsum_fwd_kernel( dt += dt_bias[:, None] if DT_SOFTPLUS: dt = tl.where(dt <= 20.0, softplus(dt), dt) - # As of Triton 2.2.0, tl.clamp is not available yet - # dt = tl.clamp(dt, dt_min, dt_max) - dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + + dt = tl.clamp(dt, dt_min, dt_max) dt = tl.where( (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) @@ -197,56 +190,46 @@ def _chunk_state_fwd_kernel( dA_cumsum_ptr, seq_idx_ptr, # Matrix dimensions - hdim, - dstate, - chunk_size, - batch, + hdim: tl.constexpr, + dstate: tl.constexpr, + chunk_size: tl.constexpr, seqlen, - nheads_ngroups_ratio, + nheads_ngroups_ratio: tl.constexpr, # Strides - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_b_batch, - stride_b_seqlen, - stride_b_head, - stride_b_dstate, - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_b_dstate: tl.constexpr, + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_dt_head: tl.int64, + stride_dt_chunk: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_seq_idx_seqlen: tl.constexpr, # Meta-parameters - HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): - pid_bc = tl.program_id(axis=1).to(tl.int64) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch + pid_c = tl.program_id(axis=1).to(tl.int64) pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + ( + b_ptr += pid_c * chunk_size * stride_b_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_b_head - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + + seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -259,13 +242,11 @@ def _chunk_state_fwd_kernel( dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - if HAS_SEQ_IDX: - seq_idx_last = tl.load(seq_idx_ptr + - (chunk_size_limit - 1) * stride_seq_idx_seqlen) + seq_idx_last = tl.load(seq_idx_ptr + + (chunk_size_limit - 1) * stride_seq_idx_seqlen) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): @@ -280,29 +261,28 @@ def _chunk_state_fwd_kernel( dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - if HAS_SEQ_IDX: - seq_idx_k = tl.load(seq_idx_ptrs, - mask=offs_k < chunk_size_limit - k, - other=-1) + + seq_idx_k = tl.load(seq_idx_ptrs, + mask=offs_k < chunk_size_limit - k, + other=-1) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k - else: - scale = tl.where(seq_idx_k == seq_idx_last, - tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + + scale = tl.where(seq_idx_k == seq_idx_last, + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen b_ptrs += BLOCK_SIZE_K * stride_b_seqlen dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) - states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + @@ -400,36 +380,35 @@ def _chunk_state_varlen_kernel( states_ptr, initstates_ptr, # Matrix dimensions - hdim, - dstate, - chunk_size, - seqlen, - nheads_ngroups_ratio, + hdim: tl.constexpr, + dstate: tl.constexpr, + chunk_size: tl.constexpr, + nheads_ngroups_ratio: tl.constexpr, # Strides - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_b_seqlen, - stride_b_head, - stride_b_dstate, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_chunk_states_chunk, - stride_chunk_states_head, - stride_chunk_states_hdim, - stride_chunk_states_dstate, - stride_states_batch, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_init_states_batch, - stride_init_states_head, - stride_init_states_hdim, - stride_init_states_dstate, + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_b_dstate: tl.constexpr, + stride_dt_head: tl.int64, + stride_dt_chunk: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_chunk_states_chunk: tl.int64, + stride_chunk_states_head: tl.int64, + stride_chunk_states_hdim: tl.int64, + stride_chunk_states_dstate: tl.constexpr, + stride_states_batch: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_init_states_batch: tl.int64, + stride_init_states_head: tl.int64, + stride_init_states_hdim: tl.int64, + stride_init_states_dstate: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -558,52 +537,47 @@ def _chunk_cumsum_fwd(dt, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): - batch, seqlen, nheads = dt.shape + seqlen, nheads = dt.shape assert A.shape == (nheads, ) if dt_bias is not None: assert dt_bias.shape == (nheads, ) nchunks = math.ceil(seqlen / chunk_size) - dt_out = torch.empty(batch, - nheads, + dt_out = torch.empty(nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - dA_cumsum = torch.empty(batch, - nheads, + dA_cumsum = torch.empty(nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - grid_chunk_cs = lambda META: (batch, nchunks, + grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( - dt, - A, - dt_bias, - dt_out, - dA_cumsum, - batch, - seqlen, - nheads, - chunk_size, - dt_limit[0], - dt_limit[1], - dt.stride(0), - dt.stride(1), - dt.stride(2), - A.stride(0), - dt_bias.stride(0) if dt_bias is not None else 0, - dt_out.stride(0), - dt_out.stride(2), - dt_out.stride(1), - dt_out.stride(3), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(3), - dt_softplus, + dt_ptr=dt, + A_ptr=A, + dt_bias_ptr=dt_bias, + dt_out_ptr=dt_out, + dA_cumsum_ptr=dA_cumsum, + seqlen=seqlen, + nheads=nheads, + chunk_size=chunk_size, + dt_min=dt_limit[0], + dt_max=dt_limit[1], + stride_dt_seqlen=dt.stride(0), + stride_dt_head=dt.stride(1), + stride_A_head=A.stride(0), + stride_dt_bias_head=dt_bias.stride(0) + if dt_bias is not None else 0, + stride_dt_out_head=dt_out.stride(0), + stride_dt_out_chunk=dt_out.stride(1), + stride_dt_out_csize=dt_out.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + DT_SOFTPLUS=dt_softplus, HAS_DT_BIAS=dt_bias is not None, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) @@ -617,63 +591,57 @@ def _chunk_state_fwd(B, seq_idx=None, states=None, states_in_fp32=True): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape + seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape assert nheads % ngroups == 0 - assert B.shape == (batch, seqlen, ngroups, dstate) - assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert B.shape == (seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) + + assert seq_idx is not None + assert seq_idx.shape == (seqlen, ) + if states is not None: - assert states.shape == (batch, nchunks, nheads, headdim, dstate) + assert states.shape == (nchunks, nheads, headdim, dstate) else: states_dtype = torch.float32 if states_in_fp32 else B.dtype - states = torch.empty((batch, nchunks, nheads, headdim, dstate), + states = torch.empty((nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) - grid = lambda META: ( - triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( - dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. + cdiv(dstate, META['BLOCK_SIZE_N']), nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_state_fwd_kernel[grid]( - x, - B, - states, - dt, - dA_cumsum, - seq_idx, - headdim, - dstate, - chunk_size, - batch, - seqlen, - nheads // ngroups, - x.stride(0), - x.stride(1), - x.stride(2), - x.stride(3), - B.stride(0), - B.stride(1), - B.stride(2), - B.stride(-1), - states.stride(0), - states.stride(1), - states.stride(2), - states.stride(3), - states.stride(4), - dt.stride(0), - dt.stride(2), - dt.stride(1), - dt.stride(3), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(3), - *((seq_idx.stride(0), - seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - HAS_SEQ_IDX=seq_idx is not None, + x_ptr=x, + b_ptr=B, + states_ptr=states, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + seq_idx_ptr=seq_idx, + hdim=headdim, + dstate=dstate, + chunk_size=chunk_size, + seqlen=seqlen, + nheads_ngroups_ratio=nheads // ngroups, + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_b_seqlen=B.stride(0), + stride_b_head=B.stride(1), + stride_b_dstate=B.stride(2), + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_dt_head=dt.stride(0), + stride_dt_chunk=dt.stride(1), + stride_dt_csize=dt.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_seq_idx_seqlen=seq_idx.stride(0), ) return states @@ -705,46 +673,52 @@ def chunk_state_varlen(B, dstate, dtype=chunk_states.dtype, device=chunk_states.device) + + initial_states_strides = ((initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) + if initial_states is not None else (0, 0, 0, 0)) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads) with torch.cuda.device(x.device.index): _chunk_state_varlen_kernel[grid]( - x, - B, - dt, - dA_cumsum, - chunk_states, - cu_seqlens, - states, - initial_states, - headdim, - dstate, - chunk_size, - total_seqlen, - nheads // ngroups, - x.stride(0), - x.stride(1), - x.stride(2), - B.stride(0), - B.stride(1), - B.stride(2), - dt.stride(1), - dt.stride(0), - dt.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - chunk_states.stride(0), - chunk_states.stride(1), - chunk_states.stride(2), - chunk_states.stride(3), - states.stride(0), - states.stride(1), - states.stride(2), - states.stride(3), - *((initial_states.stride(0), initial_states.stride(1), - initial_states.stride(2), - initial_states.stride(3)) if initial_states is not None else - (0, 0, 0, 0)), + x_ptr=x, + b_ptr=B, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + chunk_states_ptr=chunk_states, + cu_seqlens_ptr=cu_seqlens, + states_ptr=states, + initstates_ptr=initial_states, + hdim=headdim, + dstate=dstate, + chunk_size=chunk_size, + nheads_ngroups_ratio=nheads // ngroups, + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_b_seqlen=B.stride(0), + stride_b_head=B.stride(1), + stride_b_dstate=B.stride(2), + stride_dt_head=dt.stride(0), + stride_dt_chunk=dt.stride(1), + stride_dt_csize=dt.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_chunk_states_chunk=chunk_states.stride(0), + stride_chunk_states_head=chunk_states.stride(1), + stride_chunk_states_hdim=chunk_states.stride(2), + stride_chunk_states_dstate=chunk_states.stride(3), + stride_states_batch=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_init_states_batch=initial_states_strides[0], + stride_init_states_head=initial_states_strides[1], + stride_init_states_hdim=initial_states_strides[2], + stride_init_states_dstate=initial_states_strides[3], HAS_INITSTATES=initial_states is not None) return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index fcc5c905bf77..37d6c2870812 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -31,6 +31,7 @@ def _mamba_chunk_scan_combined_fwd(x, B, C, chunk_size, + out, D=None, z=None, dt_bias=None, @@ -41,14 +42,13 @@ def _mamba_chunk_scan_combined_fwd(x, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), - state_dtype=None, - out=None): + state_dtype=None): assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" - batch, seqlen, nheads, headdim = x.shape - _, _, ngroups, dstate = B.shape + seqlen, nheads, headdim = x.shape + _, ngroups, dstate = B.shape assert nheads % ngroups == 0 - assert B.shape == (batch, seqlen, ngroups, dstate) - assert dt.shape == (batch, seqlen, nheads) + assert B.shape == (seqlen, ngroups, dstate) + assert dt.shape == (seqlen, nheads) assert A.shape == (nheads, ) assert C.shape == B.shape if z is not None: @@ -56,25 +56,24 @@ def _mamba_chunk_scan_combined_fwd(x, if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads, ) if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) + assert seq_idx.shape == (seqlen, ) if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() if x.stride(-1) != 1 and x.stride( - 1) != 1: # Either M or K dimension should be contiguous + 0) != 1: # Either M or K dimension should be contiguous x = x.contiguous() if z is not None and z.stride(-1) != 1 and z.stride( - 1) != 1: # Either M or K dimension should be contiguous + 0) != 1: # Either M or K dimension should be contiguous z = z.contiguous() if D is not None and D.stride(-1) != 1: D = D.contiguous() + assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens" + if initial_states is not None: - if cu_seqlens is None: - assert initial_states.shape == (batch, nheads, headdim, dstate) - else: - assert initial_states.shape == (len(cu_seqlens) - 1, nheads, - headdim, dstate) + assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, + dstate) # This function executes 5 sub-functions for computing mamba # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ @@ -114,18 +113,16 @@ def _mamba_chunk_scan_combined_fwd(x, # - this will ensure that states will be updated with the rightmost flushed seq_idx # of the previous chunk. This implies that the first chunk of states is either 0 # or equal to init_states of the first example. - states, final_states = _state_passing_fwd( + states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), - dA_cumsum, + dA_cumsum, # (nheads, nchunks, chunk_size) initial_states=rearrange(initial_states, "... p n -> ... (p n)") - if initial_states is not None else None, + if initial_states is not None else + None, # (batch, nheads, headdim*dstate) seq_idx=seq_idx, - chunk_size=chunk_size, out_dtype=state_dtype if state_dtype is not None else C.dtype, - is_cont_batched=cu_seqlens is not None, chunk_offsets=chunk_offsets) - states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) - for t in [states, final_states]) + states = rearrange(states, "... (p n) -> ... p n", n=dstate) # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, @@ -144,87 +141,88 @@ def _mamba_chunk_scan_combined_fwd(x, # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had # a seq_idx change, in which case we take states information from # init_states. - out_x = _chunk_scan_fwd( + _chunk_scan_fwd( CB, x, dt, dA_cumsum, C, states, + out, # in-place update + seq_idx, D=D, z=z, - seq_idx=seq_idx, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets, initial_states=initial_states, - out=out, ) - if cu_seqlens is None: - return out_x, dt, dA_cumsum, states, final_states - else: - assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - varlen_states = chunk_state_varlen( - B.squeeze(0), - x.squeeze(0), - dt.squeeze(0), - dA_cumsum.squeeze(0), - cu_seqlens, - states.squeeze(0), - initial_states=initial_states, - ) - return out_x, dt, dA_cumsum, states, final_states, varlen_states - - -def mamba_chunk_scan_combined(x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - chunk_indices=None, - chunk_offsets=None, - cu_seqlens=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - out=None, - return_final_states=False, - return_varlen_states=False, - state_dtype=None): + + varlen_states = chunk_state_varlen( + B, + x, + dt, + dA_cumsum, + cu_seqlens, + states, + initial_states=initial_states, + ) + + return varlen_states + + +def mamba_chunk_scan_combined_varlen( + x, + dt, + A, + B, + C, + chunk_size, + cu_seqlens, + seq_idx, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + chunk_indices=None, + chunk_offsets=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + state_dtype=None, +): """ Argument: - x: (batch, seqlen, nheads, headdim) - dt: (batch, seqlen, nheads) + x: (seqlen, nheads, headdim) + dt: (seqlen, nheads) A: (nheads) - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) + B: (seqlen, ngroups, dstate) + C: (seqlen, ngroups, dstate) chunk_size: int + seq_idx: (seqlen) + cu_seqlens: (batch + 1) + out: (seqlen, nheads, headdim) preallocated output tensor D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, nheads, headdim) + z: (seqlen, nheads, headdim) dt_bias: (nheads,) initial_states: (batch, nheads, headdim, dstate) - seq_idx: (batch, seqlen) - cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True dt_softplus: Whether to apply softplus to dt - out: Preallocated output tensor + out: (seqlen, nheads, headdim) preallocated output tensor state_dtype: The data type of the ssm state + Return: + varlen_states: (batch, nheads, headdim, dstate) """ - if not return_varlen_states: - cu_seqlens = None - else: - assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" - out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( + assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input" + assert seq_idx is not None + + varlen_states = _mamba_chunk_scan_combined_fwd( x, dt, A, B, C, chunk_size, + out, D=D, z=z, dt_bias=dt_bias, @@ -235,14 +233,6 @@ def mamba_chunk_scan_combined(x, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit, - out=out, state_dtype=state_dtype) - if not return_varlen_states: - if not return_final_states: - return - else: - return final_states - else: - varlen_states = rest[0] - return (varlen_states) if not return_final_states else (final_states, - varlen_states) + + return varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index d61c3a8cdbe9..71a8a4b0a1c8 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -27,64 +27,46 @@ def _state_passing_fwd_kernel( # Pointers to matrices states_ptr, out_ptr, - final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr, chunk_offsets_ptr, chunk_meta_num, # Matrix dimensions - dim, + dim: tl.constexpr, nchunks, seqlen, - chunk_size, + chunk_size: tl.constexpr, # Strides - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_dim, - stride_out_batch, - stride_out_chunk, - stride_out_head, - stride_out_dim, - stride_final_states_batch, - stride_final_states_head, - stride_final_states_dim, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_initstates_batch, - stride_initstates_head, - stride_initstates_dim, - stride_seq_idx_batch, - stride_seq_idx_seqlen, + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_dim: tl.constexpr, + stride_out_chunk: tl.int64, + stride_out_head: tl.int64, + stride_out_dim: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_initstates_batch: tl.int64, + stride_initstates_head: tl.int64, + stride_initstates_dim: tl.constexpr, + stride_seq_idx_seqlen: tl.constexpr, # Meta-parameters HAS_INITSTATES: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - IS_CONT_BATCHED: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - pid_b = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) + pid_h = tl.program_id(axis=1) pid_m = tl.program_id(axis=0) - states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + ( - chunk_size - 1) * stride_dA_cs_csize - out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head - final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head + states_ptr += pid_h * stride_states_head + dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - + 1) * stride_dA_cs_csize + out_ptr += pid_h * stride_out_head if HAS_INITSTATES: initstates_ptr += pid_h * stride_initstates_head - if not IS_CONT_BATCHED: - initstates_ptr += pid_b * stride_initstates_batch - - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) states_ptrs = states_ptr + offs_m * stride_states_dim out_ptrs = out_ptr + offs_m * stride_out_dim - final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim # - states will be the past state of the sequence that continues on the current check if not HAS_INITSTATES: @@ -101,65 +83,63 @@ def _state_passing_fwd_kernel( out_ptrs += stride_out_chunk prev_seq_idx_chunk_end = 0 logical_chunk_idx = 0 - for c in range(nchunks): + for c in range(nchunks - 1): new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) scale_mask = True - if HAS_SEQ_IDX: - # - the seq to pass forward is the one that is flushed to the right - # boundary. - # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. - seq_idx_chunk_end = tl.load(seq_idx_ptr + (min( - (c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) - if HAS_INITSTATES: - if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: - # this means in the current chunk the rightmost flushed seq - # has changed. - # - so we do not propagate the state from previous chunk - # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch - - # - update state with seq_idx_new's init state - states = tl.load(initstates_ptrs, - mask=offs_m < dim, - other=0.0).to(tl.float32) - - # - we need to consider the cumsum only of the last sequence in the chunk - # - find its starting position (given by c_off of the logical chunk index) - # - and subtract the cumsum just before that position from the total cumsum - # - first, update the logical chunk index (add the number of sequences in the current physical chunk): - # sequence index at the start of the current chunk - seq_idx_chunk_start = tl.load(seq_idx_ptr + - min(c * chunk_size, seqlen) * - stride_seq_idx_seqlen) - logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start - # - load the chunk offset: - c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, - mask=logical_chunk_idx < chunk_meta_num, - other=0) - # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything - if c_off > 0: - # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset - dA_cs_boundary = tl.load( - dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize + - (c_off - 1) * stride_dA_cs_csize, - mask=(c_off - 1) > -1 and c_off < chunk_size, - other=0.0) - dA_cs -= dA_cs_boundary - - # - increment logical chunk index for every physical chunk - logical_chunk_idx += 1 - else: - scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end - prev_seq_idx_chunk_end = seq_idx_chunk_end + # - the seq to pass forward is the one that is flushed to the right + # boundary. + # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. + seq_idx_chunk_end = tl.load(seq_idx_ptr + + (min((c + 1) * chunk_size, seqlen) - 1) * + stride_seq_idx_seqlen) + + if HAS_INITSTATES: + if prev_seq_idx_chunk_end != seq_idx_chunk_end: + # this means in the current chunk the rightmost flushed seq + # has changed. + # - so we do not propagate the state from previous chunk + # - but rather we load that sequence's init state + initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch + + # - update state with seq_idx_new's init state + states = tl.load(initstates_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + + # - we need to consider the cumsum only of the last sequence in the chunk + # - find its starting position (given by c_off of the logical chunk index) + # - and subtract the cumsum just before that position from the total cumsum + # - first, update the logical chunk index (add the number of sequences in the current physical chunk): + # sequence index at the start of the current chunk + seq_idx_chunk_start = tl.load(seq_idx_ptr + + min(c * chunk_size, seqlen) * + stride_seq_idx_seqlen) + logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start + # - load the chunk offset: + c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, + mask=logical_chunk_idx < chunk_meta_num, + other=0) + # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything + if c_off > 0: + # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset + dA_cs_boundary = tl.load( + dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize + + (c_off - 1) * stride_dA_cs_csize, + mask=(c_off - 1) > -1 and c_off < chunk_size, + other=0.0) + dA_cs -= dA_cs_boundary + + # - increment logical chunk index for every physical chunk + logical_chunk_idx += 1 + else: + scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end + prev_seq_idx_chunk_end = seq_idx_chunk_end scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) states = scale * states + new_states - if c < nchunks - 1: - tl.store(out_ptrs, states, mask=offs_m < dim) - else: - tl.store(final_states_ptrs, states, mask=offs_m < dim) + tl.store(out_ptrs, states, mask=offs_m < dim) + states_ptrs += stride_states_chunk dA_cs_ptr += stride_dA_cs_chunk out_ptrs += stride_out_chunk @@ -168,81 +148,53 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, dA_cumsum, + seq_idx, + chunk_offsets, initial_states=None, - seq_idx=None, - chunk_size=None, out_dtype=None, - is_cont_batched=False, - chunk_offsets=None, ): - batch, nchunks, nheads, dim = states.shape - if chunk_size is None: - chunk_size = dA_cumsum.shape[-1] - else: - assert chunk_size == dA_cumsum.shape[-1] - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - if initial_states is not None: - if is_cont_batched: - # - if cu_seqlens is provided, then the initial states - # are used for continuous batching. In which case we - # require seq_idx to be provided - assert seq_idx is not None, "seq_idx must be provided for continuous batching" - # - we also need chunk_offsets to be provided, to account - # for computation of dA_cumsum from the start of the - # sequence - assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching" - else: - # - this is the regular batching case, where initial - # states are used are for each example of the batch. - assert initial_states.shape == (batch, nheads, dim) - - if seq_idx is not None: - seqlen = seq_idx.shape[-1] - assert seq_idx.shape == (batch, seqlen) + nchunks, nheads, dim = states.shape + chunk_size = dA_cumsum.shape[-1] + assert dA_cumsum.shape == (nheads, nchunks, chunk_size) + seqlen = seq_idx.shape[-1] out_dtype = states.dtype if out_dtype is None else out_dtype - out = torch.empty((batch, nchunks, nheads, dim), + out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype) - final_states = torch.empty((batch, nheads, dim), - device=states.device, - dtype=torch.float32) - grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + + initial_states_strides = ((initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2)) + if initial_states is not None else (0, 0, 0)) + + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), nheads) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( - states, - out, - final_states, - dA_cumsum, - initial_states, - seq_idx, - chunk_offsets, - len(chunk_offsets) if chunk_offsets is not None else 0, - dim, - nchunks, - seqlen if seq_idx is not None else 0, - chunk_size, - states.stride(0), - states.stride(1), - states.stride(2), - states.stride(3), - out.stride(0), - out.stride(1), - out.stride(2), - out.stride(3), - final_states.stride(0), - final_states.stride(1), - final_states.stride(2), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(3), - *((initial_states.stride(0), initial_states.stride(1), - initial_states.stride(2)) if initial_states is not None else - (0, 0, 0)), - *((seq_idx.stride(0), - seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + states_ptr=states, + out_ptr=out, + dA_cs_ptr=dA_cumsum, + initstates_ptr=initial_states, + seq_idx_ptr=seq_idx, + chunk_offsets_ptr=chunk_offsets, + chunk_meta_num=len(chunk_offsets) + if chunk_offsets is not None else 0, + dim=dim, + nchunks=nchunks, + seqlen=seqlen if seq_idx is not None else 0, + chunk_size=chunk_size if seq_idx is not None else 0, + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_dim=states.stride(2), + stride_out_chunk=out.stride(0), + stride_out_head=out.stride(1), + stride_out_dim=out.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_initstates_batch=initial_states_strides[0], + stride_initstates_head=initial_states_strides[1], + stride_initstates_dim=initial_states_strides[2], + stride_seq_idx_seqlen=seq_idx.stride(0), HAS_INITSTATES=initial_states is not None, - HAS_SEQ_IDX=seq_idx is not None, - IS_CONT_BATCHED=is_cont_batched, ) - return out, final_states + return out diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index a7acf64f302b..03265b13de50 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -35,7 +35,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_state_update) from vllm.model_executor.layers.mamba.ops.ssd_combined import ( - mamba_chunk_scan_combined) + mamba_chunk_scan_combined_varlen) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -262,6 +262,7 @@ def forward_cuda( seq_idx_p = attn_metadata.seq_idx_p chunk_indices_p = attn_metadata.chunk_indices_p chunk_offsets_p = attn_metadata.chunk_offsets_p + query_start_loc_p = attn_metadata.query_start_loc_p # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) @@ -302,9 +303,6 @@ def forward_cuda( [num_decodes, num_prefills], dim=0, ) - query_start_loc_p = ( - attn_metadata.query_start_loc[-num_prefills - 1:] - - num_decodes if has_prefill else None) # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -356,17 +354,17 @@ def forward_cuda( has_initial_states_p[:, None, None, None], ssm_state[state_indices_tensor_p], 0) - varlen_state = mamba_chunk_scan_combined( - hidden_states_p.view(1, num_prefill_tokens, + varlen_state = mamba_chunk_scan_combined_varlen( + hidden_states_p.view(num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim), - dt.unsqueeze(0), + dt, self.A, - B.view(1, num_prefill_tokens, 1, -1), - C.view(1, num_prefill_tokens, 1, -1), + B.view(num_prefill_tokens, 1, -1), + C.view(num_prefill_tokens, 1, -1), chunk_size=chunk_size, D=self.D, - z=gate_p.view(1, num_prefill_tokens, + z=gate_p.view(num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim), dt_bias=self.dt_bias, seq_idx=seq_idx_p, @@ -374,11 +372,9 @@ def forward_cuda( chunk_offsets=chunk_offsets_p, cu_seqlens=query_start_loc_p, initial_states=initial_states, - return_varlen_states=True, - return_final_states=False, dt_softplus=True, dt_limit=(0.0, float("inf")), - out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, + out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim), state_dtype=ssm_state.dtype, ) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index f45fc75334a2..6f16fda962ae 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -115,7 +115,7 @@ class Mamba2AttentionMetadata: num_prefill_tokens: int num_decodes: int num_decode_tokens: int - query_start_loc: torch.Tensor + query_start_loc_p: torch.Tensor seq_lens: torch.Tensor prep_initial_states: bool @@ -151,7 +151,7 @@ def build(self, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> Mamba2AttentionMetadata: num_reqs = common_attn_metadata.num_reqs - query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_p = None seq_lens = common_attn_metadata.seq_lens seq_idx_p = None @@ -179,7 +179,7 @@ def build(self, num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) prep_initial_states = torch.any(has_initial_states_cpu).item() has_initial_states_p = has_initial_states_cpu.to( - query_start_loc.device) + common_attn_metadata.query_start_loc.device) query_start_loc_p = common_attn_metadata.query_start_loc[ -num_prefills - 1:] - num_decode_tokens @@ -190,7 +190,6 @@ def build(self, device=query_start_loc_p.device), query_start_loc_p.diff(), output_size=num_prefill_tokens) - seq_idx_p.unsqueeze_(0) # We compute metadata for chunked prefill once at the top level # model forward and reuse them in mamba layers. If not needed, @@ -217,7 +216,7 @@ def build(self, num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, - query_start_loc=query_start_loc, + query_start_loc_p=query_start_loc_p, seq_lens=seq_lens, prep_initial_states=prep_initial_states, chunk_size=self.chunk_size,