Skip to content

Commit b421701

Browse files
elvischenvmgoin
authored andcommitted
[Bugfix] Refactor Flashinfer TRTLLM attention kernel selection logic (vllm-project#24600)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
1 parent 2951709 commit b421701

File tree

3 files changed

+65
-29
lines changed

3 files changed

+65
-29
lines changed

vllm/envs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,9 +1223,12 @@ def get_vllm_port() -> Optional[int]:
12231223
"VLLM_USE_CUDNN_PREFILL":
12241224
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
12251225

1226-
# If set to 1, use the TRTLLM attention backend in flashinfer.
1226+
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
1227+
# If set to 0/False, use the default attention backend in flashinfer.
1228+
# If not set, auto-detect the attention backend in flashinfer.
12271229
"VLLM_USE_TRTLLM_ATTENTION":
1228-
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
1230+
lambda: (None if "VLLM_USE_TRTLLM_ATTENTION" not in os.environ else
1231+
os.environ["VLLM_USE_TRTLLM_ATTENTION"].lower() in ("1", "true")),
12291232

12301233
# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
12311234
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":

vllm/utils/flashinfer.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -154,28 +154,31 @@ def has_nvidia_artifactory() -> bool:
154154

155155

156156
@functools.cache
157-
def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
158-
"""Cache result which only depends on the environment"""
159-
# This is a lambda, call it once
160-
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
161-
157+
def supports_trtllm_attention() -> bool:
158+
"""
159+
TRTLLM attention is supported if the platform is SM100 and
160+
NVIDIA artifactory is accessible
161+
"""
162162
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
163-
if not (current_platform.is_device_capability(100)
164-
and has_nvidia_artifactory()):
165-
return False, env_value
163+
return current_platform.is_device_capability(
164+
100) and has_nvidia_artifactory()
166165

166+
167+
@functools.cache
168+
def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]:
169+
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
167170
if env_value is not None:
168171
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
169-
# Environment variable is set - respect it
170-
# Making the conditional check for zero because
171-
# the path is automatically enabled if the batch size condition
172-
# is satisfied.
173-
use_trtllm = (env_value == "1")
174-
if use_trtllm:
175-
logger.info_once("Using TRTLLM attention.")
176-
return use_trtllm, env_value
172+
return env_value
177173

178-
return True, None
174+
175+
def force_use_trtllm_attention() -> Optional[bool]:
176+
"""
177+
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
178+
return ``True`` if TRTLLM attention is forced to be used,
179+
return ``False`` if TRTLLM attention is forced to be not used.
180+
"""
181+
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
179182

180183

