Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,7 +1417,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1",
"TRITON_MLA",
"CUTLASS_MLA_VLLM_V1",
"CUTLASS_MLA",
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
Expand Down
10 changes: 5 additions & 5 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,15 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
if cls.is_device_capability(100):
# Blackwell => Force CutlassMLA.
use_cutlass_mla = True
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA_VLLM_V1"
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
else:
# Not Blackwell
use_flashmla = True
else:
# Forced case
use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
use_cutlass_mla = (
envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1")
envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")

from vllm.attention.ops.flashmla import is_flashmla_supported
if use_flashmla and is_flashmla_supported()[0] \
Expand All @@ -182,7 +182,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
if use_cutlass_mla and cache_config.block_size != 128:
cache_config.block_size = 128
logger.info("Forcing kv cache block size to 128 for "
"CUTLASS_MLA_VLLM_V1 backend.")
"CUTLASS_MLA backend.")

compilation_config = vllm_config.compilation_config
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
Expand Down Expand Up @@ -211,9 +211,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
if use_mla:
# TODO(lucas): refactor to be more concise
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:
if selected_backend == _Backend.CUTLASS_MLA:
if use_v1:
logger.info_once("Using Cutlass MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
Expand Down
2 changes: 1 addition & 1 deletion vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class _Backend(enum.Enum):
TRITON_MLA_VLLM_V1 = enum.auto()
FLASHMLA_VLLM_V1 = enum.auto()
FLASHMLA = enum.auto() # Supported by V1
CUTLASS_MLA_VLLM_V1 = enum.auto()
CUTLASS_MLA = enum.auto()
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CutlassMLABackend(MLACommonBackend):

@staticmethod
def get_name() -> str:
return "CUTLASS_MLA_VLLM_V1"
return "CUTLASS_MLA"

@staticmethod
def get_impl_cls() -> type["CutlassMLAImpl"]:
Expand Down