Skip to content

Conversation

@lgeiger
Copy link
Contributor

@lgeiger lgeiger commented Oct 17, 2025

Purpose

This PR removes calls to .contiguous() in the QwenVisionAttention module. Tested with both xformers and flash attention backend.

Test Plan

vllm bench serve Qwen/Qwen2.5-VL-3B-Instruct

vllm bench serve --backend openai-chat --model Qwen/Qwen2.5-VL-3B-Instruct --endpoint /v1/chat/completions --dataset-name hf --dataset-path lmarena-ai/VisionArena-Chat --hf-split train --num-prompts 1000

Test Result

Before

Screenshot 2025-10-17 at 16 05 07 Screenshot 2025-10-17 at 16 06 26
============ Serving Benchmark Result ============
Successful requests:                     998
Failed requests:                         2
Benchmark duration (s):                  100.72
Total input tokens:                      94309
Total generated tokens:                  105944
Request throughput (req/s):              9.91
Output token throughput (tok/s):         1051.90
Peak output token throughput (tok/s):    5104.00
Peak concurrent requests:                998.00
Total Token throughput (tok/s):          1988.28
---------------Time to First Token----------------
Mean TTFT (ms):                          44047.03
Median TTFT (ms):                        43652.87
P99 TTFT (ms):                           95956.25
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          199.36
Median TPOT (ms):                        217.51
P99 TPOT (ms):                           252.47
---------------Inter-token Latency----------------
Mean ITL (ms):                           199.33
Median ITL (ms):                         220.68
P99 ITL (ms):                            437.93
==================================================

After

============ Serving Benchmark Result ============
Successful requests:                     998
Failed requests:                         2
Benchmark duration (s):                  100.23
Total input tokens:                      94290
Total generated tokens:                  106295
Request throughput (req/s):              9.96
Output token throughput (tok/s):         1060.55
Peak output token throughput (tok/s):    4899.00
Peak concurrent requests:                998.00
Total Token throughput (tok/s):          2001.33
---------------Time to First Token----------------
Mean TTFT (ms):                          43825.06
Median TTFT (ms):                        42747.21
P99 TTFT (ms):                           95296.01
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          197.58
Median TPOT (ms):                        213.70
P99 TPOT (ms):                           252.91
---------------Inter-token Latency----------------
Mean ITL (ms):                           197.16
Median ITL (ms):                         218.65
P99 ITL (ms):                            462.42
==================================================

@lgeiger lgeiger requested a review from sighingnow as a code owner October 17, 2025 14:25
@mergify mergify bot added the qwen Related to Qwen models label Oct 17, 2025
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
@lgeiger lgeiger force-pushed the qwenvl-unnecessary-contiguous branch from 1d05935 to 0940321 Compare October 17, 2025 14:27
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to optimize the Qwen-VL models by removing what appear to be unnecessary .contiguous() calls. While this change might be safe for some attention backends like Torch SDPA or xFormers, it introduces a critical bug in the Flash Attention path. The rearrange operation used to prepare tensors for Flash Attention requires contiguous inputs, and removing the .contiguous() call will lead to a runtime error. My review adds comments to restore the necessary calls to ensure the model works correctly with all supported backends.

@tjtanaa
Copy link
Contributor

tjtanaa commented Oct 18, 2025

Evaluated on ROCm

Server command: vllm serve --model Qwen/Qwen2.5-VL-3B-Instruct

Evaluation score on chartqa


For detailed information on this command, run:
  run.py eval_vllm --model_name Qwen/Qwen2.5-VL-3B-Instruct --url http://0.0.0.0:8000 --output_dir ./chartqa --eval_name chartqa - --help
================================================================================
Metrics:
{
    "explicit_prompt_relaxed_correctness": 0.8108,
    "anywhere_in_answer_relaxed_correctness": 0.8144
}
================================================================================

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving this

@vllm-bot vllm-bot merged commit 5c2acb2 into vllm-project:main Oct 18, 2025
6 checks passed
@lgeiger lgeiger deleted the qwenvl-unnecessary-contiguous branch October 18, 2025 17:07
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
…ct#27106)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
adabeyta pushed a commit to adabeyta/vllm that referenced this pull request Oct 20, 2025
…ct#27106)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
@JartX
Copy link
Contributor

JartX commented Oct 21, 2025

@lgeiger Hi! With this PR Qwen/Qwen3-VL-30B-A3B-Instruct, an invoice is considered to be a cheesecake or a sushi restaurant.
vllm-project/llm-compressor#1939 (comment), can you help me to resolve it?

JartX added a commit to JartX/vllm that referenced this pull request Oct 21, 2025
@JartX
Copy link
Contributor

JartX commented Oct 21, 2025

