Skip to content

[Bug]: vllm inference on qwen3-vl when use_upstream_fa is False #28903

@hedes1992

Description

@hedes1992

Your current environment

pip show torch vllm flash-attn

Name: torch
Version: 2.8.0


Name: vllm
Version: 0.11.0

Name: flash_attn
Version: 2.8.3

🐛 Describe the bug

unit-test code as the follows,
when simple qwen3-0.6B can run; but qwen3-vl-4b not run

#coding=utf-8
"""
写单元测试来验证FA和VLLM的可用性和兼容性
"""

import torch
from flash_attn import flash_attn_func
import unittest
import vllm
# from vllm.attention.backends import get_attn_backend

class TestFA_VLLM(unittest.TestCase):
    def testFA(self,):
        # 检查CUDA是否可用及设备
        print(f"CUDA available: {torch.cuda.is_available()}")
        print(f"Current device: {torch.cuda.current_device()}")
        print(f"Device name: {torch.cuda.get_device_name()}")

        # 尝试创建一个简单的张量并移动到GPU
        try:
            q = torch.randn(1, 1, 16, 64, dtype=torch.float16, device='cuda')
            k = torch.randn(1, 1, 16, 64, dtype=torch.float16, device='cuda')
            v = torch.randn(1, 1, 16, 64, dtype=torch.float16, device='cuda')
            output = flash_attn_func(q, k, v)
            print("FlashAttention test passed!")
        except Exception as e:
            print(f"FlashAttention test failed: {e}")
    
    def oriTestVLLM(self,):
        # 打印当前使用的attention后端
        print("Available CUDA devices:", torch.cuda.device_count())
        print("Current device:", torch.cuda.current_device())
        print("Device name:", torch.cuda.get_device_name())

        # 检查vLLM配置
        print("vLLM version:", vllm.__version__)

        # 尝试创建一个小模型来触发后端初始化
        try:
            from vllm import LLM
            llm = LLM(model="Qwen/Qwen3-0.6B", max_model_len=256)
            print("vLLM初始化成功!")
            prompt = "这是一个测试提示。"
            response = llm.generate(prompt)
            print("rollout测试成功! 生成的文本:", response)
        except Exception as e:
            print(f"vLLM初始化失败: {e}")
    
    def testVLLM(self,):
        # 打印当前使用的attention后端
        print("Available CUDA devices:", torch.cuda.device_count())
        print("Current device:", torch.cuda.current_device())
        print("Device name:", torch.cuda.get_device_name())

        # 尝试创建一个小模型来触发后端初始化
        try:
            MODEL_PATH = "Qwen/Qwen3-VL-4B-Instruct"
            from vllm import LLM
            from vllm import LLM, SamplingParams
            from vllm.assets.image import ImageAsset           # vLLM 内置工具,帮你把路径 → PIL
            from vllm.assets.video import VideoAsset           # 如果以后想加视频同理

            # 随便用一张图就行
            image_path = ""
            from PIL import Image
            image = Image.open(image_path)
            # 方式 B:URL
            # image = ImageAsset("image", "https://xxx.jpg").pil_image

            # Qwen3-VL 要求的对话模板
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image},   # 图像字段
                        {"type": "text",  "text": "请描述这张图片。"}
                    ]
                }
            ]
            # 用 transformers 的 apply_chat_template 把 messages → 模型输入
            from transformers import AutoTokenizer
            tok = AutoTokenizer.from_pretrained(MODEL_PATH)
            prompt = tok.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            # ---------- ④ 生成 ----------
            sampling_params = SamplingParams(
                temperature=0.7,
                max_tokens=512,
                stop_token_ids=[tok.eos_token_id, tok.convert_tokens_to_ids("<|im_end|>")]
            )

            llm = LLM(model=MODEL_PATH, max_model_len=4096, 
                limit_mm_per_prompt={"image": 1, "video": 0},  # 每张 prompt 最多 1 张图
                dtype="bfloat16",            # A100/H100 可开;消费卡用 "float16"
                gpu_memory_utilization=0.9,)
            print("vLLM初始化成功!")
            

            outputs = llm.generate(
                {"prompt": prompt, "multi_modal_data": {"image": image}},  # 关键:把图也传进去
                sampling_params=sampling_params
            )

            response = outputs[0].outputs[0].text
            print("rollout测试成功! 生成的文本:", response)
        except Exception as e:
            print(f"vLLM初始化失败: {e}")

if __name__ == "__main__":
    unittest.main()

error is :vllm/vllm_flash_attn/flash_attn_interface.py", line 233, in flash_attn_varlen_func [rank0]: out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( [rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1243, in call [rank0]: return self._op(*args, **kwargs) [rank0]: torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain.

Then, I review the code in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L375

it default set use_upstream_fa = False, when I change it to True, it works? the vllm version is 0.11.0

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions