From 3aae4639b407af72e5cd8787a0427a33aa332dc4 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 26 Sep 2025 14:36:20 -0700 Subject: [PATCH 1/9] temple fix qwen3 vl issue Signed-off-by: yewentao256 --- vllm/model_executor/models/qwen3_vl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index ede477cde1a2..ac22cbaf1cbd 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -332,6 +332,9 @@ def __init__( check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN + assert self.attn_backend == _Backend.TORCH_SDPA, \ + f"Qwen3-VL does not support {self.attn_backend} backend now. " \ + f"Consider `export VLLM_ATTENTION_BACKEND=TORCH_SDPA` to enable it." @property def dtype(self) -> torch.dtype: From ab03f323e1a1dc8fe433f176547aa99aea41ec74 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 26 Sep 2025 15:12:28 -0700 Subject: [PATCH 2/9] update through comments Signed-off-by: yewentao256 --- vllm/model_executor/models/qwen3_vl.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index ac22cbaf1cbd..dd0765e5c8fa 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -66,7 +66,7 @@ PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend +from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm.utils import is_list_of @@ -332,9 +332,12 @@ def __init__( check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN - assert self.attn_backend == _Backend.TORCH_SDPA, \ - f"Qwen3-VL does not support {self.attn_backend} backend now. " \ - f"Consider `export VLLM_ATTENTION_BACKEND=TORCH_SDPA` to enable it." + if current_platform.is_device_capability( + 100) and self.attn_backend != _Backend.TORCH_SDPA: + raise NotImplementedError( + f"Qwen3-VL does not support {self.attn_backend} backend now. " + f"Consider `export VLLM_ATTENTION_BACKEND=TORCH_SDPA` .") + logger.info_once(f"Qwen3-VL attn_backend: {self.attn_backend}") @property def dtype(self) -> torch.dtype: From 24841ca2cfcb77ab91e152f4bf5831f7e5b5ae80 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 26 Sep 2025 15:25:43 -0700 Subject: [PATCH 3/9] warning and fallback Signed-off-by: yewentao256 --- vllm/model_executor/models/qwen3_vl.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index dd0765e5c8fa..0645abfc616d 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -332,11 +332,14 @@ def __init__( check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN + logger.debug_once("Upstream Flash Attention is available, " + "set attn_backend to FLASH_ATTN.") if current_platform.is_device_capability( 100) and self.attn_backend != _Backend.TORCH_SDPA: - raise NotImplementedError( - f"Qwen3-VL does not support {self.attn_backend} backend now. " - f"Consider `export VLLM_ATTENTION_BACKEND=TORCH_SDPA` .") + logger.warning_once( + f"Qwen3-VL does not support {self.attn_backend} backend " + "on Blackwell now. Set attn_backend to TORCH_SDPA.") + self.attn_backend = _Backend.TORCH_SDPA logger.info_once(f"Qwen3-VL attn_backend: {self.attn_backend}") @property From ca79b81d7a57e8428fafbc78420b9fd4c09572e7 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 26 Sep 2025 16:19:41 -0700 Subject: [PATCH 4/9] update other blocks Signed-off-by: yewentao256 --- vllm/model_executor/models/qwen2_5_vl.py | 27 +++++++++++++-------- vllm/model_executor/models/qwen3_vl.py | 31 +++++++++++++----------- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index bd6c0b162cb4..bab06d9ab64c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -277,6 +277,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, + forced_attn_backend: Optional[_Backend] = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -304,16 +305,20 @@ def __init__( prefix=f"{prefix}.proj", disable_tp=use_data_parallel) - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( - head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype()) + # detect attention implementation unless forced by caller self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + if forced_attn_backend is not None: + self.attn_backend = forced_attn_backend + self.use_upstream_fa = (forced_attn_backend == _Backend.FLASH_ATTN) + else: + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, @@ -446,6 +451,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, + forced_attn_backend: Optional[_Backend] = None, ) -> None: super().__init__() if norm_layer is None: @@ -458,7 +464,8 @@ def __init__( projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel) + use_data_parallel=use_data_parallel, + forced_attn_backend=forced_attn_backend) self.mlp = Qwen2_5_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 0645abfc616d..fcee3ad2abaf 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -161,6 +161,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, + forced_attn_backend: Optional[_Backend] = None, ) -> None: super().__init__() if norm_layer is None: @@ -173,7 +174,8 @@ def __init__( projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel) + use_data_parallel=use_data_parallel, + forced_attn_backend=forced_attn_backend) self.mlp = Qwen3_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, @@ -290,19 +292,6 @@ def __init__( head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList([ - Qwen3_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel) - for layer_idx in range(vision_config.depth) - ]) - self.merger = Qwen3_VisionPatchMerger( d_model=vision_config.out_hidden_size, context_dim=self.hidden_size, @@ -342,6 +331,20 @@ def __init__( self.attn_backend = _Backend.TORCH_SDPA logger.info_once(f"Qwen3-VL attn_backend: {self.attn_backend}") + self.blocks = nn.ModuleList([ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + forced_attn_backend=self.attn_backend) + for layer_idx in range(vision_config.depth) + ]) + @property def dtype(self) -> torch.dtype: return self.patch_embed.proj.weight.dtype From 2840a067394c6abcf6484079d85553ea796eb4f0 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 26 Sep 2025 16:26:57 -0700 Subject: [PATCH 5/9] add TODO Signed-off-by: yewentao256 --- vllm/model_executor/models/qwen3_vl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index fcee3ad2abaf..fbe6a7af32c2 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -325,6 +325,8 @@ def __init__( "set attn_backend to FLASH_ATTN.") if current_platform.is_device_capability( 100) and self.attn_backend != _Backend.TORCH_SDPA: + # TODO(Roger/Wentao): remove this after FA + # or XFORMERS's issue fixed on Blackwell logger.warning_once( f"Qwen3-VL does not support {self.attn_backend} backend " "on Blackwell now. Set attn_backend to TORCH_SDPA.") From 002d7b939754b844a3832dd7fbf9e5fd0c0e9f1b Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 26 Sep 2025 17:03:49 -0700 Subject: [PATCH 6/9] cleanup Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_5_vl.py | 61 +++++++++++------------- vllm/model_executor/models/qwen3_vl.py | 21 +++++--- 2 files changed, 42 insertions(+), 40 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index bab06d9ab64c..aae86d45ffd3 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -277,7 +277,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, - forced_attn_backend: Optional[_Backend] = None, + attn_backend: _Backend = _Backend.TORCH_SDPA, + use_upstream_fa: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. @@ -304,29 +305,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.proj", disable_tp=use_data_parallel) - - # detect attention implementation unless forced by caller - self.use_upstream_fa = False - if forced_attn_backend is not None: - self.attn_backend = forced_attn_backend - self.use_upstream_fa = (forced_attn_backend == _Backend.FLASH_ATTN) - else: - self.attn_backend = get_vit_attn_backend( - head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype()) - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True - - if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA - }: - raise RuntimeError( - f"Qwen2.5-VL does not support {self.attn_backend} backend now." - ) + self.attn_backend = attn_backend + self.use_upstream_fa = use_upstream_fa self.is_flash_attn_backend = self.attn_backend in { _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA } @@ -451,7 +431,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, - forced_attn_backend: Optional[_Backend] = None, + attn_backend: Optional[_Backend] = None, + use_upstream_fa: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -465,7 +446,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, - forced_attn_backend=forced_attn_backend) + attn_backend=attn_backend, + use_upstream_fa=use_upstream_fa) self.mlp = Qwen2_5_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, @@ -637,6 +619,23 @@ def __init__( head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + self.use_upstream_fa = False + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Qwen2.5-VL does not support {self.attn_backend} backend now." + ) + self.blocks = nn.ModuleList([ Qwen2_5_VisionBlock(dim=self.hidden_size, num_heads=self.num_heads, @@ -646,7 +645,9 @@ def __init__( norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel) + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=self.use_upstream_fa) for layer_idx in range(depth) ]) self.merger = Qwen2_5_VisionPatchMerger( @@ -658,12 +659,6 @@ def __init__( prefix=f"{prefix}.merger", use_data_parallel=use_data_parallel, ) - self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype()) - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index fbe6a7af32c2..6b8e13fd2d1c 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -317,21 +317,27 @@ def __init__( self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) + self.use_upstream_fa = False if self.attn_backend != _Backend.FLASH_ATTN and \ check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN - logger.debug_once("Upstream Flash Attention is available, " - "set attn_backend to FLASH_ATTN.") + self.use_upstream_fa = True + + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Qwen3-VL does not support {self.attn_backend} backend now.") if current_platform.is_device_capability( 100) and self.attn_backend != _Backend.TORCH_SDPA: # TODO(Roger/Wentao): remove this after FA # or XFORMERS's issue fixed on Blackwell - logger.warning_once( - f"Qwen3-VL does not support {self.attn_backend} backend " - "on Blackwell now. Set attn_backend to TORCH_SDPA.") + logger.info_once("Qwen3-VL vision attention does not support " + f"{self.attn_backend} backend on Blackwell now. " + "Vision attention backend is set to TORCH_SDPA.") self.attn_backend = _Backend.TORCH_SDPA - logger.info_once(f"Qwen3-VL attn_backend: {self.attn_backend}") self.blocks = nn.ModuleList([ Qwen3_VisionBlock( @@ -343,7 +349,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", use_data_parallel=use_data_parallel, - forced_attn_backend=self.attn_backend) + attn_backend=self.attn_backend, + use_upstream_fa=self.use_upstream_fa) for layer_idx in range(vision_config.depth) ]) From 7c9c46fe8a40040ee62170921562c69109faba27 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 26 Sep 2025 17:04:16 -0700 Subject: [PATCH 7/9] update Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 6b8e13fd2d1c..067446c6675f 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -161,7 +161,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, - forced_attn_backend: Optional[_Backend] = None, + attn_backend: _Backend = _Backend.TORCH_SDPA, + use_upstream_fa: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -175,7 +176,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, - forced_attn_backend=forced_attn_backend) + attn_backend=attn_backend, + use_upstream_fa=use_upstream_fa) self.mlp = Qwen3_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, From d9cf02e8ef21bad2aec84c945cf1ddf4b05d4f72 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 26 Sep 2025 17:11:18 -0700 Subject: [PATCH 8/9] cleanup Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_5_vl.py | 27 ++++++++++++------------ vllm/model_executor/models/qwen3_vl.py | 6 +++--- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index aae86d45ffd3..45f6c7ebbd54 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -619,14 +619,14 @@ def __init__( head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) - self.use_upstream_fa = False + use_upstream_fa = False self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) if self.attn_backend != _Backend.FLASH_ATTN and \ check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + use_upstream_fa = True if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, @@ -637,18 +637,17 @@ def __init__( ) self.blocks = nn.ModuleList([ - Qwen2_5_VisionBlock(dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=get_act_and_mul_fn( - vision_config.hidden_act), - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel, - attn_backend=self.attn_backend, - use_upstream_fa=self.use_upstream_fa) - for layer_idx in range(depth) + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=get_act_and_mul_fn(vision_config.hidden_act), + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa) for layer_idx in range(depth) ]) self.merger = Qwen2_5_VisionPatchMerger( d_model=vision_config.out_hidden_size, diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 067446c6675f..03c08694fa05 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -319,12 +319,12 @@ def __init__( self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) - self.use_upstream_fa = False + use_upstream_fa = False if self.attn_backend != _Backend.FLASH_ATTN and \ check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + use_upstream_fa = True if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, @@ -352,7 +352,7 @@ def __init__( prefix=f"{prefix}.blocks.{layer_idx}", use_data_parallel=use_data_parallel, attn_backend=self.attn_backend, - use_upstream_fa=self.use_upstream_fa) + use_upstream_fa=use_upstream_fa) for layer_idx in range(vision_config.depth) ]) From df6448695234c4ea829c6a8b4ca042d2e7501193 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 26 Sep 2025 17:13:28 -0700 Subject: [PATCH 9/9] update Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 45f6c7ebbd54..f4e210955604 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -431,7 +431,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: Optional[_Backend] = None, + attn_backend: _Backend = _Backend.TORCH_SDPA, use_upstream_fa: bool = False, ) -> None: super().__init__()