Skip to content

Conversation

@baonudesifeizhai
Copy link
Contributor

@baonudesifeizhai baonudesifeizhai commented Sep 18, 2025

Purpose

Fix Whisper’s CUDA graph warmup by ensuring encoder lengths are always returned as 1-D arrays.
Guard cross-attention slot mapping against scalar inputs so cudagraph capture no longer crashes for encoder-decoder models.

Test Plan
Launch server with CUDA graphs enabled:
python -m vllm.entrypoints.openai.api_server --model openai/whisper-large-v3 --served-model-name whisper-large-v3 --compilation-config '{"cudagraph_mode": "FULL"}'
Issue a transcription via examples/online_serving/openai_transcription_client.py with a valid audio clip.
Run the parallel stress script (using a directory of well-formed .wav files) to observe concurrent throughput.

Test Result
Server starts, captures CUDA graphs without _get_cross_slot_mapping failures.
Single transcription request returns correct text.
Parallel run completes; invalid SciPy test WAVs trigger expected “No 'data' chunk” errors, while valid files succeed.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
@baonudesifeizhai baonudesifeizhai force-pushed the whisper-cudagraphs-support branch from d4bfd9f to 1b215ee Compare September 18, 2025 21:44
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix CUDA graph warmup for Whisper by ensuring encoder lengths are always 1D arrays and guarding against scalar inputs in cross-attention. The changes involve adding a defensive np.atleast_1d call in cross_attention.py and modifying array creation in gpu_model_runner.py. While the changes are functionally correct, I've identified a piece of dead code in gpu_model_runner.py that can be removed to improve code clarity and maintainability.

baonudesifeizhai and others added 2 commits September 18, 2025 17:49
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com>
@russellb
Copy link
Member

Thank you for working on this! I will give it a try tomorrow.

@Sugar-zsg
Copy link
Contributor

Sugar-zsg commented Sep 19, 2025

Thank you. I ran some tests and found that the output is always the same default value. Using the script examples/online_serving/openai_transcription_client.py gives the same result. use v1 engine.

With this PR:
transcription result: Thank you.
transcription result: Thank you.

Before:
transcription result: The first words I spoke in the original phonograph a little piece of practical poetry. Mary had her little lamb it sleet was white as snow and everywhere that Mary went, the Lamb would sure to go!

transcription result: And the old one pitch on the way to Edgar Martinez swung on the line down the left field line for a base hit. Here comes Joy. Here is Junior to third base. They're going to wave him in. The throw to the plate will be late. The Mariners are going to play for the American League Championship. I don't believe it. It just continues. My oh my.

@baonudesifeizhai
Copy link
Contributor Author

image > Thank you. I ran some tests and found that the output is always the same default value. Using the script `examples/online_serving/openai_transcription_client.py` gives the same result. use v1 engine. > > **With this PR:** **transcription result**: Thank you. **transcription result**: Thank you. > > **Before:** **transcription result:** The first words I spoke in the original phonograph a little piece of practical poetry. Mary had her little lamb it sleet was white as snow and everywhere that Mary went, the Lamb would sure to go! > > **transcription result:** And the old one pitch on the way to Edgar Martinez swung on the line down the left field line for a base hit. Here comes Joy. Here is Junior to third base. They're going to wave him in. The throw to the plate will be late. The Mariners are going to play for the American League Championship. I don't believe it. It just continues. My oh my.

@Sugar-zsg
Copy link
Contributor

Sugar-zsg commented Sep 19, 2025

@baonudesifeizhai Thank you for the reply. I tested with the latest code. After removing the cudagraph_mode configuration, I was able to get the correct result, but the latency did not change. When using cudagraph_mode=FULL, the output issue still exists.

Could you please clarify how the launch parameters should be configured in order to enable CUDA Graph correctly?

@russellb
Copy link
Member

During development, that is the output I would get if something broke in the Encoder path -- either the encoder didn't run at all, or the output didn't get passed to the decoder properly. Just a tip in case that helps with debugging.

