diff --git a/vllm/envs.py b/vllm/envs.py index 385d2a7c51f2..eeed7771f045 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1223,9 +1223,12 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_CUDNN_PREFILL": lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), - # If set to 1, use the TRTLLM attention backend in flashinfer. + # If set to 1/True, use the TRTLLM attention backend in flashinfer. + # If set to 0/False, use the default attention backend in flashinfer. + # If not set, auto-detect the attention backend in flashinfer. "VLLM_USE_TRTLLM_ATTENTION": - lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), + lambda: (None if "VLLM_USE_TRTLLM_ATTENTION" not in os.environ else + os.environ["VLLM_USE_TRTLLM_ATTENTION"].lower() in ("1", "true")), # If set to 1, when we use fp8 kv, we do not quantize Q to fp8 "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION": diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 83ec65c9b459..2179bddae243 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -154,28 +154,31 @@ def has_nvidia_artifactory() -> bool: @functools.cache -def supports_trtllm_attention() -> tuple[bool, Optional[str]]: - """Cache result which only depends on the environment""" - # This is a lambda, call it once - env_value = envs.VLLM_USE_TRTLLM_ATTENTION - +def supports_trtllm_attention() -> bool: + """ + TRTLLM attention is supported if the platform is SM100 and + NVIDIA artifactory is accessible + """ # Requires SM100 and NVIDIA artifactory to be accessible to download cubins - if not (current_platform.is_device_capability(100) - and has_nvidia_artifactory()): - return False, env_value + return current_platform.is_device_capability( + 100) and has_nvidia_artifactory() + +@functools.cache +def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]: + """Cache the env value for VLLM_USE_TRTLLM_ATTENTION""" if env_value is not None: logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) - # Environment variable is set - respect it - # Making the conditional check for zero because - # the path is automatically enabled if the batch size condition - # is satisfied. - use_trtllm = (env_value == "1") - if use_trtllm: - logger.info_once("Using TRTLLM attention.") - return use_trtllm, env_value + return env_value - return True, None + +def force_use_trtllm_attention() -> Optional[bool]: + """ + Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set, + return ``True`` if TRTLLM attention is forced to be used, + return ``False`` if TRTLLM attention is forced to be not used. + """ + return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) def use_trtllm_attention( @@ -185,18 +188,38 @@ def use_trtllm_attention( max_seq_len: int, kv_cache_dtype: str, q_dtype: torch.dtype, - is_prefill: bool, has_sinks: bool = False, ) -> bool: - use_trtllm, env_value = supports_trtllm_attention() - if not use_trtllm: + """Return ``True`` if TRTLLM attention is used.""" + force_use_trtllm = force_use_trtllm_attention() + + # Environment variable is set to 0 - respect it + if force_use_trtllm is not None and not force_use_trtllm: return False + # The platform is not supported + if not supports_trtllm_attention(): + if force_use_trtllm: + logger.warning_once( + "TRTLLM attention is not supported on this platform, " + "but VLLM_USE_TRTLLM_ATTENTION is set to 1") + return False + + # The combination of query and key heads is not supported if num_qo_heads % num_kv_heads != 0: + if force_use_trtllm: + logger.warning_once( + "TRTLLM attention is not supported for this combination of " + "query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1" + ) return False # Must use TRTLLM attention if query is FP8 quantized if q_dtype == current_platform.fp8_dtype(): + if has_sinks: + raise RuntimeError( + "TRTLLM FP8-qkv kernel is not supported for attention sinks. " + "Use kv_cache_dtype=auto for now.") logger.info_once("Using TRTLLM attention (query is quantized).") return True @@ -207,15 +230,17 @@ def use_trtllm_attention( "Using TRTLLM attention (required for attention sinks).") return True - if env_value is None: + if force_use_trtllm is None: # Environment variable not set - use auto-detection - use_trtllm = (num_tokens <= 256 and max_seq_len < 131072 + use_trtllm = (num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto") if use_trtllm: logger.warning_once("Using TRTLLM attention (auto-detected).") return use_trtllm # Environment variable is set to 1 - respect it + logger.info_once( + "Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)") return True @@ -367,6 +392,7 @@ def flashinfer_disable_q_quantization() -> bool: "has_nvidia_artifactory", "supports_trtllm_attention", "use_trtllm_attention", + "flashinfer_disable_q_quantization", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", ] diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 98a4cf38bc19..dda6dd4fbea7 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -282,7 +282,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], assert self.kv_cache_spec.dtype == self.model_config.dtype self.kv_cache_dtype = self.kv_cache_spec.dtype - if supports_trtllm_attention()[0] and \ + # Use model dtype as q dtype when TRTLLM attn is not supported, or + # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to + # use fp8 q if kv cache is fp8, and will fall back to model dtype + # if TRTLLM attention kernel is not used when building attn metadata + if supports_trtllm_attention() and \ not flashinfer_disable_q_quantization(): self.q_data_type = self.kv_cache_dtype else: @@ -298,7 +302,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.window_left = self.global_hyperparameters.window_left self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap self.has_sinks = self.global_hyperparameters.has_sinks - if self.has_sinks and not supports_trtllm_attention()[0]: + if self.has_sinks and not supports_trtllm_attention(): raise NotImplementedError( "FlashInfer backend currently does not support attention " "sinks, please use trtllm on blackwell or flash attention on " @@ -477,14 +481,12 @@ def build(self, paged_kv_last_page_len_np, ) - # Check if any layer uses sinks (requires TRTLLM attention) prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, num_prefill_tokens, max_seq_len, self.cache_dtype, self.q_data_type, - is_prefill=True, has_sinks=self.has_sinks) decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, @@ -492,13 +494,18 @@ def build(self, max_seq_len, self.cache_dtype, self.q_data_type, - is_prefill=False, has_sinks=self.has_sinks) if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm): raise NotImplementedError( "FlashInfer backend currently does not support attention " "sinks, please use trtllm on blackwell or flash attention on " "earlier GPUs.") + + # If TRTLLM attention is not used, the q quantization is not supported. + # Fall back to use model dtype. + if not (prefill_use_trtllm and decode_use_trtllm): + self.q_data_type = self.model_config.dtype + attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, q_data_type=self.q_data_type,