@tjtanaa @DarkLight1337 Could you help me with this? Right now, I can only send an image to Qwen3 VL and get a consistent result by reverting the commit.

@lgeiger
Copy link
Contributor Author

lgeiger commented Oct 21, 2025

@JartX I can't seem to reproduce your issue. I verified that on a Nvidia L40s GPU, I'm getting identical lm eval results with this PR and with it reverted for all backends with the Qwen3-VL-30B-A3B-Instruct-FP8 model.

I tested it using:

lm_eval --model vllm-vlm --model_args "pretrained=Qwen/Qwen3-VL-30B-A3B-Instruct-FP8,max_model_len=10000" --tasks chartqa --batch_size auto --apply_chat_template

Which returns the same results on main and with this PR reverted:

main (vit_attn_backend=xformers)

| Tasks   | Version | Filter | n-shot | Metric            |     |  Value |     | Stderr |
| ------- | ------: | ------ | -----: | ----------------- | --- | -----: | --- | -----: |
| chartqa |       0 | none   |      0 | anywhere_accuracy | ↑   | 0.8720 | ±   | 0.0067 |
|         |         | none   |      0 | exact_match       | ↑   | 0.6372 | ±   | 0.0096 |
|         |         | none   |      0 | relaxed_accuracy  | ↑   | 0.8624 | ±   | 0.0069 |

reverted (vit_attn_backend=xformers)

| Tasks   | Version | Filter | n-shot | Metric            |     |  Value |     | Stderr |
| ------- | ------: | ------ | -----: | ----------------- | --- | -----: | --- | -----: |
| chartqa |       0 | none   |      0 | anywhere_accuracy | ↑   | 0.8720 | ±   | 0.0067 |
|         |         | none   |      0 | exact_match       | ↑   | 0.6372 | ±   | 0.0096 |
|         |         | none   |      0 | relaxed_accuracy  | ↑   | 0.8624 | ±   | 0.0069 |

main (vit_attn_backend=flash_attn, VLLM_WORKER_MULTIPROC_METHOD=spawn)

| Tasks   | Version | Filter | n-shot | Metric            |     |  Value |     | Stderr |
| ------- | ------: | ------ | -----: | ----------------- | --- | -----: | --- | -----: |
| chartqa |       0 | none   |      0 | anywhere_accuracy | ↑   | 0.8764 | ±   | 0.0066 |
|         |         | none   |      0 | exact_match       | ↑   | 0.6412 | ±   | 0.0096 |
|         |         | none   |      0 | relaxed_accuracy  | ↑   | 0.8648 | ±   | 0.0068 |

reverted (vit_attn_backend=flash_attn, VLLM_WORKER_MULTIPROC_METHOD=spawn)

| Tasks   | Version | Filter | n-shot | Metric            |     |  Value |     | Stderr |
| ------- | ------: | ------ | -----: | ----------------- | --- | -----: | --- | -----: |
| chartqa |       0 | none   |      0 | anywhere_accuracy | ↑   | 0.8764 | ±   | 0.0066 |
|         |         | none   |      0 | exact_match       | ↑   | 0.6412 | ±   | 0.0096 |
|         |         | none   |      0 | relaxed_accuracy  | ↑   | 0.8648 | ±   | 0.0068 |

main (vit_attn_backend=TORCH_SDPA)

| Tasks   | Version | Filter | n-shot | Metric            |     |  Value |     | Stderr |
| ------- | ------: | ------ | -----: | ----------------- | --- | -----: | --- | -----: |
| chartqa |       0 | none   |      0 | anywhere_accuracy | ↑   | 0.8784 | ±   | 0.0065 |
|         |         | none   |      0 | exact_match       | ↑   | 0.6416 | ±   | 0.0096 |
|         |         | none   |      0 | relaxed_accuracy  | ↑   | 0.8676 | ±   | 0.0068 |

main (vit_attn_backend=TORCH_SDPA)

| Tasks   | Version | Filter | n-shot | Metric            |     |  Value |     | Stderr |
| ------- | ------: | ------ | -----: | ----------------- | --- | -----: | --- | -----: |
| chartqa |       0 | none   |      0 | anywhere_accuracy | ↑   | 0.8784 | ±   | 0.0065 |
|         |         | none   |      0 | exact_match       | ↑   | 0.6416 | ±   | 0.0096 |
|         |         | none   |      0 | relaxed_accuracy  | ↑   | 0.8676 | ±   | 0.0068 |

@JartX
Copy link
Contributor

JartX commented Oct 21, 2025

@lgeiger Thank you for your quick response, it this happens when I send the image via request, for example via openwebui.

@JartX
Copy link
Contributor

JartX commented Oct 21, 2025

