Skip to content

Commit 250fb1b

Browse files
[Bugfix] fixes the decoding metadata of dense mla's fp8 kvcache. (#27144)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 647214f commit 250fb1b

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

cmake/external_projects/flashmla.cmake

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ else()
1919
FetchContent_Declare(
2020
flashmla
2121
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
22-
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
22+
GIT_TAG 28417e516fcbf6257a422ba117ef5b6f44da5682
2323
GIT_PROGRESS TRUE
2424
CONFIGURE_COMMAND ""
2525
BUILD_COMMAND ""
@@ -66,6 +66,7 @@ if(FLASH_MLA_ARCHS)
6666
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
6767
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
6868
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
69+
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu
6970
)
7071

7172
set(FlashMLA_INCLUDES

vllm/attention/ops/flashmla.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ def get_mla_metadata(
102102
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
103103
- num_splits: (batch_size + 1), dtype torch.int32.
104104
"""
105+
if is_fp8_kvcache and topk is None:
106+
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
107+
cache_seqlens,
108+
num_q_tokens_per_head_k,
109+
num_heads_k,
110+
)
105111
return torch.ops._flashmla_C.get_mla_decoding_metadata(
106112
cache_seqlens,
107113
num_q_tokens_per_head_k,

vllm/v1/attention/backends/mla/flashmla.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191

9292
self.cg_buf_tile_scheduler_metadata = None
9393
self.cg_buf_num_splits = None
94+
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
9495

9596
device_properties = torch.cuda.get_device_properties(self.device)
9697
num_sms = device_properties.multi_processor_count
@@ -123,6 +124,7 @@ def _build_decode(
123124
seq_lens_device,
124125
self.num_q_heads,
125126
1, # MQA for the decode path
127+
is_fp8_kvcache=self.is_fp8_kvcache,
126128
)
127129

128130
# TODO: we can disambiguate between decode and mixed-prefill decode here

0 commit comments

Comments
 (0)