Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,9 +1223,12 @@ def get_vllm_port() -> Optional[int]:
"VLLM_USE_CUDNN_PREFILL":
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),

# If set to 1, use the TRTLLM attention backend in flashinfer.
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
# If set to 0/False, use the default attention backend in flashinfer.
# If not set, auto-detect the attention backend in flashinfer.
"VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
lambda: (None if "VLLM_USE_TRTLLM_ATTENTION" not in os.environ else
os.environ["VLLM_USE_TRTLLM_ATTENTION"].lower() in ("1", "true")),

# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":
Expand Down
70 changes: 48 additions & 22 deletions vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,28 +154,31 @@ def has_nvidia_artifactory() -> bool:


@functools.cache
def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
"""Cache result which only depends on the environment"""
# This is a lambda, call it once
env_value = envs.VLLM_USE_TRTLLM_ATTENTION

def supports_trtllm_attention() -> bool:
"""
TRTLLM attention is supported if the platform is SM100 and
NVIDIA artifactory is accessible
"""
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
if not (current_platform.is_device_capability(100)
and has_nvidia_artifactory()):
return False, env_value
return current_platform.is_device_capability(
100) and has_nvidia_artifactory()


@functools.cache
def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]:
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
use_trtllm = (env_value == "1")
if use_trtllm:
logger.info_once("Using TRTLLM attention.")
return use_trtllm, env_value
return env_value

return True, None

def force_use_trtllm_attention() -> Optional[bool]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't read envs in cached functions, please separate into force_use_trtllm_attention (uncached) and _force_use_trtllm_attention (cached, takes env var as input)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed the comments above.

Does this mean that when has_sinks is enabled, attention+quant fusion won't work?

  • If kv=auto, all the things work fine, will use TRTLLM BF16-qkv BF16-out kernel
  • If kv=fp8, by default it will always quantize query and use TRTLLM FP8-qkv
    • In this case, we found some accuracy issues in the FP8-qkv BF16-out sinks kernel. That has been fixed in the TRTLLM upstream and need to be propagated to Flahsinfer and vLLM. For now just raise an error and suggest user to use kv=auto.
    • There is a WAR for using BF16-q FP8-kv BF16-out kernel, by setting VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION, introducing by [flashinfer] [kernel] support for fp8 kv cache for trtllm prefill attention #24197.

Back to attention+quant fusion, AFAIR, only gpt-oss need attn sinks, and we haven't started quantizing the attn output for it. If we want to enable attention+quant fusion for gpt-oss, we have to at least

  1. Quantize the gpt-oss model.
  2. Ensure the FP8-qkv FP8/NVFP4-out attn sinks kernels work and have good accuracy on gpt-oss.

So for now, we need to fix FP8-qkv BF16-out sinks kernel first, and then verify the FP8-qkv FP8/NVFP4-out kernel if we need them.

"""
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
return ``True`` if TRTLLM attention is forced to be used,
return ``False`` if TRTLLM attention is forced to be not used.
"""
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)


def use_trtllm_attention(
Expand All @@ -185,18 +188,38 @@ def use_trtllm_attention(
max_seq_len: int,
kv_cache_dtype: str,
q_dtype: torch.dtype,
is_prefill: bool,
has_sinks: bool = False,
) -> bool:
use_trtllm, env_value = supports_trtllm_attention()
if not use_trtllm:
"""Return ``True`` if TRTLLM attention is used."""
force_use_trtllm = force_use_trtllm_attention()

# Environment variable is set to 0 - respect it
if force_use_trtllm is not None and not force_use_trtllm:
return False

# The platform is not supported
if not supports_trtllm_attention():
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported on this platform, "
"but VLLM_USE_TRTLLM_ATTENTION is set to 1")
return False

# The combination of query and key heads is not supported
if num_qo_heads % num_kv_heads != 0:
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported for this combination of "
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
)
return False

# Must use TRTLLM attention if query is FP8 quantized
if q_dtype == current_platform.fp8_dtype():
if has_sinks:
raise RuntimeError(
"TRTLLM FP8-qkv kernel is not supported for attention sinks. "
"Use kv_cache_dtype=auto for now.")
logger.info_once("Using TRTLLM attention (query is quantized).")
return True

Expand All @@ -207,15 +230,17 @@ def use_trtllm_attention(
"Using TRTLLM attention (required for attention sinks).")
return True

if env_value is None:
if force_use_trtllm is None:
# Environment variable not set - use auto-detection
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
use_trtllm = (num_tokens <= 256 and max_seq_len <= 131072
and kv_cache_dtype == "auto")
if use_trtllm:
logger.warning_once("Using TRTLLM attention (auto-detected).")
return use_trtllm

# Environment variable is set to 1 - respect it
logger.info_once(
"Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
return True


Expand Down Expand Up @@ -367,6 +392,7 @@ def flashinfer_disable_q_quantization() -> bool:
"has_nvidia_artifactory",
"supports_trtllm_attention",
"use_trtllm_attention",
"flashinfer_disable_q_quantization",
"flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm",
]
17 changes: 12 additions & 5 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
assert self.kv_cache_spec.dtype == self.model_config.dtype
self.kv_cache_dtype = self.kv_cache_spec.dtype

if supports_trtllm_attention()[0] and \
# Use model dtype as q dtype when TRTLLM attn is not supported, or
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
# use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
if supports_trtllm_attention() and \
not flashinfer_disable_q_quantization():
self.q_data_type = self.kv_cache_dtype
else:
Expand All @@ -298,7 +302,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
self.window_left = self.global_hyperparameters.window_left
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
self.has_sinks = self.global_hyperparameters.has_sinks
if self.has_sinks and not supports_trtllm_attention()[0]:
if self.has_sinks and not supports_trtllm_attention():
raise NotImplementedError(
"FlashInfer backend currently does not support attention "
"sinks, please use trtllm on blackwell or flash attention on "
Expand Down Expand Up @@ -477,28 +481,31 @@ def build(self,
paged_kv_last_page_len_np,
)

# Check if any layer uses sinks (requires TRTLLM attention)
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
num_prefill_tokens,
max_seq_len,
self.cache_dtype,
self.q_data_type,
is_prefill=True,
has_sinks=self.has_sinks)
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
num_decode_tokens,
max_seq_len,
self.cache_dtype,
self.q_data_type,
is_prefill=False,
has_sinks=self.has_sinks)
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
raise NotImplementedError(
"FlashInfer backend currently does not support attention "
"sinks, please use trtllm on blackwell or flash attention on "
"earlier GPUs.")

# If TRTLLM attention is not used, the q quantization is not supported.
# Fall back to use model dtype.
if not (prefill_use_trtllm and decode_use_trtllm):
self.q_data_type = self.model_config.dtype

attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
q_data_type=self.q_data_type,
Expand Down