diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 84cca8e68607..b60cfd5c50b4 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -357,8 +357,8 @@ def forward( self, layer: AttentionLayer, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + key: torch.Tensor | None, + value: torch.Tensor | None, kv_cache: torch.Tensor, attn_metadata: T, output: torch.Tensor | None = None, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f5ad98cf2125..6fa554743f0c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -562,8 +562,8 @@ def forward( self, layer: torch.nn.Module, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + key: torch.Tensor | None, + value: torch.Tensor | None, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, output: torch.Tensor | None = None, @@ -611,6 +611,8 @@ def forward( # Handle encoder attention differently - no KV cache needed if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # key/value are only None with ENCODER_DECODER attention + assert key is not None and value is not None # For encoder attention, # we use direct Q, K, V tensors without caching return self._forward_encoder_attention( @@ -670,6 +672,7 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) if self.dcp_world_size > 1: + assert key is not None and value is not None self._forward_with_dcp( query[:num_actual_tokens], key[:num_actual_tokens], diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 8e9d764e4a12..d67e3d068e00 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -294,6 +294,13 @@ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_name() -> str: return "FLASHINFER" + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER_DECODER, + ) + @staticmethod def get_impl_cls() -> type["FlashInferImpl"]: return FlashInferImpl @@ -1019,12 +1026,9 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if attn_type != AttentionType.DECODER: + if attn_type not in (AttentionType.DECODER, AttentionType.ENCODER_DECODER): raise NotImplementedError( - "Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferImpl" + "Encoder self-attention is not implemented for FlashInferImpl" ) self.sinks: torch.Tensor | None = None @@ -1063,8 +1067,8 @@ def forward( self, layer: torch.nn.Module, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + key: torch.Tensor | None, + value: torch.Tensor | None, kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, output: torch.Tensor | None = None, @@ -1155,16 +1159,20 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + if key is not None and value is not None: + # ENCODER_DECODER attention is called with key and value + # set to None after the first decode step. They are cached + # on the first pass and do not change. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 # to process the cache when the kv_cache_dtype is fp8 @@ -1176,8 +1184,9 @@ def forward( # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] - key = key[:num_actual_tokens] - value = value[:num_actual_tokens] + if key is not None and value is not None: + key = key[:num_actual_tokens] + value = value[:num_actual_tokens] output_padded = output output = output[:num_actual_tokens] @@ -1218,6 +1227,8 @@ def forward( ) assert prefill_wrapper._new_tokens._sm_scale == self.scale assert prefill_wrapper._new_tokens._causal + assert key is not None + assert value is not None prefill_wrapper.run( layer,