@baonudesifeizhai
Copy link
Contributor Author

image cudagraph_mode=FULL...outputs seems fine , i will find a way to test the token output cause current wav file are very short > @baonudesifeizhai Thank you for the reply. I tested with the latest code. After removing the `cudagraph_mode` configuration, I was able to get the correct result, but the latency did not change. When using `cudagraph_mode=FULL`, the output issue still exists. > > Could you please clarify how the launch parameters should be configured in order to enable CUDA Graph correctly?

@Sugar-zsg
Copy link
Contributor

I used the same script and configuration as you, but I still cannot get the correct results.

  • Version: v0.10.2 + your code changes

  • configuration:

(EngineCore_DP0 pid=2678704) INFO 09-22 03:08:50 [core.py:76] Initializing a V1 LLM engine (v0.10.2) with config: model='openai/whisper-large-v3', speculative_config=None, tokenizer='openai/whisper-large-v3', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=448, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=openai/whisper-large-v3, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output","vllm.mamba_mixer2","vllm.mamba_mixer","vllm.short_conv","vllm.linear_attention","vllm.plamo2_mamba_mixer","vllm.gdn_attention"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"cudagraph_mode":2,"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"pass_config":{},"max_capture_size":512,"local_cache_dir":null}

  • output:

[0000] latency=6.43s usage={'type': 'duration', 'seconds': 4} text= Thank you. [0001] latency=0.08s usage={'type': 'duration', 'seconds': 15} text= Thank you. [0002] latency=0.13s usage={'type': 'duration', 'seconds': 6} text= Thank you. [0003] latency=0.08s usage={'type': 'duration', 'seconds': 24} text= Thank you. [0004] latency=0.20s usage={'type': 'duration', 'seconds': 12} text= Thank you. [0005] latency=0.09s usage={'type': 'duration', 'seconds': 14} text= Thank you. [0006] latency=0.09s usage={'type': 'duration', 'seconds': 6} text= Thank you. [0007] latency=0.09s usage={'type': 'duration', 'seconds': 4} text= Thank you. [0008] latency=0.09s usage={'type': 'duration', 'seconds': 5} text= Thank you. [0009] latency=0.09s usage={'type': 'duration', 'seconds': 8} text= Thank you. [0010] latency=0.10s usage={'type': 'duration', 'seconds': 4} text= Thank you. ……

@baonudesifeizhai
Copy link
Contributor Author

image man .... i start 2 l40s gpu... build from source .... > VLLM_LOGGING_LEVEL=info > python -m vllm.entrypoints.openai.api_server > --model openai/whisper-large-v3 > --compilation-config '{"cudagraph_mode": "FULL"}'

@Sugar-zsg
Copy link
Contributor

I wasn’t able to run tests today, but I have a couple of questions...

  1. I can run openai/whisper-large-v3, but I cannot successfully run openai/whisper-large-v3-turbo with the code submission you already removed. only decoder.
        if self.model_config.is_encoder_decoder:
            if cudagraph_mode in (CUDAGraphMode.FULL,
                                  CUDAGraphMode.FULL_AND_PIECEWISE):
                logger.warning(
                    "CUDA graph decode-only mode required for encoder-decoder "
                    "models; setting cudagraph_mode=FULL_DECODE_ONLY")
                cudagraph_mode = self.compilation_config.cudagraph_mode = \
                    CUDAGraphMode.FULL_DECODE_ONLY
            elif cudagraph_mode == CUDAGraphMode.PIECEWISE:
                logger.warning(
                    "Encoder-decoder models do not support cudagraph prefill "
                    "capture; setting cudagraph_mode=NONE")
                cudagraph_mode = self.compilation_config.cudagraph_mode = \
                    CUDAGraphMode.NONE

  1. The v0 engine can use CUDA Graph and runs normally(only decoder). Could you please confirm that the current tests are indeed using the v1 engine?

  2. Your code changes seem to be mainly for capturing the encoder part with CUDA Graph, right? Have you profiled whether CUDA Graph is actually taking effect on the encoder part in the current implementation?

