Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key: torch.Tensor | None,
value: torch.Tensor | None,
kv_cache: torch.Tensor,
attn_metadata: T,
output: torch.Tensor | None = None,
Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,8 @@ def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key: torch.Tensor | None,
value: torch.Tensor | None,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
Expand Down Expand Up @@ -611,6 +611,8 @@ def forward(

# Handle encoder attention differently - no KV cache needed
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# key/value are only None with ENCODER_DECODER attention
assert key is not None and value is not None
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(
Expand Down Expand Up @@ -670,6 +672,7 @@ def forward(
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)

if self.dcp_world_size > 1:
assert key is not None and value is not None
self._forward_with_dcp(
query[:num_actual_tokens],
key[:num_actual_tokens],
Expand Down
49 changes: 30 additions & 19 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,13 @@ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
def get_name() -> str:
return "FLASHINFER"

@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER_DECODER,
)

@staticmethod
def get_impl_cls() -> type["FlashInferImpl"]:
return FlashInferImpl
Expand Down Expand Up @@ -1019,12 +1026,9 @@ def __init__(

self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if attn_type != AttentionType.DECODER:
if attn_type not in (AttentionType.DECODER, AttentionType.ENCODER_DECODER):
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl"
"Encoder self-attention is not implemented for FlashInferImpl"
)

self.sinks: torch.Tensor | None = None
Expand Down Expand Up @@ -1063,8 +1067,8 @@ def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key: torch.Tensor | None,
value: torch.Tensor | None,
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: torch.Tensor | None = None,
Expand Down Expand Up @@ -1155,16 +1159,20 @@ def forward(
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if key is not None and value is not None:
# ENCODER_DECODER attention is called with key and value
# set to None after the first decode step. They are cached
# on the first pass and do not change.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
Expand All @@ -1176,8 +1184,9 @@ def forward(

# Inputs and outputs may be padded for CUDA graphs
query = query[:num_actual_tokens]
key = key[:num_actual_tokens]
value = value[:num_actual_tokens]
if key is not None and value is not None:
key = key[:num_actual_tokens]
value = value[:num_actual_tokens]
output_padded = output
output = output[:num_actual_tokens]

Expand Down Expand Up @@ -1218,6 +1227,8 @@ def forward(
)
assert prefill_wrapper._new_tokens._sm_scale == self.scale
assert prefill_wrapper._new_tokens._causal
assert key is not None
assert value is not None

prefill_wrapper.run(
layer,
Expand Down