From ca504a47297c1d1ff0f759c90a7e1ad56cebc4f3 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 17 Sep 2025 16:53:51 -0400 Subject: [PATCH 1/6] [Core] Make Whisper work with b200 + flashinfer These changes were necessary to get Whisper working on a B200 machine with the flashinfer attention backend. There are three changes: 1. Make flashinfer not reject `ENCODER_DECODER`` attention. 2. Make flashinfer handle the case where `key` and `value` are None. With cross attention (`ENCODER_DECODER`), `key` and `value` are only set the first pass through the decoder for a given request. It is then cached in the kv cache for subsequent passes. 3. In the GPU model runner, this configuration enabled a code path where `force_attention` was set to `True` in `_dummy_run()`. We need to pass a non-None `encoder_seq_lens` to the cross attention metadata builder. Signed-off-by: Russell Bryant --- vllm/v1/attention/backends/flashinfer.py | 27 ++++++++++++++---------- vllm/v1/worker/gpu_model_runner.py | 8 ++++++- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 98a4cf38bc19..a11e70195691 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -674,7 +674,8 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if attn_type != AttentionType.DECODER: + if attn_type not in (AttentionType.DECODER, + AttentionType.ENCODER_DECODER): raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " @@ -795,16 +796,20 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + if key is not None and value is not None: + # ENCODER_DECODER attention is called with key and value + # set to None after the first decode step. They are cached + # on the first pass and do not change. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 # to process the cache when the kv_cache_dtype is fp8 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f256dc160a6b..e27bf82024e2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2889,6 +2889,10 @@ def _dummy_run( for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + encoder_seq_lens: Optional[np.ndarray] = None + if self.model_config.is_encoder_decoder: + encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32) + encoder_seq_lens[0] = self.max_encoder_len common_attn_metadata = CommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + @@ -2905,7 +2909,9 @@ def _dummy_run( block_table[kv_cache_group_id].get_device_tensor(num_reqs), slot_mapping=self.input_batch.block_table[ kv_cache_group_id].slot_mapping.gpu[:num_tokens], - causal=True) + causal=True, + encoder_seq_lens=encoder_seq_lens, + ) for attn_group in self.attn_groups[kv_cache_group_id]: if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( From 3ba44ed871ec353ada5806275972252ad7089278 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 18 Sep 2025 21:14:31 -0400 Subject: [PATCH 2/6] Update type hints for key/value in FlashAttention and FlashInfer Signed-off-by: Russell Bryant --- vllm/v1/attention/backends/flash_attn.py | 6 ++++-- vllm/v1/attention/backends/flashinfer.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 20f1904b3be6..3ca286ec27e7 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -423,8 +423,8 @@ def forward( self, layer: torch.nn.Module, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, @@ -472,6 +472,8 @@ def forward( # Handle encoder attention differently - no KV cache needed if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # key/value are only None with ENCODER_DECODER attention + assert key is not None and value is not None # For encoder attention, # we use direct Q, K, V tensors without caching return self._forward_encoder_attention(query[:num_actual_tokens], diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index fa4ff390bbc0..46d12174e2f8 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -712,8 +712,8 @@ def forward( self, layer: torch.nn.Module, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, output: Optional[torch.Tensor] = None, From 12af7f2307905c6f28471fc315985829b58a0f19 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 20 Nov 2025 20:26:44 +0000 Subject: [PATCH 3/6] restore code that was lost in bad merge Signed-off-by: Russell Bryant --- vllm/v1/worker/gpu_model_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 97a2b42409a9..a7ffdcd1b1cf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3340,6 +3340,9 @@ def _dummy_run( kv_cache_group_id ].slot_mapping.gpu[:num_tokens], causal=True, + dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] + if self.dcp_world_size > 1 + else None, encoder_seq_lens=encoder_seq_lens, ) for attn_group in self.attn_groups[kv_cache_group_id]: From 301666d047e3663bcbd6b89b37705bc172d40f63 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 20 Nov 2025 20:28:14 +0000 Subject: [PATCH 4/6] Update AttentionImpl to note key and value may be None Signed-off-by: Russell Bryant --- vllm/attention/backends/abstract.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 421b0c4beb37..1f2560daf4b6 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -178,8 +178,8 @@ def forward( self, layer: AttentionLayer, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + key: torch.Tensor | None, + value: torch.Tensor | None, kv_cache: torch.Tensor, attn_metadata: T, output: torch.Tensor | None = None, From cbc2e49fc180fd1b89fa7caf72e1b1dd1161995b Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 20 Nov 2025 20:45:25 +0000 Subject: [PATCH 5/6] fix another spot where key/value could be None Signed-off-by: Russell Bryant --- vllm/v1/attention/backends/flashinfer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index ad17351772be..ffb6fd4908c3 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1188,8 +1188,9 @@ def forward( # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] - key = key[:num_actual_tokens] - value = value[:num_actual_tokens] + if key is not None and value is not None: + key = key[:num_actual_tokens] + value = value[:num_actual_tokens] output_padded = output output = output[:num_actual_tokens] @@ -1230,6 +1231,8 @@ def forward( ) assert prefill_wrapper._new_tokens._sm_scale == self.scale assert prefill_wrapper._new_tokens._causal + assert key is not None + assert value is not None prefill_wrapper.run( layer, From 83cebf44cbc5254fb0bf56df5f351bfbf4ccefbb Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 1 Dec 2025 21:14:06 +0000 Subject: [PATCH 6/6] Add supports_attn_type to flashinfer backend Signed-off-by: Russell Bryant --- vllm/v1/attention/backends/flashinfer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index ffb6fd4908c3..d989eb2d0f1b 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -290,6 +290,13 @@ class FlashInferBackend(AttentionBackend): def get_name() -> str: return "FLASHINFER" + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER_DECODER, + ) + @staticmethod def get_impl_cls() -> type["FlashInferImpl"]: return FlashInferImpl