diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 785799b6bf68..6ca2a64145bd 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -264,8 +264,8 @@ def chunked_prefill_paged_decode( # Conversion of FP8 Tensor from uint8 storage to # appropriate torch.dtype for interpretation by Triton if "fp8" in kv_cache_dtype: - assert key_cache.dtype == torch.uint8 - assert value_cache.dtype == torch.uint8 + assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] if kv_cache_dtype in ("fp8", "fp8_e4m3"): target_dtype = current_platform.fp8_dtype() diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 86d256b630bf..729b61b02906 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -744,8 +744,8 @@ def context_attention_fwd(q, # Conversion of FP8 Tensor from uint8 storage to # appropriate torch.dtype for interpretation by Triton if "fp8" in kv_cache_dtype: - assert (k_cache.dtype == torch.uint8) - assert (v_cache.dtype == torch.uint8) + assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] if kv_cache_dtype in ("fp8", "fp8_e4m3"): target_dtype = current_platform.fp8_dtype() diff --git a/vllm/envs.py b/vllm/envs.py index b007bf8c59b7..bd9104afa4aa 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -15,6 +15,7 @@ VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = False + VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False VLLM_FLASH_ATTN_VERSION: Optional[int] = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -290,6 +291,13 @@ def get_vllm_port() -> Optional[int]: lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # Use separate prefill and decode kernels for V1 attention instead of + # the unified triton kernel. + "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": + lambda: + (os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in + ("true", "1")), + # Force vllm to use a specific flash-attention version (2 or 3), only valid # when using the flash-attention backend. "VLLM_FLASH_ATTN_VERSION": @@ -323,8 +331,8 @@ def get_vllm_port() -> Optional[int]: # Whether to log responses from API Server for debugging "VLLM_DEBUG_LOG_API_SERVER_RESPONSE": - lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"). - lower() == "true", + lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" + ).lower() == "true", # S3 access information, used for tensorizer to load model from S3 "S3_ACCESS_KEY_ID": diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 4000f93984d3..a97bb85004f6 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -5,6 +5,7 @@ import torch from vllm import _custom_ops as ops +from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.ops.chunked_prefill_paged_decode import ( @@ -126,6 +127,8 @@ def __init__( "TritonAttentionImpl") self.fp8_dtype = current_platform.fp8_dtype() + self.force_prefill_decode_attn = \ + envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION def forward( self, @@ -166,9 +169,9 @@ def forward( # performance to make sure it does not introduce any overhead. num_queries_per_kv = query.shape[1] // key.shape[1] - use_prefill_decode_attn = (num_queries_per_kv & - (num_queries_per_kv - 1)) != 0 - + num_q_is_pow2 = (num_queries_per_kv & (num_queries_per_kv - 1)) == 0 + use_prefill_decode_attn = (self.force_prefill_decode_attn + or not num_q_is_pow2) num_actual_tokens = attn_metadata.num_actual_tokens if use_prefill_decode_attn: