-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Models][QwenVL] Remove unnecessary .contiguous() calls
#27106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Models][QwenVL] Remove unnecessary .contiguous() calls
#27106
Conversation
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
1d05935 to
0940321
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to optimize the Qwen-VL models by removing what appear to be unnecessary .contiguous() calls. While this change might be safe for some attention backends like Torch SDPA or xFormers, it introduces a critical bug in the Flash Attention path. The rearrange operation used to prepare tensors for Flash Attention requires contiguous inputs, and removing the .contiguous() call will lead to a runtime error. My review adds comments to restore the necessary calls to ensure the model works correctly with all supported backends.
|
Evaluated on ROCm Server command: Evaluation score on chartqa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for improving this
…ct#27106) Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
…ct#27106) Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
|
@lgeiger Hi! With this PR Qwen/Qwen3-VL-30B-A3B-Instruct, an invoice is considered to be a cheesecake or a sushi restaurant. |
…lm-project#27106)" This reverts commit 5c2acb2.
|
@tjtanaa @DarkLight1337 Could you help me with this? Right now, I can only send an image to Qwen3 VL and get a consistent result by reverting the commit. |
|
@JartX I can't seem to reproduce your issue. I verified that on a Nvidia L40s GPU, I'm getting identical lm eval results with this PR and with it reverted for all backends with the I tested it using: Which returns the same results on main and with this PR reverted: |
|
@lgeiger Thank you for your quick response, it this happens when I send the image via request, for example via openwebui. |
|
@lgeiger my docker run command docker run -it --rm ROCM RDNA3 |
|
@JartX To double check, can you reproduce the lm eval results that I posted above or do you also see a regression with and without this PR? |
|
Using.. |
|
@lgeiger Many thanks! |
|
@lgeiger This issue will need to be debugged on AMD GPUs. I remember @DarkLight1337 mentioned that on ROCm we might need the However, I am afraid I might have missed some edge cases that @JartX found. |
|
With Contiguous
Without Contiguos
|
|
@JartX which ViT attention backend did you test? TORCH_SDPA? |
|
@tjtanaa yep, only can test torch_sdpa and you have to set in the source code to use, in own other pr of attn can change the code to force it |
…ct#27106) Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
…ct#27106) Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…ct#27106) Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
@lgeiger @tjtanaa |
|
I have validated on AMD Instinct GPU. So generally on ROCm if we are using TORCH_SDPA backend we need to cast to contiguous(). elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
+ if current_platform.is_rocm():
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (
rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()So instead of placing at the usual place, we should only cast to contiguous in the However, at current moment, ROCm will never use TORCH_SDPA. We also need this PR #27190 . The current optimization is only done for
|
|
|
@tjtanaa |
|
@tjtanaa You can also touch the PR, maintainers have access, so if that's the case, add the fix and we do codev :) |
|
@JartX I don't have maintainer access. And it would be better if you could fix it on your end. Thanks a lot for finding out the issues. 😄 |
|
@tjtanaa So I was wrong, sorry, I thought so hahaha, that way we'd both appear at once. Also, do you mind if I put you as a coauthor of the commit? |
|
@JartX Sure. Thank you so much for considering me as the coauthor of the commit. |
|
@tjtanaa Of course, all together! 💪 |
…ct#27106) Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…ct#27106) Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Purpose
This PR removes calls to
.contiguous()in theQwenVisionAttentionmodule. Tested with both xformers and flash attention backend.Test Plan
Test Result
Before
After