Skip to content

Commit da3a941

Browse files
JartXlgeiger
authored andcommitted
[FIXBUG] Qwen3VL hallucinations without Contiguous on Torch.SDPA (vllm-project#27744)
Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: Lukas Geiger <lukas.geiger94@gmail.com>
1 parent bbaf8e9 commit da3a941

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,14 @@ def forward(
428428
)
429429
elif self.attn_backend == _Backend.TORCH_SDPA:
430430
# Execute attention entry by entry for speed & less VRAM.
431+
from vllm.platforms import current_platform
432+
433+
# Never remove the next contiguous logic
434+
# Without it, hallucinations occur with the backend
435+
if current_platform.is_rocm():
436+
q = q.contiguous()
437+
k = k.contiguous()
438+
v = v.contiguous()
431439
outputs = []
432440
for i in range(1, len(cu_seqlens)):
433441
start_idx = cu_seqlens[i - 1]

0 commit comments

Comments
 (0)