Thanks!

@baonudesifeizhai
Copy link
Contributor Author

image 1 openai/whisper-large-v3-turbo can translate the audios in hugging faces datasets.... 2 vllm i think they forces to go throung v1 right now.... 3 https://paste.ubuntu.com/p/BfDPtCRJmx/ for server: python -m vllm.entrypoints.openai.api_server \ --model openai/whisper-large-v3-turbo \ --compilation-config '{ "use_cudagraph": true, "cudagraph_mode": "FULL" }' image seems good? res: > I wasn’t able to run tests today, but I have a couple of questions... > > 1. I can run `openai/whisper-large-v3`, but I cannot successfully run `openai/whisper-large-v3-turbo` with the code submission you already removed. only decoder. > > ``` > if self.model_config.is_encoder_decoder: > if cudagraph_mode in (CUDAGraphMode.FULL, > CUDAGraphMode.FULL_AND_PIECEWISE): > logger.warning( > "CUDA graph decode-only mode required for encoder-decoder " > "models; setting cudagraph_mode=FULL_DECODE_ONLY") > cudagraph_mode = self.compilation_config.cudagraph_mode = \ > CUDAGraphMode.FULL_DECODE_ONLY > elif cudagraph_mode == CUDAGraphMode.PIECEWISE: > logger.warning( > "Encoder-decoder models do not support cudagraph prefill " > "capture; setting cudagraph_mode=NONE") > cudagraph_mode = self.compilation_config.cudagraph_mode = \ > CUDAGraphMode.NONE > ``` > > 2. The v0 engine can use CUDA Graph and runs normally(only decoder). Could you please confirm that the current tests are indeed using the v1 engine? > 3. Your code changes seem to be mainly for capturing the encoder part with CUDA Graph, right? Have you profiled whether CUDA Graph is actually taking effect on the encoder part in the current implementation? > > Thanks!

@baonudesifeizhai
Copy link
Contributor Author

could you have a look ...? thanks....

Thank you for working on this! I will give it a try tomorrow.

@russellb
Copy link
Member

could you have a look ...? thanks....

Thank you for working on this! I will give it a try tomorrow.

I've been following the comments. I was hoping to see @Sugar-zsg be able to replicate success. I will try it soon.

Please also update all commit messages to include the Signed-off-by header. That will make the DCO check pass in CI.

russellb added a commit to russellb/vllm that referenced this pull request Sep 25, 2025
Whisper does not work with full cudagraphs. That is being worked on in
PR vllm-project#25208.

The failure can be reproduced reliably via
`tests/models/multimodal/generation/test_whisper.py`, at least in my
H100 development environment. The tests passed on the PR and I'm not
sure why.

Regardless, this seems like the right change to make until vllm-project#25208
sorts out exactly what changes are needed.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
@russellb
Copy link
Member

The default cudagraph mode changed to FULL_AND_PIECEWISE in #25444, but note I'm changing the default for whisper back to PIECEWISE in #25701.

The whisper tests are failing for me locally in my H100 environment. They pass in CI, but they also passed in CI on #25444 even though it broke H100 for me.

With this PR, the failure is different. It's an accuracy failure instead of failing much earlier. This is what I'm running:

❯ pytest tests/models/multimodal/generation/test_whisper.py -v -s

and here is an example failure:

_________________________________________________ test_models[openai/whisper-large-v3-turbo] __________________________________________________

>   run_test(
      ^^^^^^^^
        vllm_runner,
        model,
        tensor_parallel_size=1,
    )

tests/models/multimodal/generation/test_whisper.py:128: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

>   assert output.outputs[0].text == expected
    ^^^^^^^^^^^^^^^^^
E   AssertionError: assert ' And the 0-1...the line down' == ' And the 0-1.... My, oh, my.'
E     
E     -  And the 0-1 pitch on the way to Edgar Martinez. Swung on the line down the left field line for a base hit. Here comes Joy. Here is Junior to third base. They're going to wave him in. The throw to the plate will be late. The Mariners are going to play for the American League Championship. I don't believe it. It just continues. My, oh, my.
E     +  And the 0-1 pitch on the way to Edgar Martinez. Swung on the line down

tests/models/multimodal/generation/test_whisper.py:121: AssertionError

FAILED tests/models/multimodal/generation/test_whisper.py::test_models[openai/whisper-large-v3-turbo] - AssertionError: assert ' And the 0-1...the line down' == ' And the 0-1.... My, oh, my.'

Can you give these tests a try?

russellb added a commit to russellb/vllm that referenced this pull request Sep 25, 2025
Whisper does not work with full cudagraphs. That is being worked on in
PR vllm-project#25208.

The failure can be reproduced reliably via
`tests/models/multimodal/generation/test_whisper.py`, at least in my
H100 development environment. The tests passed on the PR and I'm not
sure why.

Regardless, this seems like the right change to make until vllm-project#25208
sorts out exactly what changes are needed.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
@baonudesifeizhai
Copy link
Contributor Author

for this branch ,A100 SXM2x ,pytest tests/models/multimodal/generation/test_whisper.py -v -s
image

however i need to export VLLM_WORKER_MULTIPROC_METHOD=spawn
export TOKENIZERS_PARALLELISM=false ... thats because it still have the problems with multiworker

pytest tests/models/multimodal/generation/test_whisper.py -v -s

@baonudesifeizhai
Copy link
Contributor Author

shoul we fix the whisper fork problem now...?

The default cudagraph mode changed to FULL_AND_PIECEWISE in #25444, but note I'm changing the default for whisper back to PIECEWISE in #25701.

The whisper tests are failing for me locally in my H100 environment. They pass in CI, but they also passed in CI on #25444 even though it broke H100 for me.

With this PR, the failure is different. It's an accuracy failure instead of failing much earlier. This is what I'm running:

❯ pytest tests/models/multimodal/generation/test_whisper.py -v -s

and here is an example failure:

_________________________________________________ test_models[openai/whisper-large-v3-turbo] __________________________________________________

>   run_test(
      ^^^^^^^^
        vllm_runner,
        model,
        tensor_parallel_size=1,
    )

tests/models/multimodal/generation/test_whisper.py:128: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

>   assert output.outputs[0].text == expected
    ^^^^^^^^^^^^^^^^^
E   AssertionError: assert ' And the 0-1...the line down' == ' And the 0-1.... My, oh, my.'
E     
E     -  And the 0-1 pitch on the way to Edgar Martinez. Swung on the line down the left field line for a base hit. Here comes Joy. Here is Junior to third base. They're going to wave him in. The throw to the plate will be late. The Mariners are going to play for the American League Championship. I don't believe it. It just continues. My, oh, my.
E     +  And the 0-1 pitch on the way to Edgar Martinez. Swung on the line down

tests/models/multimodal/generation/test_whisper.py:121: AssertionError

FAILED tests/models/multimodal/generation/test_whisper.py::test_models[openai/whisper-large-v3-turbo] - AssertionError: assert ' And the 0-1...the line down' == ' And the 0-1.... My, oh, my.'

Can you give these tests a try?

@baonudesifeizhai
Copy link
Contributor Author

export VLLM_WORKER_MULTIPROC_METHOD=spawn
export TOKENIZERS_PARALLELISM=false
export VLLM_COMPILATION_CONFIG='{"cudagraph_mode": "PIECEWISE"}'
pytest tests/models/multimodal/generation/test_whisper.py -v -s
image

@Sugar-zsg
Copy link
Contributor

