Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/design/cuda_graphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ The following table lists backends that support full CUDA Graphs at the time of
| FlashAttention v3 | `ALWAYS` | has unified routine for both batches, so `FULL` mode is good |
| Triton Attention | `ALWAYS` | prefer `FULL_AND_PIECEWISE` since it has different kernels for prefill/mixed and pure decode batches |
| AITER FlashAttention | `UNIFORM_BATCH`| |
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | Will be set to `UNIFORM_BATCH` when using TRTLLM attention on Blackwell |
| FlashMLA | `UNIFORM_BATCH` | |
| FlashInferMLA | `UNIFORM_BATCH` | |
| AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | |
Expand Down
12 changes: 5 additions & 7 deletions tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,17 +459,15 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
is_decode = []
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
builder_cls, _ = try_get_attention_backend(backend)
dummy_builder = builder_cls(kv_cache_spec, [], vllm_config, device)
query_len_support = getattr(
dummy_builder, "query_len_support", QueryLenSupport.SINGLE_ONLY
)
if is_spec_decode_test:
query_len_support = getattr(
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
)
supports_spec = query_len_support != QueryLenSupport.SINGLE_ONLY
is_decode.append(supports_spec)
else:
threshold = getattr(builder_cls, "reorder_batch_threshold", None)
query_len_support = getattr(
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
)
threshold = getattr(dummy_builder, "reorder_batch_threshold", None)
within_threshold = q_len <= threshold if threshold else False
if (
within_threshold
Expand Down
4 changes: 0 additions & 4 deletions vllm/attention/layers/chunked_local_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import ClassVar

import torch

Expand All @@ -12,7 +11,6 @@
from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
CommonAttentionMetadata,
make_local_attention_virtual_batches,
subclass_attention_backend,
Expand All @@ -33,8 +31,6 @@ def create_chunked_local_attention_backend(
underlying_builder = underlying_attn_backend.get_builder_cls()

class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER

def build(
self,
common_prefix_len: int,
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def get_seq_len_block_table_args(


class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
reorder_batch_threshold: int = 1
reorder_batch_threshold: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I still think this needs to be set?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is set in the constructor: _init_reorder_batch_threshold(1, False)

The type annotation is left to indicate that it will never be "None" on this class and its subclasses. This is a common pattern in the changes in this PR


def __init__(
self,
Expand Down
49 changes: 24 additions & 25 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,30 +175,6 @@ def _get_sliding_window_configs(


class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetadata]):
# FA3:
# Supports full cudagraphs for all cases.
#
# FA2:
# For FA2, a graph is captured with max_query_len=1, (which is what we
# capture by default for num_tokens <= max_num_seqs when there is no
# spec-decode) then these graphs will not work for mixed prefill-decode
# (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
# in FA2.
# In summary if we are running with spec decodes the graphs would
# work for mixed prefill-decode and uniform-decode. But for non-spec decodes
# the graphs would not work for mixed prefill-decode; sorta the inverse
# of UNIFORM_SINGLE_TOKEN_DECODE.
# There's probably a better way to describe this using `AttentionCGSupport`
# but for now just set it to `UNIFORM_BATCH` to get use to drop down
# to FULL_AND_PIECEWISE.
# TODO(luka, lucas): audit FA2 as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support = (
AttentionCGSupport.ALWAYS
if get_flash_attn_version() == 3
else AttentionCGSupport.UNIFORM_BATCH
)

def __init__(
self,
kv_cache_spec: AttentionSpec,
Expand All @@ -220,8 +196,31 @@ def __init__(
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size

# FA3:
# Supports full cudagraphs for all cases.
#
# FA2:
# For FA2, a graph is captured with max_query_len=1, (which is what we
# capture by default for num_tokens <= max_num_seqs when there is no
# spec-decode) then these graphs will not work for mixed prefill-decode
# (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
# in FA2.
# In summary if we are running with spec decodes the graphs would
# work for mixed prefill-decode and uniform-decode. But for non-spec decodes
# the graphs would not work for mixed prefill-decode; sorta the inverse
# of UNIFORM_SINGLE_TOKEN_DECODE.
# There's probably a better way to describe this using `AttentionCGSupport`
# but for now just set it to `UNIFORM_BATCH` to get use to drop down
# to FULL_AND_PIECEWISE.
# TODO(luka, lucas): audit FA2 as part of:
# https://github.com/vllm-project/vllm/issues/22945
is_fa3 = get_flash_attn_version() == 3
self.cudagraph_support = (
AttentionCGSupport.ALWAYS if is_fa3 else AttentionCGSupport.UNIFORM_BATCH
)

self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = get_flash_attn_version() == 3
self.aot_schedule = is_fa3

try:
from vllm.distributed.parallel_state import get_dcp_group
Expand Down
21 changes: 11 additions & 10 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""Attention layer with FlashInfer."""

from dataclasses import dataclass
from typing import ClassVar

import numpy as np
import torch
Expand Down Expand Up @@ -271,11 +270,7 @@ class FlashInferMetadata:


class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)

reorder_batch_threshold: int = 1
reorder_batch_threshold: int

def __init__(
self,
Expand Down Expand Up @@ -312,6 +307,7 @@ def __init__(
if speculative_config is not None
else 0
)
self.has_spec_decode = num_spec_tokens > 0
self.enable_cuda_graph = (
self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
)
Expand Down Expand Up @@ -353,7 +349,13 @@ def __init__(
else:
self.q_data_type = self.model_config.dtype

# If using trtllm attention, we can support uniform_batch speculative decoding
self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
self.must_use_trtllm_decode = can_use_trtllm and self.has_spec_decode
if self.must_use_trtllm_decode:
self.cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
else:
self.cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

self._cascade_wrapper = None # Wrapper for cascade attention

Expand Down Expand Up @@ -550,7 +552,6 @@ def build(
paged_kv_last_page_len_np,
)

uses_spec_reorder = self.reorder_batch_threshold > 1
prefill_use_trtllm = use_trtllm_attention(
self.num_qo_heads,
self.num_kv_heads,
Expand All @@ -560,9 +561,9 @@ def build(
self.q_data_type,
is_prefill=True,
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder,
has_spec=self.has_spec_decode,
)
decode_use_trtllm = use_trtllm_attention(
decode_use_trtllm = self.must_use_trtllm_decode or use_trtllm_attention(
self.num_qo_heads,
self.num_kv_heads,
num_decode_tokens,
Expand All @@ -571,7 +572,7 @@ def build(
self.q_data_type,
is_prefill=False,
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder,
has_spec=self.has_spec_decode,
)

if not (prefill_use_trtllm and decode_use_trtllm):
Expand Down
5 changes: 1 addition & 4 deletions vllm/v1/attention/backends/gdn_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ class GDNAttentionMetadata:


class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

reorder_batch_threshold: int = 1

def __init__(
self,
kv_cache_spec: AttentionSpec,
Expand All @@ -81,6 +77,7 @@ def __init__(
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self._init_reorder_batch_threshold(1, self.use_spec_decode)
self.cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/attention/backends/linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class LinearAttentionMetadata:


class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]):
reorder_batch_threshold: int = 1
reorder_batch_threshold: int

def __init__(
self,
Expand All @@ -44,6 +44,7 @@ def __init__(
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
self._init_reorder_batch_threshold(1)

def build(
self,
Expand Down
9 changes: 4 additions & 5 deletions vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import abc
from typing import ClassVar, TypeVar
from typing import TypeVar

import torch

Expand All @@ -18,10 +18,7 @@


class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
reorder_batch_threshold: int = 1
cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
reorder_batch_threshold: int

def __init__(
self,
Expand All @@ -43,6 +40,8 @@ def __init__(
dtype=torch.int32,
device=device,
)
self._init_reorder_batch_threshold(1)
self.cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
Expand Down
26 changes: 16 additions & 10 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@

import torch
from tqdm import tqdm
from typing_extensions import override

import vllm.envs as envs
from vllm import _custom_ops as ops
Expand Down Expand Up @@ -495,7 +496,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# query length <= threshold are classified as decode requests.
# Use `query_len_support` (above) to set this automatically
# when speculative decoding is enabled.
reorder_batch_threshold: int = 1
reorder_batch_threshold: int

@staticmethod
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
Expand Down Expand Up @@ -620,17 +621,22 @@ def __init__(
device=device,
)

supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
self._init_reorder_batch_threshold(
self.reorder_batch_threshold, supports_spec_decode
)
# This can be overwritten by subclasses where applicable
self._init_reorder_batch_threshold(1)

# Validate consistency between query_len_support and reorder_batch_threshold
if self.query_len_support == QueryLenSupport.SINGLE_ONLY:
assert self.reorder_batch_threshold == 1, (
f"reorder_batch_threshold must be 1 when query_len_support is "
f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
@override
def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int = 1,
supports_spec_as_decode: bool | None = None,
) -> None:
if supports_spec_as_decode is None:
supports_spec_as_decode = (
self.query_len_support != QueryLenSupport.SINGLE_ONLY
)
super()._init_reorder_batch_threshold(
reorder_batch_threshold, supports_spec_as_decode
)

def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/attention/backends/mla/cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from typing import ClassVar

import torch

Expand All @@ -26,10 +25,11 @@


class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# enable full CUDA Graph support for decode-only capture
self.cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE


class CutlassMLABackend(MLACommonBackend):
Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):


class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
reorder_batch_threshold: int = 512 # process small prefills with decode pathway

def __init__(
self,
Expand Down Expand Up @@ -113,6 +111,11 @@ def __init__(
if vllm_is_batch_invariant():
self.max_num_splits = 1

# process small prefills with decode pathway
self._init_reorder_batch_threshold(512)

self.cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

def _schedule_decode(
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
):
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,19 @@


class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM

def __init__(
self,
*args,
**kwargs,
):
super().__init__(
*args,
**kwargs,
)
self.cudagraph_support = AttentionCGSupport.UNIFORM_BATCH


class FlashInferMLABackend(MLACommonBackend):
@staticmethod
Expand Down
Loading