@lgeiger my docker run command

docker run -it --rm
--name vllm
--tty
--restart unless-stopped
--shm-size=48gb
-p 80:8000
-v $(pwd)/app:/app
-v $(pwd)/models:/models
-v $(pwd)/huggecache:/root/.cache/huggingface
--device /dev/kfd
--device /dev/dri
--group-add video
--ipc=host
--network host
--cap-add SYS_PTRACE
--security-opt seccomp=unconfined
--privileged
-e HSA_OVERRIDE_GFX_VERSION=11.0.0
-e VLLM_USE_V1=1
-e VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1
vllm-rocm-251021-revert
vllm serve Qwen/Qwen3-VL-30B-A3B-Instruct
--gpu-memory-utilization 0.97
--max_model_len 40960
-tp 4
--served-model-name QWEN3VL
--port 80
--enable-auto-tool-choice
--compilation-config '{"full_cuda_graph": true}'
--disable-log-requests
--tool-call-parser hermes
--dtype float16 \

ROCM RDNA3

@lgeiger
Copy link
Contributor Author

lgeiger commented Oct 21, 2025

@JartX To double check, can you reproduce the lm eval results that I posted above or do you also see a regression with and without this PR?

@JartX
Copy link
Contributor

JartX commented Oct 21, 2025

Using..
VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_ATTENTION_BACKEND=TORCH_SDPA lm_eval --model vllm-vlm --model_args "pretrained=Qwen/Qwen3-VL-30B-A3B-Instruct,max_model_len=10000,tensor_parallel_size=4" --tasks chartqa --batch_size auto --apply_chat_template
Processed prompts: 2%| | 53/2500 [16:43<9:05:51, 13.38s/it, est. speed input: 36.35 toks/s, output: 9
I go to try launch first the vllm serve and after lm_eval api

@JartX
Copy link
Contributor

JartX commented Oct 21, 2025

@lgeiger
Would you be so kind as to try it via the API? I'm launching it this way; the other way would take me 15 hours or more. Now with VLLM Server, it's taken me 1:30 to 2:30 hours.
VLLM_ATTENTION_BACKEND=TORCH_SDPA lm_eval
--model local-chat-completions
--model_args model=Qwen/Qwen3-VL-30B-A3B-Instruct,base_url=http://localhost:8000/v1/chat/completions,num_concurrent=10,max_retries=3,tokenized_requests=False
--tasks chartqa
--batch_size auto
--apply_chat_template

Many thanks!

@tjtanaa
Copy link
Contributor

tjtanaa commented Oct 22, 2025

@lgeiger This issue will need to be debugged on AMD GPUs.

I remember @DarkLight1337 mentioned that on ROCm we might need the .contiguous(). I last verified on mi300x without this .contiguous() and does not seem to have problem. Maybe it is because I was testing with _Backend.FLASH_ATTN and _Backend.ROCM_AITER_FA. Could be that if we use _Backend.TORCH_SDPA, we need to cast .contiguous().

However, I am afraid I might have missed some edge cases that @JartX found.

@JartX
Copy link
Contributor

JartX commented Oct 22, 2025

With Contiguous
local-chat-completions (model=QWEN3VL,base_url=http://localhost:8000/v1/chat/completions,num_concurrent=10,max_retries=3,tokenized_requests=False), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
chartqa 0 none 0 anywhere_accuracy 0.8712 ± 0.0067
none 0 exact_match 0.6380 ± 0.0096
none 0 relaxed_accuracy 0.8604 ± 0.0069

Without Contiguos
local-chat-completions (model=QWEN3VL,base_url=http://localhost:8000/v1/chat/completions,num_concurrent=10,max_retries=3,tokenized_requests=False), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
chartqa 0 none 0 anywhere_accuracy 0.0408 ± 0.0040
none 0 exact_match 0.0272 ± 0.0033
none 0 relaxed_accuracy 0.0404 ± 0.0039

@lgeiger @tjtanaa
I would say that it varies quite a bit

@tjtanaa
Copy link
Contributor

tjtanaa commented Oct 22, 2025

@JartX which ViT attention backend did you test? TORCH_SDPA?

@JartX
Copy link
Contributor

JartX commented Oct 22, 2025

@tjtanaa yep, only can test torch_sdpa and you have to set in the source code to use, in own other pr of attn can change the code to force it

@JartX
Copy link
Contributor

JartX commented Oct 22, 2025

@tjtanaa #27190

albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
…ct#27106)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
@JartX
Copy link
Contributor

JartX commented Oct 24, 2025

@lgeiger @tjtanaa Hey guys, I see there are more reports of this problem. Could we try to fix it? Thanks a lot :)

@lgeiger
Copy link
Contributor Author

lgeiger commented Oct 24, 2025

@lgeiger @tjtanaa Hey guys, I see there are more reports of this problem. Could we try to fix it? Thanks a lot :)

I don't have access to an AMD GPU, so would be great if either @JartX or @tjtanaa could make a PR with a fix.

xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…ct#27106)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…ct#27106)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
@JartX
Copy link
Contributor

JartX commented Oct 24, 2025

@lgeiger @tjtanaa Hey guys, I see there are more reports of this problem. Could we try to fix it? Thanks a lot :)