181184
def use_trtllm_attention(
@@ -185,18 +188,38 @@ def use_trtllm_attention(
185188
max_seq_len: int,
186189
kv_cache_dtype: str,
187190
q_dtype: torch.dtype,
188-
is_prefill: bool,
189191
has_sinks: bool = False,
190192
) -> bool:
191-
use_trtllm, env_value = supports_trtllm_attention()
192-
if not use_trtllm:
193+
"""Return ``True`` if TRTLLM attention is used."""
194+
force_use_trtllm = force_use_trtllm_attention()
195+
196+
# Environment variable is set to 0 - respect it
197+
if force_use_trtllm is not None and not force_use_trtllm:
193198
return False
194199

200+
# The platform is not supported
201+
if not supports_trtllm_attention():
202+
if force_use_trtllm:
203+
logger.warning_once(
204+
"TRTLLM attention is not supported on this platform, "
205+
"but VLLM_USE_TRTLLM_ATTENTION is set to 1")
206+
return False
207+
208+
# The combination of query and key heads is not supported
195209
if num_qo_heads % num_kv_heads != 0:
210+
if force_use_trtllm:
211+
logger.warning_once(
212+
"TRTLLM attention is not supported for this combination of "
213+
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
214+
)
196215
return False
197216

198217
# Must use TRTLLM attention if query is FP8 quantized
199218
if q_dtype == current_platform.fp8_dtype():
219+
if has_sinks:
220+
raise RuntimeError(
221+
"TRTLLM FP8-qkv kernel is not supported for attention sinks. "
222+
"Use kv_cache_dtype=auto for now.")
200223
logger.info_once("Using TRTLLM attention (query is quantized).")
201224
return True
202225

@@ -207,15 +230,17 @@ def use_trtllm_attention(
207230
"Using TRTLLM attention (required for attention sinks).")
208231
return True
209232

210-
if env_value is None:
233+
if force_use_trtllm is None:
211234
# Environment variable not set - use auto-detection
212-
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
235+
use_trtllm = (num_tokens <= 256 and max_seq_len <= 131072
213236
and kv_cache_dtype == "auto")
214237
if use_trtllm:
215238
logger.warning_once("Using TRTLLM attention (auto-detected).")
216239
return use_trtllm
217240

218241
# Environment variable is set to 1 - respect it
242+
logger.info_once(
243+
"Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
219244
return True
220245

221246

@@ -367,6 +392,7 @@ def flashinfer_disable_q_quantization() -> bool:
367392
"has_nvidia_artifactory",
368393
"supports_trtllm_attention",
369394
"use_trtllm_attention",
395+
"flashinfer_disable_q_quantization",
370396
"flashinfer_scaled_fp4_mm",
371397
"flashinfer_scaled_fp8_mm",
372398
]

vllm/v1/attention/backends/flashinfer.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
282282
assert self.kv_cache_spec.dtype == self.model_config.dtype
283283
self.kv_cache_dtype = self.kv_cache_spec.dtype
284284

285-
if supports_trtllm_attention()[0] and \
285+
# Use model dtype as q dtype when TRTLLM attn is not supported, or
286+
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
287+
# use fp8 q if kv cache is fp8, and will fall back to model dtype
288+
# if TRTLLM attention kernel is not used when building attn metadata
289+
if supports_trtllm_attention() and \
286290
not flashinfer_disable_q_quantization():
287291
self.q_data_type = self.kv_cache_dtype
288292
else:
@@ -298,7 +302,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
298302
self.window_left = self.global_hyperparameters.window_left
299303
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
300304
self.has_sinks = self.global_hyperparameters.has_sinks
301-
if self.has_sinks and not supports_trtllm_attention()[0]:
305+
if self.has_sinks and not supports_trtllm_attention():
302306
raise NotImplementedError(
303307
"FlashInfer backend currently does not support attention "
304308
"sinks, please use trtllm on blackwell or flash attention on "
@@ -477,28 +481,31 @@ def build(self,
477481
paged_kv_last_page_len_np,
478482
)
479483

480-
# Check if any layer uses sinks (requires TRTLLM attention)
481484
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
482485
self.num_kv_heads,
483486
num_prefill_tokens,
484487
max_seq_len,
485488
self.cache_dtype,
486489
self.q_data_type,
487-
is_prefill=True,
488490
has_sinks=self.has_sinks)
489491
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
490492
self.num_kv_heads,
491493
num_decode_tokens,
492494
max_seq_len,
493495
self.cache_dtype,
494496
self.q_data_type,
495-
is_prefill=False,
496497
has_sinks=self.has_sinks)
497498
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
498499
raise NotImplementedError(
499500
"FlashInfer backend currently does not support attention "
500501
"sinks, please use trtllm on blackwell or flash attention on "
501502
"earlier GPUs.")
503+
504+
# If TRTLLM attention is not used, the q quantization is not supported.
505+
# Fall back to use model dtype.
506+
if not (prefill_use_trtllm and decode_use_trtllm):
507+
self.q_data_type = self.model_config.dtype
508+
502509
attn_metadata = FlashInferMetadata(
503510
num_actual_tokens=num_actual_tokens,
504511
q_data_type=self.q_data_type,

0 commit comments

Comments
 (0)