Skip to content

Commit 680147b

Browse files
gshtrasamitm02
authored andcommitted
[Attention][V1] Toggle for v1 attention backend (vllm-project#18275)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: amit <amit.man@gmail.com>
1 parent 1e56527 commit 680147b

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ def chunked_prefill_paged_decode(
264264
# Conversion of FP8 Tensor from uint8 storage to
265265
# appropriate torch.dtype for interpretation by Triton
266266
if "fp8" in kv_cache_dtype:
267-
assert key_cache.dtype == torch.uint8
268-
assert value_cache.dtype == torch.uint8
267+
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
268+
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
269269

270270
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
271271
target_dtype = current_platform.fp8_dtype()

vllm/attention/ops/prefix_prefill.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -744,8 +744,8 @@ def context_attention_fwd(q,
744744
# Conversion of FP8 Tensor from uint8 storage to
745745
# appropriate torch.dtype for interpretation by Triton
746746
if "fp8" in kv_cache_dtype:
747-
assert (k_cache.dtype == torch.uint8)
748-
assert (v_cache.dtype == torch.uint8)
747+
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
748+
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
749749

750750
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
751751
target_dtype = current_platform.fp8_dtype()

vllm/envs.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
VLLM_NCCL_SO_PATH: Optional[str] = None
1616
LD_LIBRARY_PATH: Optional[str] = None
1717
VLLM_USE_TRITON_FLASH_ATTN: bool = False
18+
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
1819
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
1920
LOCAL_RANK: int = 0
2021
CUDA_VISIBLE_DEVICES: Optional[str] = None
@@ -290,6 +291,13 @@ def get_vllm_port() -> Optional[int]:
290291
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
291292
("true", "1")),
292293

294+
# Use separate prefill and decode kernels for V1 attention instead of
295+
# the unified triton kernel.
296+
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION":
297+
lambda:
298+
(os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in
299+
("true", "1")),
300+
293301
# Force vllm to use a specific flash-attention version (2 or 3), only valid
294302
# when using the flash-attention backend.
295303
"VLLM_FLASH_ATTN_VERSION":
@@ -323,8 +331,8 @@ def get_vllm_port() -> Optional[int]:
323331

324332
# Whether to log responses from API Server for debugging
325333
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE":
326-
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False").
327-
lower() == "true",
334+
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"
335+
).lower() == "true",
328336

329337
# S3 access information, used for tensorizer to load model from S3
330338
"S3_ACCESS_KEY_ID":

vllm/v1/attention/backends/triton_attn.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
from vllm import _custom_ops as ops
8+
from vllm import envs
89
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
910
AttentionMetadata, AttentionType)
1011
from vllm.attention.ops.chunked_prefill_paged_decode import (
@@ -126,6 +127,8 @@ def __init__(
126127
"TritonAttentionImpl")
127128

128129
self.fp8_dtype = current_platform.fp8_dtype()
130+
self.force_prefill_decode_attn = \
131+
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
129132

130133
def forward(
131134
self,
@@ -166,9 +169,9 @@ def forward(
166169
# performance to make sure it does not introduce any overhead.
167170

168171
num_queries_per_kv = query.shape[1] // key.shape[1]
169-
use_prefill_decode_attn = (num_queries_per_kv &
170-
(num_queries_per_kv - 1)) != 0
171-
172+
num_q_is_pow2 = (num_queries_per_kv & (num_queries_per_kv - 1)) == 0
173+
use_prefill_decode_attn = (self.force_prefill_decode_attn
174+
or not num_q_is_pow2)
172175
num_actual_tokens = attn_metadata.num_actual_tokens
173176

174177
if use_prefill_decode_attn:

0 commit comments

Comments
 (0)