I don't have access to an AMD GPU, so would be great if either @JartX or @tjtanaa could make a PR with a fix.

@lgeiger @tjtanaa
If it's absolutely necessary to remove the contiguous, we could filter by the attention backend, to use it or not. Another option would be to do a revert. I'd bet, if you've tried and it improves performance, the solution I just mentioned. Or simply add a conditional: if it's in on_gfx9, use the contiguous, because it will be using an attention backend higher than SDPA.

@tjtanaa
Copy link
Contributor

tjtanaa commented Oct 25, 2025

@DarkLight1337 @JartX

I have validated on AMD Instinct GPU. So generally on ROCm if we are using TORCH_SDPA backend we need to cast to contiguous().
However, we shouldn't revert the PR as it is only need when both conditions is on rocm and TORCH_SDPA.

        elif self.attn_backend == _Backend.TORCH_SDPA:
            # Execute attention entry by entry for speed & less VRAM.
            from vllm.platforms import current_platform
+            if current_platform.is_rocm():
+                q = q.contiguous()
+                k = k.contiguous()
+                v = v.contiguous()
            outputs = []
            for i in range(1, len(cu_seqlens)):
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()

So instead of placing at the usual place, we should only cast to contiguous in the TORCH SDPA path.

However, at current moment, ROCm will never use TORCH_SDPA. We also need this PR #27190 .

The current optimization is only done for qwen2_5_vl.py. However, I believe this Remove unnecessary .contiguous() calls could be applied to many other models.py file:

  • qwen2_5_vl.py (Done by this PR)
  • qwen2_vl.py
  • ernie45_vl.py
  • glm4_1v.py
  • keye.py

@tjtanaa
Copy link
Contributor

tjtanaa commented Oct 25, 2025

@DarkLight1337 @JartX

I have validated on AMD Instinct GPU. So generally on ROCm if we are using TORCH_SDPA backend we need to cast to contiguous(). However, we shouldn't revert the PR as it is only need when both conditions is on rocm and TORCH_SDPA.

        elif self.attn_backend == _Backend.TORCH_SDPA:
            # Execute attention entry by entry for speed & less VRAM.
            from vllm.platforms import current_platform
+            if current_platform.is_rocm():
+                q = q.contiguous()
+                k = k.contiguous()
+                v = v.contiguous()
            outputs = []
            for i in range(1, len(cu_seqlens)):
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()

So instead of placing at the usual place, we should only cast to contiguous in the TORCH SDPA path.

However, at current moment, ROCm will never use TORCH_SDPA. We also need this PR #27190 .

The current optimization is only done for qwen2_5_vl.py. However, I believe this Remove unnecessary .contiguous() calls could be applied to many other models.py file:

  • qwen2_5_vl.py (Done by this PR)
  • qwen2_vl.py
  • ernie45_vl.py
  • glm4_1v.py
  • keye.py

@JartX can add this fix to your PR #27190?

@JartX
Copy link
Contributor

JartX commented Oct 25, 2025

@tjtanaa
Of course, I was going to suggest it to you, I was reading and you just answered me hahaha, I'll add it and we'll move forward so they can pass the tests?:)

@JartX
Copy link
Contributor

JartX commented Oct 25, 2025

@tjtanaa You can also touch the PR, maintainers have access, so if that's the case, add the fix and we do codev :)

@tjtanaa
Copy link
Contributor

tjtanaa commented Oct 25, 2025

@JartX I don't have maintainer access. And it would be better if you could fix it on your end. Thanks a lot for finding out the issues. 😄

@JartX
Copy link
Contributor

JartX commented Oct 25, 2025

@tjtanaa So I was wrong, sorry, I thought so hahaha, that way we'd both appear at once. Also, do you mind if I put you as a coauthor of the commit?

@tjtanaa
Copy link
Contributor

tjtanaa commented Oct 25, 2025

@JartX Sure. Thank you so much for considering me as the coauthor of the commit.
I am mainly contributing to the AMD features and bugfix on Instinct GPU and some vision model related features.

@JartX
Copy link
Contributor

JartX commented Oct 25, 2025

@tjtanaa Of course, all together! 💪

0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…ct#27106)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…ct#27106)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants