Skip to content

Commit 86173ad

Browse files
[Kernel] Support decode context parallelism on Blackwell with CUTLASS MLA (#24385)
Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
1 parent 795b695 commit 86173ad

File tree

5 files changed

+63
-32
lines changed

5 files changed

+63
-32
lines changed

csrc/attention/mla/sm100_cutlass_mla_kernel.cu

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ limitations under the License.
3636
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
3737
void sm100_cutlass_mla_decode(
3838
torch::Tensor const& out,
39+
torch::Tensor const& lse,
3940
torch::Tensor const& q_nope,
4041
torch::Tensor const& q_pe,
4142
torch::Tensor const& kv_c_and_k_pe_cache,
@@ -99,6 +100,7 @@ struct MlaSm100 {
99100
template <typename T>
100101
typename T::Fmha::Arguments args_from_options(
101102
at::Tensor const& out,
103+
at::Tensor const& lse,
102104
at::Tensor const& q_nope,
103105
at::Tensor const& q_pe,
104106
at::Tensor const& kv_c_and_k_pe_cache,
@@ -162,7 +164,10 @@ typename T::Fmha::Arguments args_from_options(
162164
stride_PT,
163165
page_count_total,
164166
page_size},
165-
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
167+
{static_cast<ElementOut*>(out.data_ptr()),
168+
stride_O,
169+
static_cast<ElementAcc*>(lse.defined() ? lse.data_ptr() : nullptr),
170+
stride_LSE},
166171
hw_info,
167172
// TODO(trevor-m): Change split_kv back to -1 when
168173
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
@@ -181,6 +186,7 @@ typename T::Fmha::Arguments args_from_options(
181186
template <typename Element, typename ElementOut, bool IsPaged128, typename PersistenceOption>
182187
void runMla(
183188
at::Tensor const& out,
189+
at::Tensor const& lse,
184190
at::Tensor const& q_nope,
185191
at::Tensor const& q_pe,
186192
at::Tensor const& kv_c_and_k_pe_cache,
@@ -192,7 +198,7 @@ void runMla(
192198
cudaStream_t stream) {
193199
using MlaSm100Type = MlaSm100<Element, ElementOut, IsPaged128, PersistenceOption>;
194200
typename MlaSm100Type::Fmha fmha;
195-
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
201+
auto arguments = args_from_options<MlaSm100Type>(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
196202

197203
CUTLASS_CHECK(fmha.can_implement(arguments));
198204

@@ -214,6 +220,7 @@ void runMla(
214220

215221
void sm100_cutlass_mla_decode(
216222
torch::Tensor const& out,
223+
torch::Tensor const& lse,
217224
torch::Tensor const& q_nope,
218225
torch::Tensor const& q_pe,
219226
torch::Tensor const& kv_c_and_k_pe_cache,
@@ -234,13 +241,13 @@ void sm100_cutlass_mla_decode(
234241
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
235242
if (in_dtype == at::ScalarType::Half) {
236243
runMla<cutlass::half_t, cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
237-
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
244+
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
238245
} else if (in_dtype == at::ScalarType::BFloat16) {
239246
runMla<cutlass::bfloat16_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
240-
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
247+
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
241248
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
242249
runMla<cutlass::float_e4m3_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
243-
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
250+
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
244251
} else {
245252
TORCH_CHECK(false, "Unsupported input data type of MLA");
246253
}

csrc/torch_bindings.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -516,10 +516,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
516516

517517
// SM100 CUTLASS MLA decode
518518
ops.def(
519-
"sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
520-
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
521-
" Tensor page_table, Tensor workspace, float "
522-
"scale,"
519+
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
520+
" Tensor q_pe, Tensor kv_c_and_k_pe_cache,"
521+
" Tensor seq_lens, Tensor page_table,"
522+
" Tensor workspace, float scale,"
523523
" int num_kv_splits) -> ()");
524524
// conditionally compiled so impl in source file
525525

tests/kernels/test_cutlass_mla_decode.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import math
44
import random
5+
from typing import Optional
56

67
import pytest
78
import torch
@@ -14,14 +15,20 @@
1415
def cal_diff(x: torch.Tensor,
1516
y: torch.Tensor,
1617
name: str,
17-
use_fp8: bool = False) -> None:
18+
use_fp8: bool = False,
19+
diff_threshold: Optional[float] = None) -> None:
1820
x, y = x.double(), y.double()
1921
cos_diff = 1 - 2 * (x * y).sum().item() / max(
2022
(x * x + y * y).sum().item(), 1e-12)
21-
if (use_fp8):
22-
assert cos_diff < 1e-4
23+
if diff_threshold is not None:
24+
# directly compare the cos_diff with the threshold
25+
assert cos_diff < diff_threshold
2326
else:
24-
assert cos_diff < 1e-5
27+
# use the default threshold
28+
if (use_fp8):
29+
assert cos_diff < 1e-4
30+
else:
31+
assert cos_diff < 1e-5
2532

2633

2734
CUTLASS_MLA_UNSUPPORTED_REASON = \
@@ -118,11 +125,13 @@ def cutlass_mla():
118125
dtype=torch.uint8)
119126

120127
out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype)
121-
122-
ops.sm100_cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache_flat,
123-
cache_seqlens, block_table, workspace,
124-
scale, 1)
125-
return out_ans[:, :h_q].contiguous()
128+
output_lse = torch.empty((b, MAX_HEADS),
129+
dtype=torch.float32,
130+
device=q_nope.device)
131+
ops.sm100_cutlass_mla_decode(out_ans, output_lse, q_nope, q_pe,
132+
kv_cache_flat, cache_seqlens, block_table,
133+
workspace, scale, 1)
134+
return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous()
126135

127136
def scaled_dot_product_attention(query, key, value, is_causal=False):
128137
query = query.float()
@@ -165,11 +174,14 @@ def ref_mla():
165174
lse[i] = lse_i
166175
return out, lse
167176

168-
out_cutlass = cutlass_mla()
177+
out_cutlass, lse_cutlass = cutlass_mla()
169178
out_torch, lse_torch = ref_mla()
170179
# Extract the single token (s_q=1) slice to match cutlass output shape
171180
out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv]
181+
lse_torch_slice = lse_torch[:, 0, :] # [b, h_q]
172182
cal_diff(out_cutlass, out_torch_slice, "out", use_fp8)
183+
# lse has larger numerical error, so use a larger threshold
184+
cal_diff(lse_cutlass, lse_torch_slice, "lse", use_fp8, diff_threshold=1e-3)
173185

174186
t = triton.testing.do_bench(cutlass_mla)
175187
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2

vllm/_custom_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,13 +1833,13 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
18331833
return out
18341834

18351835

1836-
def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
1837-
q_pe: torch.Tensor,
1836+
def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor,
1837+
q_nope: torch.Tensor, q_pe: torch.Tensor,
18381838
kv_c_and_k_pe_cache: torch.Tensor,
18391839
seq_lens: torch.Tensor, page_table: torch.Tensor,
18401840
workspace: torch.Tensor, scale: float,
18411841
num_kv_splits: int) -> torch.Tensor:
1842-
torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe,
1842+
torch.ops._C.sm100_cutlass_mla_decode(out, lse, q_nope, q_pe,
18431843
kv_c_and_k_pe_cache, seq_lens,
18441844
page_table, workspace, scale,
18451845
num_kv_splits)

vllm/v1/attention/backends/mla/cutlass_mla.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def ensure_size(self, attn_metadata: MLACommonMetadata,
7676

7777

7878
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
79+
can_return_lse_for_decode: bool = True
7980

8081
def __init__(
8182
self,
@@ -138,7 +139,7 @@ def _sm100_cutlass_mla_decode(
138139
workspace: torch.Tensor,
139140
sm_scale: float,
140141
num_kv_splits: int,
141-
) -> torch.Tensor:
142+
) -> tuple[torch.Tensor, torch.Tensor]:
142143
assert (q_nope.ndim == 3
143144
), f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
144145
assert (
@@ -193,9 +194,13 @@ def _sm100_cutlass_mla_decode(
193194
dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype)
194195
else q_nope.dtype)
195196
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
197+
lse = (torch.empty(
198+
(B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
199+
if self.need_to_return_lse_for_decode else torch.Tensor())
196200

197201
ops.sm100_cutlass_mla_decode(
198202
out,
203+
lse,
199204
q_nope,
200205
q_pe,
201206
kv_c_and_k_pe_cache,
@@ -205,15 +210,17 @@ def _sm100_cutlass_mla_decode(
205210
sm_scale,
206211
num_kv_splits,
207212
)
208-
return out[:, :H].contiguous()
213+
returned_lse = lse[:, :H].contiguous(
214+
) if self.need_to_return_lse_for_decode else lse
215+
return out[:, :H].contiguous(), returned_lse
209216

210217
def _sm100_forward_decode(
211218
self,
212219
q_nope: torch.Tensor,
213220
q_pe: torch.Tensor,
214221
kv_c_and_k_pe_cache: torch.Tensor,
215222
attn_metadata: MLACommonMetadata,
216-
) -> torch.Tensor:
223+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
217224
assert kv_c_and_k_pe_cache.numel() > 0
218225
assert attn_metadata.decode is not None
219226

@@ -226,13 +233,18 @@ def _sm100_forward_decode(
226233
q_nope = q_nope.clone()
227234
q_pe = q_pe.clone()
228235

229-
o = self._sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
230-
attn_metadata.decode.seq_lens,
231-
attn_metadata.decode.block_table,
232-
self._workspace.get_buf(),
233-
self.scale, self._num_kv_splits)
236+
o, lse = self._sm100_cutlass_mla_decode(
237+
q_nope,
238+
q_pe,
239+
kv_c_and_k_pe_cache,
240+
attn_metadata.decode.seq_lens,
241+
attn_metadata.decode.block_table,
242+
self._workspace.get_buf(),
243+
self.scale,
244+
self._num_kv_splits,
245+
)
234246

235-
return o
247+
return o, (lse if self.need_to_return_lse_for_decode else None)
236248

237249
# TODO: Currently we leave it here only for backup in case something is
238250
# wrong with the new SM100 CUTLASS MLA kernel
@@ -286,4 +298,4 @@ def _forward_decode(
286298
attn_metadata), None
287299

288300
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
289-
attn_metadata), None
301+
attn_metadata)

0 commit comments

Comments
 (0)