Skip to content

Commit c171eca

Browse files
yewentao256xuebwang-amd
authored andcommitted
[Bugfix] Allow Only SDPA Backend for ViT on B200 for Qwen3-VL (vllm-project#25788)
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 6fb605a commit c171eca

File tree

2 files changed

+75
-51
lines changed

2 files changed

+75
-51
lines changed

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ def __init__(
274274
quant_config: Optional[QuantizationConfig] = None,
275275
prefix: str = "",
276276
use_data_parallel: bool = False,
277+
attn_backend: _Backend = _Backend.TORCH_SDPA,
278+
use_upstream_fa: bool = False,
277279
) -> None:
278280
super().__init__()
279281
# Per attention head and per partition values.
@@ -300,25 +302,8 @@ def __init__(
300302
quant_config=quant_config,
301303
prefix=f"{prefix}.proj",
302304
disable_tp=use_data_parallel)
303-
304-
# Detect attention implementation.
305-
self.attn_backend = get_vit_attn_backend(
306-
head_size=self.hidden_size_per_attention_head,
307-
dtype=torch.get_default_dtype())
308-
self.use_upstream_fa = False
309-
if self.attn_backend != _Backend.FLASH_ATTN and \
310-
check_upstream_fa_availability(
311-
torch.get_default_dtype()):
312-
self.attn_backend = _Backend.FLASH_ATTN
313-
self.use_upstream_fa = True
314-
315-
if self.attn_backend not in {
316-
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
317-
_Backend.ROCM_AITER_FA
318-
}:
319-
raise RuntimeError(
320-
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
321-
)
305+
self.attn_backend = attn_backend
306+
self.use_upstream_fa = use_upstream_fa
322307
self.is_flash_attn_backend = self.attn_backend in {
323308
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
324309
}
@@ -443,6 +428,8 @@ def __init__(
443428
quant_config: Optional[QuantizationConfig] = None,
444429
prefix: str = "",
445430
use_data_parallel: bool = False,
431+
attn_backend: _Backend = _Backend.TORCH_SDPA,
432+
use_upstream_fa: bool = False,
446433
) -> None:
447434
super().__init__()
448435
if norm_layer is None:
@@ -455,7 +442,9 @@ def __init__(
455442
projection_size=dim,
456443
quant_config=quant_config,
457444
prefix=f"{prefix}.attn",
458-
use_data_parallel=use_data_parallel)
445+
use_data_parallel=use_data_parallel,
446+
attn_backend=attn_backend,
447+
use_upstream_fa=use_upstream_fa)
459448
self.mlp = Qwen2_5_VisionMLP(dim,
460449
mlp_hidden_dim,
461450
act_fn=act_fn,
@@ -627,17 +616,35 @@ def __init__(
627616
head_dim = self.hidden_size // self.num_heads
628617
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
629618

619+
use_upstream_fa = False
620+
self.attn_backend = get_vit_attn_backend(
621+
head_size=head_dim, dtype=torch.get_default_dtype())
622+
if self.attn_backend != _Backend.FLASH_ATTN and \
623+
check_upstream_fa_availability(
624+
torch.get_default_dtype()):
625+
self.attn_backend = _Backend.FLASH_ATTN
626+
use_upstream_fa = True
627+
628+
if self.attn_backend not in {
629+
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
630+
_Backend.ROCM_AITER_FA
631+
}:
632+
raise RuntimeError(
633+
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
634+
)
635+
630636
self.blocks = nn.ModuleList([
631-
Qwen2_5_VisionBlock(dim=self.hidden_size,
632-
num_heads=self.num_heads,
633-
mlp_hidden_dim=vision_config.intermediate_size,
634-
act_fn=get_act_and_mul_fn(
635-
vision_config.hidden_act),
636-
norm_layer=norm_layer,
637-
quant_config=quant_config,
638-
prefix=f"{prefix}.blocks.{layer_idx}",
639-
use_data_parallel=use_data_parallel)
640-
for layer_idx in range(depth)
637+
Qwen2_5_VisionBlock(
638+
dim=self.hidden_size,
639+
num_heads=self.num_heads,
640+
mlp_hidden_dim=vision_config.intermediate_size,
641+
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
642+
norm_layer=norm_layer,
643+
quant_config=quant_config,
644+
prefix=f"{prefix}.blocks.{layer_idx}",
645+
use_data_parallel=use_data_parallel,
646+
attn_backend=self.attn_backend,
647+
use_upstream_fa=use_upstream_fa) for layer_idx in range(depth)
641648
])
642649
self.merger = Qwen2_5_VisionPatchMerger(
643650
d_model=vision_config.out_hidden_size,
@@ -648,12 +655,6 @@ def __init__(
648655
prefix=f"{prefix}.merger",
649656
use_data_parallel=use_data_parallel,
650657
)
651-
self.attn_backend = get_vit_attn_backend(
652-
head_size=head_dim, dtype=torch.get_default_dtype())
653-
if self.attn_backend != _Backend.FLASH_ATTN and \
654-
check_upstream_fa_availability(
655-
torch.get_default_dtype()):
656-
self.attn_backend = _Backend.FLASH_ATTN
657658

658659
@property
659660
def dtype(self) -> torch.dtype:

vllm/model_executor/models/qwen3_vl.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
PromptReplacement, PromptUpdate,
6464
PromptUpdateDetails)
6565
from vllm.multimodal.profiling import BaseDummyInputsBuilder
66-
from vllm.platforms import _Backend
66+
from vllm.platforms import _Backend, current_platform
6767
from vllm.sequence import IntermediateTensors
6868
from vllm.transformers_utils.config import uses_mrope
6969
from vllm.utils import is_list_of
@@ -158,6 +158,8 @@ def __init__(
158158
quant_config: Optional[QuantizationConfig] = None,
159159
prefix: str = "",
160160
use_data_parallel: bool = False,
161+
attn_backend: _Backend = _Backend.TORCH_SDPA,
162+
use_upstream_fa: bool = False,
161163
) -> None:
162164
super().__init__()
163165
if norm_layer is None:
@@ -170,7 +172,9 @@ def __init__(
170172
projection_size=dim,
171173
quant_config=quant_config,
172174
prefix=f"{prefix}.attn",
173-
use_data_parallel=use_data_parallel)
175+
use_data_parallel=use_data_parallel,
176+
attn_backend=attn_backend,
177+
use_upstream_fa=use_upstream_fa)
174178
self.mlp = Qwen3_VisionMLP(dim,
175179
mlp_hidden_dim,
176180
act_fn=act_fn,
@@ -287,19 +291,6 @@ def __init__(
287291
head_dim = self.hidden_size // self.num_heads
288292
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
289293

290-
self.blocks = nn.ModuleList([
291-
Qwen3_VisionBlock(
292-
dim=self.hidden_size,
293-
num_heads=self.num_heads,
294-
mlp_hidden_dim=vision_config.intermediate_size,
295-
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
296-
norm_layer=norm_layer,
297-
quant_config=quant_config,
298-
prefix=f"{prefix}.blocks.{layer_idx}",
299-
use_data_parallel=use_data_parallel)
300-
for layer_idx in range(vision_config.depth)
301-
])
302-
303294
self.merger = Qwen3_VisionPatchMerger(
304295
d_model=vision_config.out_hidden_size,
305296
context_dim=self.hidden_size,
@@ -325,10 +316,42 @@ def __init__(
325316

326317
self.attn_backend = get_vit_attn_backend(
327318
head_size=head_dim, dtype=torch.get_default_dtype())
319+
use_upstream_fa = False
328320
if self.attn_backend != _Backend.FLASH_ATTN and \
329321
check_upstream_fa_availability(
330322
torch.get_default_dtype()):
331323
self.attn_backend = _Backend.FLASH_ATTN
324+
use_upstream_fa = True
325+
326+
if self.attn_backend not in {
327+
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
328+
_Backend.ROCM_AITER_FA
329+
}:
330+
raise RuntimeError(
331+
f"Qwen3-VL does not support {self.attn_backend} backend now.")
332+
if current_platform.is_device_capability(
333+
100) and self.attn_backend != _Backend.TORCH_SDPA:
334+
# TODO(Roger/Wentao): remove this after FA
335+
# or XFORMERS's issue fixed on Blackwell
336+
logger.info_once("Qwen3-VL vision attention does not support "
337+
f"{self.attn_backend} backend on Blackwell now. "
338+
"Vision attention backend is set to TORCH_SDPA.")
339+
self.attn_backend = _Backend.TORCH_SDPA
340+
341+
self.blocks = nn.ModuleList([
342+
Qwen3_VisionBlock(
343+
dim=self.hidden_size,
344+
num_heads=self.num_heads,
345+
mlp_hidden_dim=vision_config.intermediate_size,
346+
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
347+
norm_layer=norm_layer,
348+
quant_config=quant_config,
349+
prefix=f"{prefix}.blocks.{layer_idx}",
350+
use_data_parallel=use_data_parallel,
351+
attn_backend=self.attn_backend,
352+
use_upstream_fa=use_upstream_fa)
353+
for layer_idx in range(vision_config.depth)
354+
])
332355

333356
@property
334357
def dtype(self) -> torch.dtype:

0 commit comments

Comments
 (0)