@@ -154,28 +154,31 @@ def has_nvidia_artifactory() -> bool:
154154
155155
156156@functools .cache
157- def supports_trtllm_attention () -> tuple [ bool , Optional [ str ]] :
158- """Cache result which only depends on the environment"""
159- # This is a lambda, call it once
160- env_value = envs . VLLM_USE_TRTLLM_ATTENTION
161-
157+ def supports_trtllm_attention () -> bool :
158+ """
159+ TRTLLM attention is supported if the platform is SM100 and
160+ NVIDIA artifactory is accessible
161+ """
162162 # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
163- if not (current_platform .is_device_capability (100 )
164- and has_nvidia_artifactory ()):
165- return False , env_value
163+ return current_platform .is_device_capability (
164+ 100 ) and has_nvidia_artifactory ()
166165
166+
167+ @functools .cache
168+ def _force_use_trtllm_attention (env_value : Optional [bool ]) -> Optional [bool ]:
169+ """Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
167170 if env_value is not None :
168171 logger .info_once ("VLLM_USE_TRTLLM_ATTENTION is set to %s" , env_value )
169- # Environment variable is set - respect it
170- # Making the conditional check for zero because
171- # the path is automatically enabled if the batch size condition
172- # is satisfied.
173- use_trtllm = (env_value == "1" )
174- if use_trtllm :
175- logger .info_once ("Using TRTLLM attention." )
176- return use_trtllm , env_value
172+ return env_value
177173
178- return True , None
174+
175+ def force_use_trtllm_attention () -> Optional [bool ]:
176+ """
177+ Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
178+ return ``True`` if TRTLLM attention is forced to be used,
179+ return ``False`` if TRTLLM attention is forced to be not used.
180+ """
181+ return _force_use_trtllm_attention (envs .VLLM_USE_TRTLLM_ATTENTION )
179182
180183
181184def use_trtllm_attention (
@@ -185,18 +188,38 @@ def use_trtllm_attention(
185188 max_seq_len : int ,
186189 kv_cache_dtype : str ,
187190 q_dtype : torch .dtype ,
188- is_prefill : bool ,
189191 has_sinks : bool = False ,
190192) -> bool :
191- use_trtllm , env_value = supports_trtllm_attention ()
192- if not use_trtllm :
193+ """Return ``True`` if TRTLLM attention is used."""
194+ force_use_trtllm = force_use_trtllm_attention ()
195+
196+ # Environment variable is set to 0 - respect it
197+ if force_use_trtllm is not None and not force_use_trtllm :
193198 return False
194199
200+ # The platform is not supported
201+ if not supports_trtllm_attention ():
202+ if force_use_trtllm :
203+ logger .warning_once (
204+ "TRTLLM attention is not supported on this platform, "
205+ "but VLLM_USE_TRTLLM_ATTENTION is set to 1" )
206+ return False
207+
208+ # The combination of query and key heads is not supported
195209 if num_qo_heads % num_kv_heads != 0 :
210+ if force_use_trtllm :
211+ logger .warning_once (
212+ "TRTLLM attention is not supported for this combination of "
213+ "query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
214+ )
196215 return False
197216
198217 # Must use TRTLLM attention if query is FP8 quantized
199218 if q_dtype == current_platform .fp8_dtype ():
219+ if has_sinks :
220+ raise RuntimeError (
221+ "TRTLLM FP8-qkv kernel is not supported for attention sinks. "
222+ "Use kv_cache_dtype=auto for now." )
200223 logger .info_once ("Using TRTLLM attention (query is quantized)." )
201224 return True
202225
@@ -207,15 +230,17 @@ def use_trtllm_attention(
207230 "Using TRTLLM attention (required for attention sinks)." )
208231 return True
209232
210- if env_value is None :
233+ if force_use_trtllm is None :
211234 # Environment variable not set - use auto-detection
212- use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
235+ use_trtllm = (num_tokens <= 256 and max_seq_len <= 131072
213236 and kv_cache_dtype == "auto" )
214237 if use_trtllm :
215238 logger .warning_once ("Using TRTLLM attention (auto-detected)." )
216239 return use_trtllm
217240
218241 # Environment variable is set to 1 - respect it
242+ logger .info_once (
243+ "Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)" )
219244 return True
220245
221246
@@ -367,6 +392,7 @@ def flashinfer_disable_q_quantization() -> bool:
367392 "has_nvidia_artifactory" ,
368393 "supports_trtllm_attention" ,
369394 "use_trtllm_attention" ,
395+ "flashinfer_disable_q_quantization" ,
370396 "flashinfer_scaled_fp4_mm" ,
371397 "flashinfer_scaled_fp8_mm" ,
372398]
0 commit comments