diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index b74ae09e6112..313f941ebf93 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -21,6 +21,7 @@ class _Backend(enum.Enum): TRITON_MLA = enum.auto() CUTLASS_MLA = enum.auto() FLASHMLA = enum.auto() + FLASHMLA_SPARSE = enum.auto() FLASH_ATTN_MLA = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() @@ -43,6 +44,7 @@ class _Backend(enum.Enum): _Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501 _Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501 _Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501 + _Backend.FLASHMLA_SPARSE: "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend", # noqa: E501 _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501 _Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501 _Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501 diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 49c29de35da1..144e46d5e953 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -55,7 +55,7 @@ class FlashMLASparseBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "FLASHMLA_SPARSE_VLLM_V1" + return "FLASHMLA_SPARSE" @staticmethod def get_metadata_cls() -> type[AttentionMetadata]: