-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Description
Your current environment
pip show torch vllm flash-attn
Name: torch
Version: 2.8.0
Name: vllm
Version: 0.11.0
Name: flash_attn
Version: 2.8.3
🐛 Describe the bug
unit-test code as the follows,
when simple qwen3-0.6B can run; but qwen3-vl-4b not run
#coding=utf-8
"""
写单元测试来验证FA和VLLM的可用性和兼容性
"""
import torch
from flash_attn import flash_attn_func
import unittest
import vllm
# from vllm.attention.backends import get_attn_backend
class TestFA_VLLM(unittest.TestCase):
def testFA(self,):
# 检查CUDA是否可用及设备
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Current device: {torch.cuda.current_device()}")
print(f"Device name: {torch.cuda.get_device_name()}")
# 尝试创建一个简单的张量并移动到GPU
try:
q = torch.randn(1, 1, 16, 64, dtype=torch.float16, device='cuda')
k = torch.randn(1, 1, 16, 64, dtype=torch.float16, device='cuda')
v = torch.randn(1, 1, 16, 64, dtype=torch.float16, device='cuda')
output = flash_attn_func(q, k, v)
print("FlashAttention test passed!")
except Exception as e:
print(f"FlashAttention test failed: {e}")
def oriTestVLLM(self,):
# 打印当前使用的attention后端
print("Available CUDA devices:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name())
# 检查vLLM配置
print("vLLM version:", vllm.__version__)
# 尝试创建一个小模型来触发后端初始化
try:
from vllm import LLM
llm = LLM(model="Qwen/Qwen3-0.6B", max_model_len=256)
print("vLLM初始化成功!")
prompt = "这是一个测试提示。"
response = llm.generate(prompt)
print("rollout测试成功! 生成的文本:", response)
except Exception as e:
print(f"vLLM初始化失败: {e}")
def testVLLM(self,):
# 打印当前使用的attention后端
print("Available CUDA devices:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name())
# 尝试创建一个小模型来触发后端初始化
try:
MODEL_PATH = "Qwen/Qwen3-VL-4B-Instruct"
from vllm import LLM
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset # vLLM 内置工具,帮你把路径 → PIL
from vllm.assets.video import VideoAsset # 如果以后想加视频同理
# 随便用一张图就行
image_path = ""
from PIL import Image
image = Image.open(image_path)
# 方式 B:URL
# image = ImageAsset("image", "https://xxx.jpg").pil_image
# Qwen3-VL 要求的对话模板
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image}, # 图像字段
{"type": "text", "text": "请描述这张图片。"}
]
}
]
# 用 transformers 的 apply_chat_template 把 messages → 模型输入
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(MODEL_PATH)
prompt = tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# ---------- ④ 生成 ----------
sampling_params = SamplingParams(
temperature=0.7,
max_tokens=512,
stop_token_ids=[tok.eos_token_id, tok.convert_tokens_to_ids("<|im_end|>")]
)
llm = LLM(model=MODEL_PATH, max_model_len=4096,
limit_mm_per_prompt={"image": 1, "video": 0}, # 每张 prompt 最多 1 张图
dtype="bfloat16", # A100/H100 可开;消费卡用 "float16"
gpu_memory_utilization=0.9,)
print("vLLM初始化成功!")
outputs = llm.generate(
{"prompt": prompt, "multi_modal_data": {"image": image}}, # 关键:把图也传进去
sampling_params=sampling_params
)
response = outputs[0].outputs[0].text
print("rollout测试成功! 生成的文本:", response)
except Exception as e:
print(f"vLLM初始化失败: {e}")
if __name__ == "__main__":
unittest.main()error is :vllm/vllm_flash_attn/flash_attn_interface.py", line 233, in flash_attn_varlen_func [rank0]: out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( [rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1243, in call [rank0]: return self._op(*args, **kwargs) [rank0]: torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain.
Then, I review the code in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L375
it default set use_upstream_fa = False, when I change it to True, it works? the vllm version is 0.11.0
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.