Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
66e4d6d
fix
baonudesifeizhai Sep 18, 2025
1b215ee
yapf
baonudesifeizhai Sep 18, 2025
1f8c6e3
Merge branch 'main' into whisper-cudagraphs-support
baonudesifeizhai Sep 18, 2025
4fda4fc
Update vllm/v1/worker/gpu_model_runner.py
baonudesifeizhai Sep 18, 2025
e8d12d7
yapf
baonudesifeizhai Sep 18, 2025
88059d3
fix
baonudesifeizhai Sep 19, 2025
28800dc
fix and debug
baonudesifeizhai Sep 19, 2025
51e8742
yapf fix
baonudesifeizhai Sep 19, 2025
573ffcd
Merge branch 'main' into whisper-cudagraphs-support
baonudesifeizhai Sep 19, 2025
ebf1d39
Merge branch 'main' into whisper-cudagraphs-support
baonudesifeizhai Sep 19, 2025
1195357
add log
baonudesifeizhai Sep 19, 2025
95954c1
Merge branch 'whisper-cudagraphs-support' of https://github.com/baonu…
baonudesifeizhai Sep 19, 2025
b28a68a
remove logger
baonudesifeizhai Sep 19, 2025
59583c0
Merge branch 'vllm-project:main' into whisper-cudagraphs-support
baonudesifeizhai Sep 19, 2025
83bdfc3
Merge branch 'vllm-project:main' into whisper-cudagraphs-support
baonudesifeizhai Sep 22, 2025
47f9cd4
Merge branch 'main' into whisper-cudagraphs-support
baonudesifeizhai Sep 23, 2025
342f5a8
Merge branch 'main' into whisper-cudagraphs-support
baonudesifeizhai Sep 24, 2025
89af950
Merge branch 'vllm-project:main' into whisper-cudagraphs-support
baonudesifeizhai Sep 25, 2025
6092e13
Merge branch 'vllm-project:main' into whisper-cudagraphs-support
baonudesifeizhai Sep 26, 2025
d8eb97d
Merge branch 'vllm-project:main' into whisper-cudagraphs-support
baonudesifeizhai Oct 11, 2025
2f4e230
fix format error and add tracker in encoder lengths
baonudesifeizhai Oct 11, 2025
f8dd813
Merge branch 'main' into whisper-cudagraphs-support
baonudesifeizhai Oct 16, 2025
159c66f
Merge branch 'main' into whisper-cudagraphs-support
baonudesifeizhai Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/attention/layers/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def _get_cross_slot_mapping(
) -> torch.Tensor:
"""Get cross-attention slot mappings."""

encoder_seq_lens = np.atleast_1d(encoder_seq_lens)

block_size = kv_cache_spec.block_size
slot_mappings = []

Expand Down
73 changes: 68 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def __init__(

# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
self._encoder_cudagraph_buffers: dict[tuple[Any, ...], torch.Tensor] = {}

self.use_aux_hidden_state_outputs = False
# Set up speculative decoding.
Expand Down Expand Up @@ -1023,11 +1024,19 @@ def _get_encoder_seq_lens(
return None

# Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
for req_id in scheduler_output.scheduled_encoder_inputs:
req_index = self.input_batch.req_id_to_index[req_id]
encoder_seq_lens[req_index] = self.max_encoder_len
# encoder lengths for all requests with encoder inputs.
# Note: This must include ALL requests with encoder features,
# not just those being scheduled in this step, because cross-attention
# needs encoder lengths during decode phase for CUDA graph compatibility.
encoder_seq_lens = np.zeros((int(num_reqs),), dtype=np.int32)

# Iterate through all active requests in the batch
for req_id in self.input_batch.req_ids[:num_reqs]:
req_state = self.requests.get(req_id)
# Check if this request has encoder inputs (multimodal features)
if req_state and req_state.mm_features:
req_index = self.input_batch.req_id_to_index[req_id]
encoder_seq_lens[req_index] = self.max_encoder_len

return encoder_seq_lens

Expand Down Expand Up @@ -1893,8 +1902,62 @@ def _extract_encoder_inputs(
# input_features=...)
encoder_features.update(mm_kwargs_group)

if self._should_use_encoder_cudagraph_buffers():
encoder_features = self._prepare_encoder_inputs_for_cudagraph(
encoder_features
)

return encoder_features

def _should_use_encoder_cudagraph_buffers(self) -> bool:
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE

def _prepare_encoder_inputs_for_cudagraph(
self,
encoder_inputs: dict[str, Any],
) -> dict[str, Any]:
input_features = encoder_inputs.get("input_features")
if input_features is None:
return encoder_inputs

encoder_inputs["input_features"] = self._copy_to_cudagraph_buffer(
("input_features",), input_features
)
return encoder_inputs

def _copy_to_cudagraph_buffer(
self,
key_prefix: tuple[Any, ...],
value: Any,
) -> Any:
if isinstance(value, torch.Tensor):
key = key_prefix + (tuple(value.shape), value.dtype, value.device)
buffer = self._encoder_cudagraph_buffers.get(key)
if buffer is None:
buffer = torch.empty_like(value)
self._encoder_cudagraph_buffers[key] = buffer
else:
assert (
buffer.shape == value.shape
and buffer.dtype == value.dtype
and buffer.device == value.device
), "CUDAGraph buffer mismatch for encoder inputs."
buffer.copy_(value)
return buffer

if isinstance(value, list):
return [
self._copy_to_cudagraph_buffer(key_prefix + (idx,), item)
for idx, item in enumerate(value)
]
if isinstance(value, tuple):
return tuple(
self._copy_to_cudagraph_buffer(key_prefix + (idx,), item)
for idx, item in enumerate(value)
)

return value

def get_model(self) -> nn.Module:
# get raw model out of the cudagraph wrapper.
if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)):
Expand Down