I was busy with other work for a while, but no matter how I tried, I couldn’t reproduce the same results as you reported. I reviewed the related code changes, but I couldn’t understand how this modification makes CUDA Graph take effect.

From what I can tell, it seems that you’re trying to cache the encoder inputs to ensure that the decoder receives consistent inputs each time, allowing CUDA Graph to be used. However, I have the following question:

_extract_encoder_inputs is only called once per request, and during subsequent decoding steps, scheduler_output.scheduled_encoder_inputs should be empty, meaning that the buffer logic wouldn’t actually be triggered. So in theory, this change shouldn’t have any effect, right?

Could you please explain how this modification enables CUDA Graph to work? Thank you!

@baonudesifeizhai
Copy link
Contributor Author

The original PR only prevented crashes with np.atleast_1d() but didn't solve the root cause.
The real issue: During decode phase, scheduled_encoder_inputs is empty, so _get_encoder_seq_lens() returns all zeros. This breaks cross-attention slot mapping, preventing CUDA graph from working properly.
So i Check ALL active requests, not just scheduled_encoder_inputs

I was busy with other work for a while, but no matter how I tried, I couldn’t reproduce the same results as you reported. I reviewed the related code changes, but I couldn’t understand how this modification makes CUDA Graph take effect.

From what I can tell, it seems that you’re trying to cache the encoder inputs to ensure that the decoder receives consistent inputs each time, allowing CUDA Graph to be used. However, I have the following question:

_extract_encoder_inputs is only called once per request, and during subsequent decoding steps, scheduler_output.scheduled_encoder_inputs should be empty, meaning that the buffer logic wouldn’t actually be triggered. So in theory, this change shouldn’t have any effect, right?

Could you please explain how this modification enables CUDA Graph to work? Thank you!

@Sugar-zsg
Copy link
Contributor

After further analysis, I found that when the test prompt contains only a single token, there is no encoder input, which causes abnormal results (This also explains why I was never able to reproduce the same results as you earlier.). I’ve already opened a PR and try to fix this issue.

However, during re-testing, I discovered a new problem: when running batch inference, the same code works correctly on A100 GPUs, but produces abnormal results for some batch requests when running on H20 GPUs.

A100 results:

[Batch 0] Transcription:  Yet these thoughts affected Hester Prynne less with hope than apprehension.
[Batch 1] Transcription:  Yet these thoughts affected Hester Prynne less with hope than apprehension.
[Batch 2] Transcription:  Yet these thoughts affected Hester Prynne less with hope than apprehension.
[Batch 3] Transcription:  Yet these thoughts affected Hester Prynne less with hope than apprehension.
[Batch 4] Transcription:  Yet these thoughts affected Hester Prynne less with hope than apprehension.
[Batch 5] Transcription:  Yet these thoughts affected Hester Prynne less with hope than apprehension.
[Batch 6] Transcription:  Yet these thoughts affected Hester Prynne less with hope than apprehension.
[Batch 7] Transcription:  Yet these thoughts affected Hester Prynne less with hope than apprehension.
[Batch 8] Transcription:  Yet these thoughts affected Hester Prynne less with hope than apprehension.
[Batch 9] Transcription:  Yet these thoughts affected Hester Prynne less with hope than apprehension.

H20 results:

[Batch 0] Transcription:  Yet these thoughts affected Hester Prynne less with hope than apprehension.
[Batch 1] Transcription:  Yet these thoughts affected Hester Prynne
[Batch 2] Transcription:  Yet these thoughts affected Hester Pryn
[Batch 3] Transcription:  Yet these thoughts affected Hester Prynne
[Batch 4] Transcription:  Yet these thoughts affected Hester
[Batch 5] Transcription:  Yet these thoughts affected Hester P
[Batch 6] Transcription:  Yet these thoughts affected H
[Batch 7] Transcription:  Yet these thoughts affected
[Batch 8] Transcription:  Yet these
[Batch 9] Transcription:  Yet these

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants