Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ files = [
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
"vllm/v1/attention",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
Expand Down
50 changes: 28 additions & 22 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from enum import Enum
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
Protocol, Set, Tuple, Type, TypeVar)

Expand All @@ -16,7 +17,7 @@
ModelRunnerInputBuilderBase)


class AttentionType:
class AttentionType(str, Enum):
"""
Attention type.
Use string to be compatible with `torch.compile`.
Expand All @@ -31,7 +32,12 @@ class AttentionType:
ENCODER_DECODER = "encoder_decoder"


class AttentionBackend(ABC):
AttentionMetadataType = TypeVar("AttentionMetadataType")
AttentionMetadataBuilderType = TypeVar("AttentionMetadataBuilderType")


class AttentionBackend(ABC, Generic[AttentionMetadataType,
AttentionMetadataBuilderType]):
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
Expand All @@ -50,7 +56,7 @@ def get_impl_cls() -> Type["AttentionImpl"]:

@staticmethod
@abstractmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
def get_metadata_cls() -> Type[AttentionMetadataType]:
raise NotImplementedError

@staticmethod
Expand All @@ -59,12 +65,12 @@ def get_state_cls() -> Type["AttentionState"]:
raise NotImplementedError

@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
def make_metadata(cls, *args, **kwargs) -> AttentionMetadataType:
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
def get_builder_cls() -> Type[AttentionMetadataBuilderType]:
raise NotImplementedError

@staticmethod
Expand Down Expand Up @@ -105,7 +111,7 @@ def advance_step(self, model_input: "ModelRunnerInputBase",


@dataclass
class AttentionMetadata:
class AttentionMetadata(Generic[AttentionMetadataType]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
Expand Down Expand Up @@ -135,14 +141,14 @@ class AttentionMetadata:

@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
def prefill_metadata(self) -> Optional[AttentionMetadataType]:
"""Return the attention metadata that's required to run prefill
attention."""
pass

@property
@abstractmethod
def decode_metadata(self) -> Optional["AttentionMetadata"]:
def decode_metadata(self) -> Optional[AttentionMetadataType]:
"""Return the attention metadata that's required to run decode
attention."""
pass
Expand All @@ -161,10 +167,7 @@ def asdict_zerocopy(self,
}


T = TypeVar("T", bound=AttentionMetadata)


class AttentionState(ABC, Generic[T]):
class AttentionState(ABC, Generic[AttentionMetadataType]):
"""Holds attention backend-specific objects reused during the
lifetime of the model runner."""

Expand All @@ -179,22 +182,23 @@ def graph_capture(self, max_batch_size: int):
yield

@abstractmethod
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
def graph_clone(
self, batch_size: int) -> "AttentionState[AttentionMetadataType]":
"""Clone attention state to save in CUDA graph metadata."""
...

@abstractmethod
def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False) -> T:
is_encoder_decoder_model: bool = False) -> AttentionMetadataType:
"""Get attention metadata for CUDA graph capture of batch_size."""
...

@abstractmethod
def get_graph_input_buffers(
self,
attn_metadata: T,
attn_metadata: AttentionMetadataType,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
"""Get attention-specific input buffers for CUDA graph capture."""
...
Expand All @@ -203,7 +207,7 @@ def get_graph_input_buffers(
def prepare_graph_input_buffers(
self,
input_buffers: Dict[str, Any],
attn_metadata: T,
attn_metadata: AttentionMetadataType,
is_encoder_decoder_model: bool = False) -> None:
"""In-place modify input buffers dict for CUDA graph replay."""
...
Expand All @@ -214,7 +218,7 @@ def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
...


class AttentionMetadataBuilder(ABC, Generic[T]):
class AttentionMetadataBuilder(ABC, Generic[AttentionMetadataType]):
"""Abstract class for attention metadata builders."""

@abstractmethod
Expand All @@ -229,7 +233,8 @@ def prepare(self) -> None:

@abstractmethod
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> T:
cuda_graph_pad_size: int,
batch_size: int) -> AttentionMetadataType:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError

Expand All @@ -254,7 +259,7 @@ def forward(
...


class AttentionImpl(ABC, Generic[T]):
class AttentionImpl(ABC, Generic[AttentionMetadataType]):

@abstractmethod
def __init__(
Expand All @@ -280,13 +285,14 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
attn_metadata: AttentionMetadataType,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError


class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
class MLAAttentionImpl(AttentionImpl[AttentionMetadataType],
Generic[AttentionMetadataType]):

@abstractmethod
def forward(
Expand All @@ -296,7 +302,7 @@ def forward(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
attn_metadata: AttentionMetadataType,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
Expand Down
7 changes: 5 additions & 2 deletions vllm/attention/backends/placeholder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
# lack attention.


class PlaceholderAttentionBackend(AttentionBackend):
class PlaceholderAttentionBackend(
AttentionBackend["PlaceholderAttentionMetadata",
"PlaceholderAttentionMetadataBuilder"]):
"""Placeholder backend for when no attention is needed."""

@staticmethod
Expand Down Expand Up @@ -71,7 +73,8 @@ def copy_blocks(


@dataclass
class PlaceholderAttentionMetadata(AttentionMetadata):
class PlaceholderAttentionMetadata(
AttentionMetadata["PlaceholderAttentionMetadata"]):
"""Attention metadata for prefill and decode batched together."""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,8 +1329,8 @@ def _omni_get_input_positions_tensor(
audio_llm_pos_ids_list = (torch.arange(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len)).expand(3, -1) +
audio_start_idx).split(
1, dim=1)
audio_start_idx).split(1,
dim=1)
else:
audio_llm_pos_ids_list = []
added_audio_len += min(t_ntoken_per_chunk,
Expand Down
24 changes: 14 additions & 10 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.logger import init_logger
Expand All @@ -23,13 +23,14 @@
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

if current_platform.is_cuda():
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata)
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func, get_scheduler_metadata)

logger = init_logger(__name__)


class FlashAttentionBackend(AttentionBackend):
class FlashAttentionBackend(AttentionBackend["FlashAttentionMetadata",
"FlashAttentionMetadataBuilder"]):

accept_output_buffer: bool = True

Expand All @@ -46,7 +47,7 @@
return FlashAttentionImpl

@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
def get_metadata_cls() -> type["FlashAttentionMetadata"]:
return FlashAttentionMetadata

@staticmethod
Expand Down Expand Up @@ -447,9 +448,9 @@
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
self.sliding_window = [-1, -1]
else:
self.sliding_window = (sliding_window - 1, 0)
self.sliding_window = [sliding_window - 1, 0]
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
Expand Down Expand Up @@ -503,6 +504,8 @@
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, (
"vLLM FlashAttention version is unknown.")

if attn_metadata is None:
# Profiling run.
Expand Down Expand Up @@ -569,7 +572,7 @@

descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

flash_attn_varlen_func(

Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

No overload variant of "flash_attn_varlen_func" matches argument types "Any", "Any", "Any", "Any", "Any", "int", "Any", "int", "float", "bool", "Optional[list[float]]", "list[int]", "Any", "Union[int, float]", "Optional[Any]", "int", "Any", "Any", "Any" [call-overload]

Check failure on line 575 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

No overload variant of "flash_attn_varlen_func" matches argument types "Any", "Any", "Any", "Any", "Any", "int", "Any", "int", "float", "bool", "Optional[list[float]]", "list[int]", "Any", "Union[int, float]", "Optional[Any]", "int", "Any", "Any", "Any" [call-overload]
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
Expand Down Expand Up @@ -703,7 +706,7 @@
max_kv_len: int,
softmax_scale: float,
alibi_slopes: Optional[torch.Tensor],
sliding_window: tuple[int, int],
sliding_window: list[int],
logits_soft_cap: float,
block_table: torch.Tensor,
common_prefix_len: int,
Expand All @@ -716,8 +719,9 @@
) -> torch.Tensor:
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
# TODO: Support sliding window.
assert sliding_window == (-1, -1), (
"Cascade attention does not support sliding window.")
assert sliding_window == [
-1, -1
], ("Cascade attention does not support sliding window.")

num_tokens = query.shape[0]
block_size = key_cache.shape[-3]
Expand All @@ -727,7 +731,7 @@
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])

# Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func(

Check failure on line 734 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

No overload variant of "flash_attn_varlen_func" matches argument types "Any", "Any", "Any", "Any", "Any", "Any", "int", "float", "bool", "list[int]", "Any", "float", "bool", "Optional[Any]", "int", "Any", "Any", "Any" [call-overload]

Check failure on line 734 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

No overload variant of "flash_attn_varlen_func" matches argument types "Any", "Any", "Any", "Any", "Any", "Any", "int", "float", "bool", "list[int]", "Any", "float", "bool", "Optional[Any]", "int", "Any", "Any", "Any" [call-overload]
q=query,
k=key_cache,
v=value_cache,
Expand All @@ -754,7 +758,7 @@
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

# Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func(

Check failure on line 761 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

No overload variant of "flash_attn_varlen_func" matches argument types "Any", "Any", "Any", "Any", "Any", "int", "int", "float", "bool", "list[int]", "Any", "float", "bool", "Optional[Any]", "int", "Any", "Any", "Any" [call-overload]

Check failure on line 761 in vllm/v1/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

No overload variant of "flash_attn_varlen_func" matches argument types "Any", "Any", "Any", "Any", "Any", "int", "int", "float", "bool", "list[int]", "Any", "float", "bool", "Optional[Any]", "int", "Any", "Any", "Any" [call-overload]
q=query,
k=key_cache,
v=value_cache,
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
logger = init_logger(__name__)


class FlashInferBackend(AttentionBackend):
class FlashInferBackend(AttentionBackend["FlashInferMetadata",
"FlashInferMetadataBuilder"]):

accept_output_buffer: bool = True

Expand Down
20 changes: 12 additions & 8 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.merge_attn_states import merge_attn_states
Expand All @@ -207,11 +206,13 @@
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version

try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
import vllm.vllm_flash_attn as vfa
flash_attn_varlen_func = vfa.flash_attn_varlen_func # type: ignore[attr-defined]
is_vllm_fa = True
except ImportError:
# For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func
import flash_attn as fa
flash_attn_varlen_func = fa.flash_attn_varlen_func
is_vllm_fa = False

if TYPE_CHECKING:
Expand All @@ -222,7 +223,8 @@
logger = init_logger(__name__)


class MLACommonBackend(AttentionBackend):
class MLACommonBackend(AttentionBackend["MLACommonMetadata",
"MLACommonMetadataBuilder"]):

accept_output_buffer: bool = True

Expand All @@ -231,7 +233,7 @@
return "TRITON_MLA_VLLM_V1"

@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
def get_metadata_cls() -> type["MLACommonMetadata"]:
return MLACommonMetadata

@staticmethod
Expand Down Expand Up @@ -640,16 +642,17 @@
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,

Check failure on line 645 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "partial[tuple[int, int, int]]", variable has type overloaded function) [assignment]

Check failure on line 645 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "partial[tuple[int, int, int]]", variable has type overloaded function) [assignment]
fa_version=self.vllm_flash_attn_version)

# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
device_capability = current_platform.get_device_capability()
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
self.vllm_flash_attn_version == 3 and device_capability is not None
and device_capability[0] == 9)

def _flash_attn_varlen_diff_headdims(self,
q,
Expand Down Expand Up @@ -679,7 +682,7 @@

# unpad if necessary
if self._pad_v:
attn_out = attn_out[..., :v.shape[-1]]

Check failure on line 685 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

No overload variant of "__getitem__" of "tuple" matches argument type "tuple[ellipsis, slice]" [call-overload]

Check failure on line 685 in vllm/v1/attention/backends/mla/common.py

View workflow job for this annotation

GitHub Actions / pre-commit

No overload variant of "__getitem__" of "tuple" matches argument type "tuple[ellipsis, slice]" [call-overload]

# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
Expand Down Expand Up @@ -721,7 +724,8 @@
f" {WEIGHT_NAMES}.")

def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
if not (layer.quant_method is None or isinstance(
layer.quant_method, UnquantizedLinearMethod)):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
Expand Down
Loading