File tree Expand file tree Collapse file tree 1 file changed +12
-1
lines changed
vllm/v1/attention/backends/mla Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Original file line number Diff line number Diff line change 11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4- from typing import Optional , Union
4+ from typing import ClassVar , Optional , Union
55
66import torch
77from flashinfer .decode import trtllm_batch_decode_with_kv_cache_mla
1212 MLACommonBackend ,
1313 MLACommonImpl ,
1414 MLACommonMetadata ,
15+ MLACommonMetadataBuilder ,
1516)
17+ from vllm .v1 .attention .backends .utils import AttentionCGSupport
1618
1719logger = init_logger (__name__ )
1820
1921FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
2022
2123
24+ class FlashInferMLAMetadataBuilder (MLACommonMetadataBuilder [MLACommonMetadata ]):
25+ # enable full CUDA Graph support for decode-only capture
26+ cudagraph_support : ClassVar [AttentionCGSupport ] = AttentionCGSupport .UNIFORM_BATCH
27+
28+
2229class FlashInferMLABackend (MLACommonBackend ):
2330 @staticmethod
2431 def get_name () -> str :
@@ -28,6 +35,10 @@ def get_name() -> str:
2835 def get_impl_cls () -> type ["FlashInferMLAImpl" ]:
2936 return FlashInferMLAImpl
3037
38+ @staticmethod
39+ def get_builder_cls () -> type ["FlashInferMLAMetadataBuilder" ]:
40+ return FlashInferMLAMetadataBuilder
41+
3142
3243g_fi_workspace = torch .zeros (
3344 FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE ,
You can’t perform that action at this time.
0 commit comments