Skip to content

Commit f8a08cb

Browse files
Isotr0pyWoosukKwon
andauthored
[V1] Enable Triton(ROCm) Attention backend for Nvidia GPUs (#14071)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent b15fd2b commit f8a08cb

File tree

5 files changed

+23
-16
lines changed

5 files changed

+23
-16
lines changed

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1588,7 +1588,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
15881588
# No FlashInfer or XFormers so far.
15891589
V1_BACKENDS = [
15901590
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
1591-
"TRITON_MLA", "FLASHMLA"
1591+
"TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
15921592
]
15931593
if (envs.is_set("VLLM_ATTENTION_BACKEND")
15941594
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

vllm/platforms/cuda.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,14 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
213213
return ("vllm.attention.backends."
214214
"flashmla.FlashMLABackend")
215215
if use_v1:
216-
logger.info_once("Using Flash Attention backend on V1 engine.")
217-
return ("vllm.v1.attention.backends.flash_attn."
218-
"FlashAttentionBackend")
216+
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
217+
logger.info_once("Using Triton backend on V1 engine.")
218+
return ("vllm.v1.attention.backends."
219+
"triton_attn.TritonAttentionBackend")
220+
if cls.has_device_capability(80):
221+
logger.info_once("Using Flash Attention backend on V1 engine.")
222+
return ("vllm.v1.attention.backends."
223+
"flash_attn.FlashAttentionBackend")
219224
if selected_backend == _Backend.FLASHINFER:
220225
logger.info("Using FlashInfer backend.")
221226
return "vllm.attention.backends.flashinfer.FlashInferBackend"

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def in_wsl() -> bool:
2929
class _Backend(enum.Enum):
3030
FLASH_ATTN = enum.auto()
3131
FLASH_ATTN_VLLM_V1 = enum.auto()
32+
TRITON_ATTN_VLLM_V1 = enum.auto()
3233
XFORMERS = enum.auto()
3334
ROCM_FLASH = enum.auto()
3435
TORCH_SDPA = enum.auto()

vllm/platforms/rocm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
120120
selected_backend = (_Backend.ROCM_FLASH if selected_backend
121121
== _Backend.FLASH_ATTN else selected_backend)
122122
if envs.VLLM_USE_V1:
123-
logger.info("Using ROCm Attention backend on V1 engine.")
124-
return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend"
123+
logger.info("Using Triton Attention backend on V1 engine.")
124+
return ("vllm.v1.attention.backends."
125+
"triton_attn.TritonAttentionBackend")
125126
if selected_backend == _Backend.ROCM_FLASH:
126127
if not cls.has_device_capability(90):
127128
# not Instinct series GPUs.

vllm/v1/attention/backends/rocm_attn.py renamed to vllm/v1/attention/backends/triton_attn.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
"""Attention layer with PagedAttention on rocm"""
2+
"""Attention layer with PagedAttention and Triton prefix prefill."""
33
from typing import Any, Optional
44

55
import torch
@@ -16,7 +16,7 @@
1616
logger = init_logger(__name__)
1717

1818

19-
class ROCmAttentionBackend(AttentionBackend):
19+
class TritonAttentionBackend(AttentionBackend):
2020

2121
accept_output_buffer: bool = True
2222

@@ -26,11 +26,11 @@ def get_supported_head_sizes() -> list[int]:
2626

2727
@staticmethod
2828
def get_name() -> str:
29-
return "ROCM_ATTN_VLLM_V1"
29+
return "TRITON_ATTN_VLLM_V1"
3030

3131
@staticmethod
32-
def get_impl_cls() -> type["ROCmAttentionImpl"]:
33-
return ROCmAttentionImpl
32+
def get_impl_cls() -> type["TritonAttentionImpl"]:
33+
return TritonAttentionImpl
3434

3535
@staticmethod
3636
def get_metadata_cls() -> type["AttentionMetadata"]:
@@ -56,7 +56,7 @@ def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
5656
return FlashAttentionMetadataBuilder
5757

5858

59-
class ROCmAttentionImpl(AttentionImpl):
59+
class TritonAttentionImpl(AttentionImpl):
6060

6161
def __init__(
6262
self,
@@ -73,7 +73,7 @@ def __init__(
7373
) -> None:
7474
if blocksparse_params is not None:
7575
raise ValueError(
76-
"ROCmAttention does not support block-sparse attention.")
76+
"TritonAttention does not support block-sparse attention.")
7777
self.num_heads = num_heads
7878
self.head_size = head_size
7979
self.scale = float(scale)
@@ -90,17 +90,17 @@ def __init__(
9090
assert self.num_heads % self.num_kv_heads == 0
9191
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
9292

93-
support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes()
93+
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
9494
if head_size not in support_head_sizes:
9595
raise ValueError(
96-
f"Head size {head_size} is not supported by ROCmAttention. "
96+
f"Head size {head_size} is not supported by TritonAttention. "
9797
f"Supported head sizes are: {support_head_sizes}.")
9898

9999
if attn_type != AttentionType.DECODER:
100100
raise NotImplementedError("Encoder self-attention and "
101101
"encoder/decoder cross-attention "
102102
"are not implemented for "
103-
"ROCmAttentionImpl")
103+
"TritonAttentionImpl")
104104

105105
def forward(
106106
self,

0 commit comments

Comments
 (0)