Skip to content
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,7 +1570,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
# No FlashInfer or XFormers so far.
V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
"TRITON_MLA", "FLASHMLA"
"TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
]
if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
Expand Down
11 changes: 8 additions & 3 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,14 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
return ("vllm.attention.backends."
"flashmla.FlashMLABackend")
if use_v1:
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends.flash_attn."
"FlashAttentionBackend")
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if cls.has_device_capability(80):
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Mar 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move these logs to debug

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we? I think people care about this log (although I really want to provide only one option per hardware).

logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
if selected_backend == _Backend.FLASHINFER:
logger.info("Using FlashInfer backend.")
return "vllm.attention.backends.flashinfer.FlashInferBackend"
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def in_wsl() -> bool:
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
TRITON_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
Expand Down
5 changes: 3 additions & 2 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if envs.VLLM_USE_V1:
logger.info("Using ROCm Attention backend on V1 engine.")
return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend"
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with PagedAttention on rocm"""
"""Attention layer with PagedAttention and Triton prefix prefill."""
from typing import Any, Optional

import torch
Expand All @@ -16,7 +16,7 @@
logger = init_logger(__name__)


class ROCmAttentionBackend(AttentionBackend):
class TritonAttentionBackend(AttentionBackend):

accept_output_buffer: bool = True

Expand All @@ -26,11 +26,11 @@ def get_supported_head_sizes() -> list[int]:

@staticmethod
def get_name() -> str:
return "ROCM_ATTN_VLLM_V1"
return "TRITON_ATTN_VLLM_V1"

@staticmethod
def get_impl_cls() -> type["ROCmAttentionImpl"]:
return ROCmAttentionImpl
def get_impl_cls() -> type["TritonAttentionImpl"]:
return TritonAttentionImpl

@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
Expand All @@ -56,7 +56,7 @@ def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder


class ROCmAttentionImpl(AttentionImpl):
class TritonAttentionImpl(AttentionImpl):

def __init__(
self,
Expand All @@ -73,7 +73,7 @@ def __init__(
) -> None:
if blocksparse_params is not None:
raise ValueError(
"ROCmAttention does not support block-sparse attention.")
"TritonAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand All @@ -90,17 +90,17 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes()
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by ROCmAttention. "
f"Head size {head_size} is not supported by TritonAttention. "
f"Supported head sizes are: {support_head_sizes}.")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmAttentionImpl")
"TritonAttentionImpl")

def forward(
self,
Expand Down