diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 240d912f10b6..07b7ef4e72fe 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -1125,6 +1125,10 @@
title: Qwen2Audio
- local: model_doc/qwen2_vl
title: Qwen2VL
+ - local: model_doc/qwen3_vl
+ title: Qwen3VL
+ - local: model_doc/qwen3_vl_moe
+ title: Qwen3VLMoe
- local: model_doc/sam2
title: SAM2
- local: model_doc/sam2_video
diff --git a/docs/source/en/model_doc/qwen3_vl.md b/docs/source/en/model_doc/qwen3_vl.md
new file mode 100644
index 000000000000..9e90363a1eba
--- /dev/null
+++ b/docs/source/en/model_doc/qwen3_vl.md
@@ -0,0 +1,117 @@
+
+*This model was released on None and added to Hugging Face Transformers on 2025-08-16.*
+
+
+
+# Qwen3-VL
+
+[Qwen3-VL](https://huggingface.co/papers/2502.13923) is a multimodal vision-language model series, encompassing both dense and MoE variants, as well as Instruct and Thinking versions. Building upon its predecessors, Qwen3-VL delivers significant improvements in visual understanding while maintaining strong pure text capabilities. Key architectural advancements include: enhanced MRope with interleaved layout for better spatial-temporal modeling, DeepStack integration to effectively leverage multi-level features from the Vision Transformer (ViT), and improved video understanding through text-based time alignment—evolving from T-RoPE to text timestamp alignment for more precise temporal grounding. These innovations collectively enable Qwen3-VL to achieve superior performance in complex multimodal tasks.
+
+Model usage
+
+
+
+
+```py
+import torch
+from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
+
+model = Qwen3VLForConditionalGeneration.from_pretrained(
+ "Qwen/Qwen3-VL",
+ dtype=torch.float16,
+ device_map="auto",
+ attn_implementation="sdpa"
+)
+processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL")
+messages = [
+ {
+ "role":"user",
+ "content":[
+ {
+ "type":"image",
+ "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ },
+ {
+ "type":"text",
+ "text":"Describe this image."
+ }
+ ]
+ }
+
+]
+
+inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt",
+)
+inputs.pop("token_type_ids", None)
+
+generated_ids = model.generate(**inputs, max_new_tokens=128)
+generated_ids_trimmed = [
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+]
+output_text = processor.batch_decode(
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+)
+print(output_text)
+```
+
+
+
+## Qwen3VLConfig
+
+[[autodoc]] Qwen3VLConfig
+
+## Qwen3VLTextConfig
+
+[[autodoc]] Qwen3VLTextConfig
+
+## Qwen3VLProcessor
+
+[[autodoc]] Qwen3VLProcessor
+
+## Qwen3VLVideoProcessor
+
+[[autodoc]] Qwen3VLVideoProcessor
+
+## Qwen3VLVisionModel
+
+[[autodoc]] Qwen3VLVisionModel
+ - forward
+
+## Qwen3VLTextModel
+
+[[autodoc]] Qwen3VLTextModel
+ - forward
+
+## Qwen3VLModel
+
+[[autodoc]] Qwen3VLModel
+ - forward
+
+## Qwen3VLForConditionalGeneration
+
+[[autodoc]] Qwen3VLForConditionalGeneration
+ - forward
diff --git a/docs/source/en/model_doc/qwen3_vl_moe.md b/docs/source/en/model_doc/qwen3_vl_moe.md
new file mode 100644
index 000000000000..76d046efff2d
--- /dev/null
+++ b/docs/source/en/model_doc/qwen3_vl_moe.md
@@ -0,0 +1,109 @@
+
+*This model was released on None and added to Hugging Face Transformers on 2025-08-17.*
+
+
+
+# Qwen3-VL-Moe
+
+[Qwen3-VL](https://huggingface.co/papers/2502.13923) is a multimodal vision-language model series, encompassing both dense and MoE variants, as well as Instruct and Thinking versions. Building upon its predecessors, Qwen3-VL delivers significant improvements in visual understanding while maintaining strong pure text capabilities. Key architectural advancements include: enhanced MRope with interleaved layout for better spatial-temporal modeling, DeepStack integration to effectively leverage multi-level features from the Vision Transformer (ViT), and improved video understanding through text-based time alignment—evolving from T-RoPE to text timestamp alignment for more precise temporal grounding. These innovations collectively enable Qwen3-VL to achieve superior performance in complex multimodal tasks.
+
+Model usage
+
+
+
+
+```py
+import torch
+from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
+
+model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
+ "Qwen/Qwen3-VL-Moe",
+ dtype=torch.float16,
+ device_map="auto",
+ attn_implementation="sdpa"
+)
+processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-Moe")
+messages = [
+ {
+ "role":"user",
+ "content":[
+ {
+ "type":"image",
+ "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ },
+ {
+ "type":"text",
+ "text":"Describe this image."
+ }
+ ]
+ }
+
+]
+
+inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt",
+)
+inputs.pop("token_type_ids", None)
+
+generated_ids = model.generate(**inputs, max_new_tokens=128)
+generated_ids_trimmed = [
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+]
+output_text = processor.batch_decode(
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+)
+print(output_text)
+```
+
+
+
+## Qwen3VLMoeConfig
+
+[[autodoc]] Qwen3VLMoeConfig
+
+## Qwen3VLMoeTextConfig
+
+[[autodoc]] Qwen3VLMoeTextConfig
+
+## Qwen3VLMoeVisionModel
+
+[[autodoc]] Qwen3VLMoeVisionModel
+ - forward
+
+## Qwen3VLMoeTextModel
+
+[[autodoc]] Qwen3VLMoeTextModel
+ - forward
+
+## Qwen3VLMoeModel
+
+[[autodoc]] Qwen3VLMoeModel
+ - forward
+
+## Qwen3VLMoeForConditionalGeneration
+
+[[autodoc]] Qwen3VLMoeForConditionalGeneration
+ - forward
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 1e70e0e7b4d7..9a631255ff9b 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -278,6 +278,8 @@
from .qwen3 import *
from .qwen3_moe import *
from .qwen3_next import *
+ from .qwen3_vl import *
+ from .qwen3_vl_moe import *
from .rag import *
from .recurrent_gemma import *
from .reformer import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 0d6981d685ed..c3aa38de4b31 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -325,6 +325,10 @@
("qwen3", "Qwen3Config"),
("qwen3_moe", "Qwen3MoeConfig"),
("qwen3_next", "Qwen3NextConfig"),
+ ("qwen3_vl", "Qwen3VLConfig"),
+ ("qwen3_vl_moe", "Qwen3VLMoeConfig"),
+ ("qwen3_vl_moe_text", "Qwen3VLMoeTextConfig"),
+ ("qwen3_vl_text", "Qwen3VLTextConfig"),
("rag", "RagConfig"),
("realm", "RealmConfig"),
("recurrent_gemma", "RecurrentGemmaConfig"),
@@ -763,6 +767,10 @@
("qwen3", "Qwen3"),
("qwen3_moe", "Qwen3MoE"),
("qwen3_next", "Qwen3Next"),
+ ("qwen3_vl", "Qwen3VL"),
+ ("qwen3_vl_moe", "Qwen3VLMoe"),
+ ("qwen3_vl_moe_text", "Qwen3VLMoe"),
+ ("qwen3_vl_text", "Qwen3VL"),
("rag", "RAG"),
("realm", "REALM"),
("recurrent_gemma", "RecurrentGemma"),
@@ -950,6 +958,8 @@
("internvl_vision", "internvl"),
("qwen2_5_vl_text", "qwen2_5_vl"),
("qwen2_vl_text", "qwen2_vl"),
+ ("qwen3_vl_text", "qwen3_vl"),
+ ("qwen3_vl_moe_text", "qwen3_vl_moe"),
("sam_vision_model", "sam"),
("sam2_vision_model", "sam2"),
("sam2_hiera_det_model", "sam2"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 7d07ca6dc7d6..193e8f8fd940 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -156,6 +156,7 @@
("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")),
("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
+ ("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index a4b9434f24b9..2109e487328e 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -319,6 +319,10 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("qwen3", "Qwen3Model"),
("qwen3_moe", "Qwen3MoeModel"),
("qwen3_next", "Qwen3NextModel"),
+ ("qwen3_vl", "Qwen3VLModel"),
+ ("qwen3_vl_moe", "Qwen3VLMoeModel"),
+ ("qwen3_vl_moe_text", "Qwen3VLMoeTextModel"),
+ ("qwen3_vl_text", "Qwen3VLTextModel"),
("recurrent_gemma", "RecurrentGemmaModel"),
("reformer", "ReformerModel"),
("regnet", "RegNetModel"),
@@ -972,6 +976,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("pix2struct", "Pix2StructForConditionalGeneration"),
("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
+ ("qwen3_vl", "Qwen3VLForConditionalGeneration"),
+ ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),
("video_llava", "VideoLlavaForConditionalGeneration"),
("vipllava", "VipLlavaForConditionalGeneration"),
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
@@ -1026,6 +1032,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("pixtral", "LlavaForConditionalGeneration"),
("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
+ ("qwen3_vl", "Qwen3VLForConditionalGeneration"),
+ ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),
("shieldgemma2", "Gemma3ForConditionalGeneration"),
("smolvlm", "SmolVLMForConditionalGeneration"),
("udop", "UdopForConditionalGeneration"),
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index d8db58cb7b1f..13583c55002f 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -120,6 +120,8 @@
("qwen2_5_vl", "Qwen2_5_VLProcessor"),
("qwen2_audio", "Qwen2AudioProcessor"),
("qwen2_vl", "Qwen2VLProcessor"),
+ ("qwen3_vl", "Qwen3VLProcessor"),
+ ("qwen3_vl_moe", "Qwen3VLProcessor"),
("sam", "SamProcessor"),
("sam2", "Sam2Processor"),
("sam_hq", "SamHQProcessor"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 688faf00c4ea..0ef450f45cb9 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -583,6 +583,8 @@
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
),
),
+ ("qwen3_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
+ ("qwen3_vl_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
("rag", ("RagTokenizer", None)),
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
(
diff --git a/src/transformers/models/auto/video_processing_auto.py b/src/transformers/models/auto/video_processing_auto.py
index b9a5c2204fd1..551de914626e 100644
--- a/src/transformers/models/auto/video_processing_auto.py
+++ b/src/transformers/models/auto/video_processing_auto.py
@@ -56,6 +56,8 @@
("qwen2_5_omni", "Qwen2VLVideoProcessor"),
("qwen2_5_vl", "Qwen2VLVideoProcessor"),
("qwen2_vl", "Qwen2VLVideoProcessor"),
+ ("qwen3_vl", "Qwen3VLVideoProcessor"),
+ ("qwen3_vl_moe", "Qwen3VLVideoProcessor"),
("sam2_video", "Sam2VideoVideoProcessor"),
("smolvlm", "SmolVLMVideoProcessor"),
("video_llava", "VideoLlavaVideoProcessor"),
diff --git a/src/transformers/models/qwen3_vl/__init__.py b/src/transformers/models/qwen3_vl/__init__.py
new file mode 100644
index 000000000000..e37161a2e415
--- /dev/null
+++ b/src/transformers/models/qwen3_vl/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_qwen3_vl import *
+ from .modeling_qwen3_vl import *
+ from .processing_qwen3_vl import *
+ from .video_processing_qwen3_vl import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/qwen3_vl/configuration_qwen3_vl.py b/src/transformers/models/qwen3_vl/configuration_qwen3_vl.py
new file mode 100644
index 000000000000..132ffa8be150
--- /dev/null
+++ b/src/transformers/models/qwen3_vl/configuration_qwen3_vl.py
@@ -0,0 +1,287 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen3_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class Qwen3VLVisionConfig(PretrainedConfig):
+ model_type = "qwen3_vl"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=27,
+ hidden_size=1152,
+ hidden_act="gelu_pytorch_tanh",
+ intermediate_size=4304,
+ num_heads=16,
+ in_channels=3,
+ patch_size=16,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ out_hidden_size=3584,
+ num_position_embeddings=2304,
+ deepstack_visual_indexes=[8, 16, 24],
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.out_hidden_size = out_hidden_size
+ self.num_position_embeddings = num_position_embeddings
+ self.initializer_range = initializer_range
+ self.deepstack_visual_indexes = deepstack_visual_indexes
+
+
+class Qwen3VLTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a
+ Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151936):
+ Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen3VLModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 22016):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 32):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ head_dim (`int`, *optional*, defaults to 128):
+ The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 128000):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 5000000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+
+ ```python
+ >>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig
+
+ >>> # Initializing a Qwen3VL style configuration
+ >>> configuration = Qwen3VLTextConfig()
+
+ >>> # Initializing a model from the Qwen3-VL-7B style configuration
+ >>> model = Qwen3VLTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_vl_text"
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ vocab_size=151936,
+ hidden_size=4096,
+ intermediate_size=22016,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ head_dim=128,
+ hidden_act="silu",
+ max_position_embeddings=128000,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=5000000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.head_dim = head_dim
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class Qwen3VLConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a
+ Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 151655):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 151656):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 151652):
+ The start token index to encode the image prompt.
+ vision_end_token_id (`int`, *optional*, defaults to 151653):
+ The end token index to encode the image prompt.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie the word embeddings.
+
+ ```python
+ >>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig
+
+ >>> # Initializing a Qwen3-VL style configuration
+ >>> configuration = Qwen3VLConfig()
+
+ >>> # Initializing a model from the Qwen3-VL-4B style configuration
+ >>> model = Qwen3VLForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_vl"
+ sub_configs = {"vision_config": Qwen3VLVisionConfig, "text_config": Qwen3VLTextConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=151655,
+ video_token_id=151656,
+ vision_start_token_id=151652,
+ vision_end_token_id=151653,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ self.text_config = self.sub_configs["text_config"]()
+
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+ super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
+
+
+__all__ = ["Qwen3VLConfig", "Qwen3VLTextConfig"]
diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
new file mode 100644
index 000000000000..a18366a2a534
--- /dev/null
+++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
@@ -0,0 +1,1568 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen3_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig
+
+
+class Qwen3VLVisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
+
+
+class Qwen3VLVisionPatchEmbed(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.temporal_patch_size = config.temporal_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.hidden_size
+
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Qwen3VLVisionRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class Qwen3VLVisionPatchMerger(nn.Module):
+ def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
+ self.use_postshuffle_norm = use_postshuffle_norm
+ self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
+ self.act_fn = nn.GELU()
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
+ return x
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_q_dtype = q.dtype
+ orig_k_dtype = k.dtype
+ q, k = q.float(), k.float()
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ q_embed = q_embed.to(orig_q_dtype)
+ k_embed = k_embed.to(orig_k_dtype)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Qwen3VLVisionAttention(nn.Module):
+ def __init__(self, config: Qwen3VLVisionConfig) -> None:
+ super().__init__()
+ self.dim = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
+ self.proj = nn.Linear(self.dim, self.dim)
+ self.scaling = self.head_dim**-0.5
+ self.config = config
+ self.attention_dropout = 0.0
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ query_states, key_states, value_states = (
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ )
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if self.config._attn_implementation == "flash_attention_2":
+ # Flash Attention 2: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs = [
+ attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )[0]
+ for q, k, v in zip(*splits)
+ ]
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen3VLVisionBlock(GradientCheckpointingLayer):
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
+ super().__init__()
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.attn = Qwen3VLVisionAttention(config=config)
+ self.mlp = Qwen3VLVisionMLP(config=config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Qwen3VLTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Qwen3VLTextConfig, device=None):
+ super().__init__()
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", "default")
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
+
+ def apply_interleaved_mrope(self, freqs, mrope_section):
+ """Apply interleaved MRoPE to 3D rotary embeddings.
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
+ args:
+ x: (3, bs, seq_len, head_dim // 2)
+ mrope_section: (3,)
+ returns:
+ x_t: (bs, seq_len, head_dim // 2)
+ """
+ freqs_t = freqs[0] # just overwrite the first dimension T
+ for dim, offset in enumerate((1, 2), start=1): # H, W
+ length = mrope_section[dim] * 3
+ idx = slice(offset, length, 3)
+ freqs_t[..., idx] = freqs[dim, ..., idx]
+ return freqs_t
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Qwen3VL has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ if position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Qwen3VLTextRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
+ """
+ Qwen3VLTextRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class Qwen3VLTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
+ self.k_norm = Qwen3VLTextRMSNorm(
+ self.head_dim, eps=config.rms_norm_eps
+ ) # thus post q_norm does not need reshape
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Qwen3VLTextMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class Qwen3VLTextDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = Qwen3VLTextAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = Qwen3VLTextMLP(config)
+ self.input_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class Qwen3VLModelOutputWithPast(ModelOutput):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[list[torch.FloatTensor]] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+@auto_docstring
+class Qwen3VLPreTrainedModel(PreTrainedModel):
+ config: Qwen3VLConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": Qwen3VLTextDecoderLayer,
+ "attentions": Qwen3VLTextAttention,
+ }
+
+
+class Qwen3VLVisionModel(Qwen3VLPreTrainedModel):
+ config: Qwen3VLVisionConfig
+ _no_split_modules = ["Qwen3VLVisionBlock"]
+
+ def __init__(self, config, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
+
+ self.patch_embed = Qwen3VLVisionPatchEmbed(
+ config=config,
+ )
+
+ self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)])
+ self.merger = Qwen3VLVisionPatchMerger(
+ config=config,
+ use_postshuffle_norm=False,
+ )
+
+ self.deepstack_visual_indexes = config.deepstack_visual_indexes
+ self.deepstack_merger_list = nn.ModuleList(
+ [
+ Qwen3VLVisionPatchMerger(
+ config=config,
+ use_postshuffle_norm=True,
+ )
+ for _ in range(len(config.deepstack_visual_indexes))
+ ]
+ )
+
+ self.gradient_checkpointing = False
+
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+ merge_size = self.spatial_merge_size
+
+ max_hw = int(grid_thw[:, 1:].max().item())
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
+ device = freq_table.device
+
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
+
+ offset = 0
+ for num_frames, height, width in grid_thw:
+ merged_h, merged_w = height // merge_size, width // merge_size
+
+ block_rows = torch.arange(merged_h, device=device) # block row indices
+ block_cols = torch.arange(merged_w, device=device) # block col indices
+ intra_row = torch.arange(merge_size, device=device) # intra-block row offsets
+ intra_col = torch.arange(merge_size, device=device) # intra-block col offsets
+
+ # Compute full-resolution positions
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
+
+ row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+ col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+
+ coords = torch.stack((row_idx, col_idx), dim=-1)
+
+ if num_frames > 1:
+ coords = coords.repeat(num_frames, 1)
+
+ num_tokens = coords.shape[0]
+ pos_ids[offset : offset + num_tokens] = coords
+ offset += num_tokens
+
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
+ embeddings = embeddings.flatten(1)
+ return embeddings
+
+ def fast_pos_embed_interpolate(self, grid_thw):
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.int()
+ w_idxs_floor = w_idxs.int()
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ base_h = h_idxs_floor * self.num_grid_per_side
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
+
+ indices = [
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+ ]
+
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
+
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
+ weight_tensor = torch.tensor(
+ weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
+ )
+ pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+ patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
+
+ patch_pos_embeds_permute = []
+ merge_size = self.config.spatial_merge_size
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+ pos_embed = pos_embed.repeat(t, 1)
+ pos_embed = (
+ pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
+ .permute(0, 1, 3, 2, 4, 5)
+ .flatten(0, 4)
+ )
+ patch_pos_embeds_permute.append(pos_embed)
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
+ return patch_pos_embeds
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ deepstack_feature_lists = []
+ for layer_num, blk in enumerate(self.blocks):
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ if layer_num in self.deepstack_visual_indexes:
+ deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
+ hidden_states
+ )
+ deepstack_feature_lists.append(deepstack_feature)
+
+ hidden_states = self.merger(hidden_states)
+
+ return hidden_states, deepstack_feature_lists
+
+
+@auto_docstring(
+ custom_intro=(
+ "Text part of Qwen3VL, "
+ "not a pure text-only model, as DeepStack integrates visual features into the early hidden states."
+ )
+)
+class Qwen3VLTextModel(Qwen3VLPreTrainedModel):
+ config: Qwen3VLTextConfig
+ _no_split_modules = ["Qwen3VLTextDecoderLayer"]
+
+ def __init__(self, config: Qwen3VLTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Qwen3VLTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ # args for deepstack
+ visual_pos_masks: Optional[torch.Tensor] = None,
+ deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ r"""
+ visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
+ The mask of the visual positions.
+ deepstack_visual_embeds (`list[torch.Tensor]`, *optional*):
+ The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim).
+ The feature is extracted from the different visual encoder layers, and fed to the decoder
+ hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334).
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ text_position_ids = position_ids[0]
+
+ attention_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=text_position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ for layer_idx, decoder_layer in enumerate(self.layers):
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=text_position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = layer_outputs
+
+ # add visual features to the hidden states of first several layers
+ if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
+ hidden_states = self._deepstack_process(
+ hidden_states,
+ visual_pos_masks,
+ deepstack_visual_embeds[layer_idx],
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+ def _deepstack_process(
+ self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor
+ ):
+ visual_pos_masks = visual_pos_masks.to(hidden_states.device)
+ visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
+ local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
+ hidden_states[visual_pos_masks, :] = local_this
+ return hidden_states
+
+
+@auto_docstring
+class Qwen3VLModel(Qwen3VLPreTrainedModel):
+ base_model_prefix = ""
+ _checkpoint_conversion_mapping = {}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Qwen3VLConfig
+ _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Qwen3VLVisionModel._from_config(config.vision_config)
+ self.language_model = Qwen3VLTextModel._from_config(config.text_config)
+ self.rope_deltas = None # cache rope_deltas here
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_rope_index(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids."""
+
+ # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split
+ if video_grid_thw is not None:
+ video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
+ video_grid_thw[:, 0] = 1
+
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_index, video_index = 0, 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+ for i, input_ids in enumerate(total_input_ids):
+ input_ids = input_ids[attention_mask[i] == 1]
+ image_nums, video_nums = 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
+
+ Args:
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ # Same implementation as for images
+ return self.get_image_features(pixel_values_videos, video_grid_thw)
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype)
+ image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ return image_embeds, deepstack_image_embeds
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: Optional[torch.FloatTensor] = None,
+ video_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
+ raise ValueError(
+ f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
+ )
+
+ return special_image_mask, special_video_mask
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen3VLModelOutputWithPast]:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ image_mask = None
+ video_mask = None
+
+ if pixel_values is not None:
+ image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw)
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ visual_pos_masks = None
+ deepstack_visual_embeds = None
+ if image_mask is not None and video_mask is not None:
+ # aggregate visual_pos_masks and deepstack_visual_embeds
+ image_mask = image_mask[..., 0]
+ video_mask = video_mask[..., 0]
+ visual_pos_masks = image_mask | video_mask
+ deepstack_visual_embeds = []
+ image_mask_joint = image_mask[visual_pos_masks]
+ video_mask_joint = video_mask[visual_pos_masks]
+ for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
+ embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
+ embed_joint[image_mask_joint, :] = img_embed
+ embed_joint[video_mask_joint, :] = vid_embed
+ deepstack_visual_embeds.append(embed_joint)
+ elif image_mask is not None:
+ image_mask = image_mask[..., 0]
+ visual_pos_masks = image_mask
+ deepstack_visual_embeds = deepstack_image_embeds
+ elif video_mask is not None:
+ video_mask = video_mask[..., 0]
+ visual_pos_masks = video_mask
+ deepstack_visual_embeds = deepstack_video_embeds
+
+ if position_ids is None:
+ attention_mask_tensor = (
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
+ )
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
+ # Only apply conversion for floating point tensors (inverted masks)
+ if attention_mask_tensor.dtype.is_floating_point:
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
+
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do asssisted decoding
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
+ (input_ids is not None and input_ids.shape[1] != 1)
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
+ )
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
+ (cache_position is not None and cache_position[0] == 0)
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
+ )
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask=attention_mask_tensor,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ delta = (
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
+ if cache_position is not None
+ else 0
+ )
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ if cache_position is not None: # otherwise `deltas` is an int `0`
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ visual_pos_masks=visual_pos_masks,
+ deepstack_visual_embeds=deepstack_visual_embeds,
+ **kwargs,
+ )
+
+ return Qwen3VLModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Qwen3VL causal language model (or autoregressive) outputs.
+ """
+)
+class Qwen3VLCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[list[torch.FloatTensor]] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = ["lm_head.weight"]
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Qwen3VLConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen3VLModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ return self.model.get_image_features(pixel_values, image_grid_thw)
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def visual(self):
+ return self.model.visual
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+
+ Example:
+ TODO: Add example
+ """
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return Qwen3VLCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ # Qwen3VL position_ids are prepareed with rope_deltas in forward
+ model_inputs["position_ids"] = None
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+
+ return model_inputs
+
+ def _get_image_nums_and_video_nums(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Returns:
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
+ """
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+
+ if inputs_embeds is not None:
+ vision_start_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ image_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ video_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ else:
+ vision_start_mask = input_ids == vision_start_token_id
+ image_mask = input_ids == image_token_id
+ video_mask = input_ids == video_token_id
+
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
+
+ return image_nums, video_nums
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: Optional[torch.LongTensor] = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ # Overwritten -- Support for expanding tensors without a batch size dimension
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
+ # image_grid_thw.shape[0] is sum(num_images for samples)
+
+ if expand_size == 1:
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
+
+ def _expand_dict_for_generation_visual(dict_to_expand):
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
+ )
+
+ def _repeat_interleave_samples(x, lengths, repeat_times):
+ samples = torch.split(x, lengths)
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
+ return result
+
+ for key in dict_to_expand:
+ if key == "pixel_values":
+ # split images into samples
+ samples = torch.split(image_grid_thw, list(image_nums))
+ # compute the sequence length of images for each sample
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "image_grid_thw":
+ # get the num of images for each sample
+ lengths = list(image_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "pixel_values_videos":
+ samples = torch.split(video_grid_thw, list(video_nums))
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "video_grid_thw":
+ lengths = list(video_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "second_per_grid_ts":
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
+ )
+ return dict_to_expand
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ and key not in visual_keys
+ ):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
+
+
+__all__ = [
+ "Qwen3VLVisionModel",
+ "Qwen3VLForConditionalGeneration",
+ "Qwen3VLModel",
+ "Qwen3VLPreTrainedModel",
+ "Qwen3VLTextModel",
+]
diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py
new file mode 100644
index 000000000000..ae608e81a05d
--- /dev/null
+++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py
@@ -0,0 +1,1472 @@
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen3-VL model."""
+
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...configuration_utils import PretrainedConfig
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update, rope_config_validation
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
+from ...utils.generic import check_model_inputs
+from ...video_utils import VideoInput
+from ..qwen2_5_vl.modeling_qwen2_5_vl import (
+ Qwen2_5_VLCausalLMOutputWithPast,
+ Qwen2_5_VLForConditionalGeneration,
+ Qwen2_5_VLModel,
+ Qwen2_5_VLVisionBlock,
+)
+from ..qwen2_vl.modeling_qwen2_vl import (
+ PatchEmbed,
+ Qwen2VLModelOutputWithPast,
+ Qwen2VLPreTrainedModel,
+ TransformersKwargs,
+ VisionAttention,
+ VisionRotaryEmbedding,
+)
+from ..qwen2_vl.processing_qwen2_vl import Qwen2VLImagesKwargs, Qwen2VLProcessor
+from ..qwen3.modeling_qwen3 import (
+ Qwen3Attention,
+ Qwen3DecoderLayer,
+ Qwen3Model,
+ apply_rotary_pos_emb,
+ eager_attention_forward,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen3VLVisionConfig(PretrainedConfig):
+ model_type = "qwen3_vl"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=27,
+ hidden_size=1152,
+ hidden_act="gelu_pytorch_tanh",
+ intermediate_size=4304,
+ num_heads=16,
+ in_channels=3,
+ patch_size=16,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ out_hidden_size=3584,
+ num_position_embeddings=2304,
+ deepstack_visual_indexes=[8, 16, 24],
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.out_hidden_size = out_hidden_size
+ self.num_position_embeddings = num_position_embeddings
+ self.initializer_range = initializer_range
+ self.deepstack_visual_indexes = deepstack_visual_indexes
+
+
+class Qwen3VLTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a
+ Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151936):
+ Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen3VLModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 22016):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 32):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ head_dim (`int`, *optional*, defaults to 128):
+ The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 128000):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 5000000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+
+ ```python
+ >>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig
+
+ >>> # Initializing a Qwen3VL style configuration
+ >>> configuration = Qwen3VLTextConfig()
+
+ >>> # Initializing a model from the Qwen3-VL-7B style configuration
+ >>> model = Qwen3VLTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_vl_text"
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ vocab_size=151936,
+ hidden_size=4096,
+ intermediate_size=22016,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ head_dim=128,
+ hidden_act="silu",
+ max_position_embeddings=128000,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=5000000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.head_dim = head_dim
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class Qwen3VLConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a
+ Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 151655):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 151656):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 151652):
+ The start token index to encode the image prompt.
+ vision_end_token_id (`int`, *optional*, defaults to 151653):
+ The end token index to encode the image prompt.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie the word embeddings.
+
+ ```python
+ >>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig
+
+ >>> # Initializing a Qwen3-VL style configuration
+ >>> configuration = Qwen3VLConfig()
+
+ >>> # Initializing a model from the Qwen3-VL-4B style configuration
+ >>> model = Qwen3VLForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_vl"
+ sub_configs = {"vision_config": Qwen3VLVisionConfig, "text_config": Qwen3VLTextConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=151655,
+ video_token_id=151656,
+ vision_start_token_id=151652,
+ vision_end_token_id=151653,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ self.text_config = self.sub_configs["text_config"]()
+
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+ super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
+
+
+class Qwen3VLVisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
+
+
+class Qwen3VLVisionPatchEmbed(PatchEmbed):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.temporal_patch_size = config.temporal_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.hidden_size
+
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
+
+
+class Qwen3VLVisionRotaryEmbedding(VisionRotaryEmbedding):
+ pass
+
+
+class Qwen3VLVisionPatchMerger(nn.Module):
+ def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
+ self.use_postshuffle_norm = use_postshuffle_norm
+ self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
+ self.act_fn = nn.GELU()
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
+ return x
+
+
+class Qwen3VLVisionAttention(VisionAttention):
+ def __init__(self, config: Qwen3VLVisionConfig) -> None:
+ super().__init__()
+ self.dim = config.hidden_size
+
+
+class Qwen3VLVisionBlock(Qwen2_5_VLVisionBlock):
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
+ super().__init__()
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.attn = Qwen3VLVisionAttention(config=config)
+ self.mlp = Qwen3VLVisionMLP(config=config)
+
+
+class Qwen3VLTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Qwen3VLTextConfig, device=None):
+ super().__init__()
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", "default")
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
+
+ def apply_interleaved_mrope(self, freqs, mrope_section):
+ """Apply interleaved MRoPE to 3D rotary embeddings.
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
+ args:
+ x: (3, bs, seq_len, head_dim // 2)
+ mrope_section: (3,)
+ returns:
+ x_t: (bs, seq_len, head_dim // 2)
+ """
+ freqs_t = freqs[0] # just overwrite the first dimension T
+ for dim, offset in enumerate((1, 2), start=1): # H, W
+ length = mrope_section[dim] * 3
+ idx = slice(offset, length, 3)
+ freqs_t[..., idx] = freqs[dim, ..., idx]
+ return freqs_t
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Qwen3VL has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ if position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class Qwen3VLTextAttention(Qwen3Attention):
+ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ del self.sliding_window
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Qwen3VLTextDecoderLayer(Qwen3DecoderLayer):
+ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ del self.attention_type
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ return super().forward(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+
+class Qwen3VLModelOutputWithPast(Qwen2VLModelOutputWithPast):
+ pass
+
+
+class Qwen3VLPreTrainedModel(Qwen2VLPreTrainedModel):
+ config: Qwen3VLConfig
+ _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
+ _can_record_outputs = {
+ "hidden_states": Qwen3VLTextDecoderLayer,
+ "attentions": Qwen3VLTextAttention,
+ }
+
+
+class Qwen3VLVisionModel(Qwen3VLPreTrainedModel):
+ config: Qwen3VLVisionConfig
+ _no_split_modules = ["Qwen3VLVisionBlock"]
+
+ def __init__(self, config, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
+
+ self.patch_embed = Qwen3VLVisionPatchEmbed(
+ config=config,
+ )
+
+ self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)])
+ self.merger = Qwen3VLVisionPatchMerger(
+ config=config,
+ use_postshuffle_norm=False,
+ )
+
+ self.deepstack_visual_indexes = config.deepstack_visual_indexes
+ self.deepstack_merger_list = nn.ModuleList(
+ [
+ Qwen3VLVisionPatchMerger(
+ config=config,
+ use_postshuffle_norm=True,
+ )
+ for _ in range(len(config.deepstack_visual_indexes))
+ ]
+ )
+
+ self.gradient_checkpointing = False
+
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+ merge_size = self.spatial_merge_size
+
+ max_hw = int(grid_thw[:, 1:].max().item())
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
+ device = freq_table.device
+
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
+
+ offset = 0
+ for num_frames, height, width in grid_thw:
+ merged_h, merged_w = height // merge_size, width // merge_size
+
+ block_rows = torch.arange(merged_h, device=device) # block row indices
+ block_cols = torch.arange(merged_w, device=device) # block col indices
+ intra_row = torch.arange(merge_size, device=device) # intra-block row offsets
+ intra_col = torch.arange(merge_size, device=device) # intra-block col offsets
+
+ # Compute full-resolution positions
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
+
+ row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+ col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+
+ coords = torch.stack((row_idx, col_idx), dim=-1)
+
+ if num_frames > 1:
+ coords = coords.repeat(num_frames, 1)
+
+ num_tokens = coords.shape[0]
+ pos_ids[offset : offset + num_tokens] = coords
+ offset += num_tokens
+
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
+ embeddings = embeddings.flatten(1)
+ return embeddings
+
+ def fast_pos_embed_interpolate(self, grid_thw):
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.int()
+ w_idxs_floor = w_idxs.int()
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ base_h = h_idxs_floor * self.num_grid_per_side
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
+
+ indices = [
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+ ]
+
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
+
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
+ weight_tensor = torch.tensor(
+ weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
+ )
+ pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+ patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
+
+ patch_pos_embeds_permute = []
+ merge_size = self.config.spatial_merge_size
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+ pos_embed = pos_embed.repeat(t, 1)
+ pos_embed = (
+ pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
+ .permute(0, 1, 3, 2, 4, 5)
+ .flatten(0, 4)
+ )
+ patch_pos_embeds_permute.append(pos_embed)
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
+ return patch_pos_embeds
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ deepstack_feature_lists = []
+ for layer_num, blk in enumerate(self.blocks):
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ if layer_num in self.deepstack_visual_indexes:
+ deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
+ hidden_states
+ )
+ deepstack_feature_lists.append(deepstack_feature)
+
+ hidden_states = self.merger(hidden_states)
+
+ return hidden_states, deepstack_feature_lists
+
+
+@auto_docstring(
+ custom_intro=(
+ "Text part of Qwen3VL, "
+ "not a pure text-only model, as DeepStack integrates visual features into the early hidden states."
+ )
+)
+class Qwen3VLTextModel(Qwen3VLPreTrainedModel, Qwen3Model):
+ config: Qwen3VLTextConfig
+ _no_split_modules = ["Qwen3VLTextDecoderLayer"]
+
+ def __init__(self, config: Qwen3VLTextConfig):
+ super().__init__(config)
+ del self.has_sliding_layers
+
+ def _deepstack_process(
+ self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor
+ ):
+ visual_pos_masks = visual_pos_masks.to(hidden_states.device)
+ visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
+ local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
+ hidden_states[visual_pos_masks, :] = local_this
+ return hidden_states
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ # args for deepstack
+ visual_pos_masks: Optional[torch.Tensor] = None,
+ deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ r"""
+ visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
+ The mask of the visual positions.
+ deepstack_visual_embeds (`list[torch.Tensor]`, *optional*):
+ The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim).
+ The feature is extracted from the different visual encoder layers, and fed to the decoder
+ hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334).
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ text_position_ids = position_ids[0]
+
+ attention_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=text_position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ for layer_idx, decoder_layer in enumerate(self.layers):
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=text_position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = layer_outputs
+
+ # add visual features to the hidden states of first several layers
+ if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
+ hidden_states = self._deepstack_process(
+ hidden_states,
+ visual_pos_masks,
+ deepstack_visual_embeds[layer_idx],
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class Qwen3VLModel(Qwen2_5_VLModel):
+ config: Qwen3VLConfig
+ _checkpoint_conversion_mapping = {}
+ _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Qwen3VLVisionModel._from_config(config.vision_config)
+ self.language_model = Qwen3VLTextModel._from_config(config.text_config)
+
+ def get_rope_index(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids."""
+
+ # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split
+ if video_grid_thw is not None:
+ video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
+ video_grid_thw[:, 0] = 1
+
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_index, video_index = 0, 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+ for i, input_ids in enumerate(total_input_ids):
+ input_ids = input_ids[attention_mask[i] == 1]
+ image_nums, video_nums = 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype)
+ image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ return image_embeds, deepstack_image_embeds
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
+
+ Args:
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ # Same implementation as for images
+ return self.get_image_features(pixel_values_videos, video_grid_thw)
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen3VLModelOutputWithPast]:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ image_mask = None
+ video_mask = None
+
+ if pixel_values is not None:
+ image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw)
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ visual_pos_masks = None
+ deepstack_visual_embeds = None
+ if image_mask is not None and video_mask is not None:
+ # aggregate visual_pos_masks and deepstack_visual_embeds
+ image_mask = image_mask[..., 0]
+ video_mask = video_mask[..., 0]
+ visual_pos_masks = image_mask | video_mask
+ deepstack_visual_embeds = []
+ image_mask_joint = image_mask[visual_pos_masks]
+ video_mask_joint = video_mask[visual_pos_masks]
+ for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
+ embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
+ embed_joint[image_mask_joint, :] = img_embed
+ embed_joint[video_mask_joint, :] = vid_embed
+ deepstack_visual_embeds.append(embed_joint)
+ elif image_mask is not None:
+ image_mask = image_mask[..., 0]
+ visual_pos_masks = image_mask
+ deepstack_visual_embeds = deepstack_image_embeds
+ elif video_mask is not None:
+ video_mask = video_mask[..., 0]
+ visual_pos_masks = video_mask
+ deepstack_visual_embeds = deepstack_video_embeds
+
+ if position_ids is None:
+ attention_mask_tensor = (
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
+ )
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
+ # Only apply conversion for floating point tensors (inverted masks)
+ if attention_mask_tensor.dtype.is_floating_point:
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
+
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do asssisted decoding
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
+ (input_ids is not None and input_ids.shape[1] != 1)
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
+ )
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
+ (cache_position is not None and cache_position[0] == 0)
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
+ )
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask=attention_mask_tensor,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ delta = (
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
+ if cache_position is not None
+ else 0
+ )
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ if cache_position is not None: # otherwise `deltas` is an int `0`
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ visual_pos_masks=visual_pos_masks,
+ deepstack_visual_embeds=deepstack_visual_embeds,
+ **kwargs,
+ )
+
+ return Qwen3VLModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+
+
+class Qwen3VLCausalLMOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast):
+ pass
+
+
+class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
+ config: Qwen3VLConfig
+ _checkpoint_conversion_mapping = {}
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+
+ Example:
+ TODO: Add example
+ """
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return Qwen3VLCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ # Qwen3VL position_ids are prepareed with rope_deltas in forward
+ model_inputs["position_ids"] = None
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+
+ return model_inputs
+
+
+class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False):
+ pass
+
+
+class Qwen3VLImagesKwargs(Qwen2VLImagesKwargs):
+ pass
+
+
+class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: Qwen3VLImagesKwargs
+ videos_kwargs: Qwen3VLVideosProcessorKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_token_type_ids": False,
+ "return_mm_token_type_ids": False,
+ },
+ "videos_kwargs": {"return_metadata": True},
+ }
+
+
+class Qwen3VLProcessor(Qwen2VLProcessor):
+ r"""
+ Constructs a Qwen3VL processor which wraps a Qwen3VL image processor and a Qwen2 tokenizer into a single processor.
+ [`Qwen3VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
+ [`~Qwen3VLProcessor.__call__`] and [`~Qwen3VLProcessor.decode`] for more information.
+ Args:
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ video_processor ([`Qwen3VLVideoProcessor`], *optional*):
+ The video processor is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
+ super().__init__(image_processor, tokenizer, video_processor, chat_template, **kwargs)
+ self.vision_start_token = (
+ "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token
+ )
+ self.vision_end_token = (
+ "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token
+ )
+ self.vision_start_token_id = (
+ tokenizer.vision_start_token_id
+ if getattr(tokenizer, "vision_start_token_id", None)
+ else tokenizer.convert_tokens_to_ids(self.vision_start_token)
+ )
+ self.vision_end_token_id = (
+ tokenizer.vision_end_token_id
+ if getattr(tokenizer, "vision_end_token_id", None)
+ else tokenizer.convert_tokens_to_ids(self.vision_end_token)
+ )
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ videos: VideoInput = None,
+ **kwargs: Unpack[Qwen3VLProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
+ - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ Qwen3VLProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ if images is not None:
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ image_grid_thw = image_inputs["image_grid_thw"]
+ else:
+ image_inputs = {}
+ image_grid_thw = None
+
+ if videos is not None:
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
+ video_grid_thw = videos_inputs["video_grid_thw"]
+ # If user has not requested video metadata, pop it
+ if "return_metadata" not in kwargs:
+ video_metadata = videos_inputs.pop("video_metadata")
+ else:
+ video_metadata = videos_inputs["video_metadata"]
+ video_grid_thw = videos_inputs["video_grid_thw"]
+ else:
+ videos_inputs = {}
+ video_grid_thw = None
+
+ if not isinstance(text, list):
+ text = [text]
+
+ text = text.copy() # below lines change text in-place
+ if image_grid_thw is not None:
+ merge_length = self.image_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ num_image_tokens = image_grid_thw[index].prod() // merge_length
+ text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
+
+ if video_grid_thw is not None:
+ merge_length = self.video_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while self.video_token in text[i]:
+ metadata = video_metadata[i]
+ if metadata.fps is None:
+ logger.warning_once(
+ "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
+ "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
+ "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
+ )
+ metadata.fps = 24 if metadata.fps is None else metadata.fps
+
+ # if timestamps are not provided, calculate them
+ curr_timestamp = self._calculate_timestamps(
+ metadata.frames_indices,
+ metadata.fps,
+ self.video_processor.merge_size,
+ )
+
+ video_placeholder = ""
+ frame_seqlen = video_grid_thw[index][1:].prod() // merge_length
+ for frame_idx in range(video_grid_thw[index][0]):
+ curr_time = curr_timestamp[frame_idx]
+ video_placeholder += f"<{curr_time:.1f} seconds>"
+ video_placeholder += (
+ self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token
+ )
+ if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]:
+ text[i] = text[i].replace(
+ f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1
+ )
+ else:
+ # vllm may input video token directly
+ text[i] = text[i].replace(self.video_token, video_placeholder, 1)
+ index += 1
+
+ text[i] = text[i].replace("<|placeholder|>", self.video_token)
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
+
+ def _calculate_timestamps(self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2):
+ if not isinstance(indices, list):
+ indices = indices.tolist()
+ if len(indices) % merge_size != 0:
+ indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size))
+ timestamps = [idx / video_fps for idx in indices]
+ # @JJJYmmm frames are merged by self.merge_size, \
+ # so we need to average the timestamps between the first/last frame within the temporal patch
+ timestamps = [
+ (timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size)
+ ]
+ return timestamps
+
+
+__all__ = [
+ "Qwen3VLConfig",
+ "Qwen3VLTextConfig",
+ "Qwen3VLVisionModel",
+ "Qwen3VLForConditionalGeneration",
+ "Qwen3VLModel",
+ "Qwen3VLPreTrainedModel",
+ "Qwen3VLProcessor",
+ "Qwen3VLTextModel",
+]
diff --git a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py
new file mode 100644
index 000000000000..cac82e738f39
--- /dev/null
+++ b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py
@@ -0,0 +1,328 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen3_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import logging
+from ...video_utils import VideoInput
+
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False):
+ pass
+
+
+class Qwen3VLImagesKwargs(ImagesKwargs):
+ min_pixels: Optional[int]
+ max_pixels: Optional[int]
+ patch_size: Optional[int]
+ temporal_patch_size: Optional[int]
+ merge_size: Optional[int]
+
+
+class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: Qwen3VLImagesKwargs
+ videos_kwargs: Qwen3VLVideosProcessorKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_token_type_ids": False,
+ "return_mm_token_type_ids": False,
+ },
+ "videos_kwargs": {"return_metadata": True},
+ }
+
+
+class Qwen3VLProcessor(ProcessorMixin):
+ r"""
+ Constructs a Qwen3VL processor which wraps a Qwen3VL image processor and a Qwen2 tokenizer into a single processor.
+ [`Qwen3VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
+ [`~Qwen3VLProcessor.__call__`] and [`~Qwen3VLProcessor.decode`] for more information.
+ Args:
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ video_processor ([`Qwen3VLVideoProcessor`], *optional*):
+ The video processor is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer", "video_processor"]
+ image_processor_class = "AutoImageProcessor"
+ video_processor_class = "AutoVideoProcessor"
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
+
+ def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
+ super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
+ self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
+ self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
+ self.image_token_id = (
+ tokenizer.image_token_id
+ if getattr(tokenizer, "image_token_id", None)
+ else tokenizer.convert_tokens_to_ids(self.image_token)
+ )
+ self.video_token_id = (
+ tokenizer.video_token_id
+ if getattr(tokenizer, "video_token_id", None)
+ else tokenizer.convert_tokens_to_ids(self.video_token)
+ )
+ self.vision_start_token = (
+ "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token
+ )
+ self.vision_end_token = (
+ "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token
+ )
+ self.vision_start_token_id = (
+ tokenizer.vision_start_token_id
+ if getattr(tokenizer, "vision_start_token_id", None)
+ else tokenizer.convert_tokens_to_ids(self.vision_start_token)
+ )
+ self.vision_end_token_id = (
+ tokenizer.vision_end_token_id
+ if getattr(tokenizer, "vision_end_token_id", None)
+ else tokenizer.convert_tokens_to_ids(self.vision_end_token)
+ )
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ videos: VideoInput = None,
+ **kwargs: Unpack[Qwen3VLProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
+ - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ Qwen3VLProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ if images is not None:
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ image_grid_thw = image_inputs["image_grid_thw"]
+ else:
+ image_inputs = {}
+ image_grid_thw = None
+
+ if videos is not None:
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
+ video_grid_thw = videos_inputs["video_grid_thw"]
+ # If user has not requested video metadata, pop it
+ if "return_metadata" not in kwargs:
+ video_metadata = videos_inputs.pop("video_metadata")
+ else:
+ video_metadata = videos_inputs["video_metadata"]
+ video_grid_thw = videos_inputs["video_grid_thw"]
+ else:
+ videos_inputs = {}
+ video_grid_thw = None
+
+ if not isinstance(text, list):
+ text = [text]
+
+ text = text.copy() # below lines change text in-place
+ if image_grid_thw is not None:
+ merge_length = self.image_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ num_image_tokens = image_grid_thw[index].prod() // merge_length
+ text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
+
+ if video_grid_thw is not None:
+ merge_length = self.video_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while self.video_token in text[i]:
+ metadata = video_metadata[i]
+ if metadata.fps is None:
+ logger.warning_once(
+ "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
+ "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
+ "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
+ )
+ metadata.fps = 24 if metadata.fps is None else metadata.fps
+
+ # if timestamps are not provided, calculate them
+ curr_timestamp = self._calculate_timestamps(
+ metadata.frames_indices,
+ metadata.fps,
+ self.video_processor.merge_size,
+ )
+
+ video_placeholder = ""
+ frame_seqlen = video_grid_thw[index][1:].prod() // merge_length
+ for frame_idx in range(video_grid_thw[index][0]):
+ curr_time = curr_timestamp[frame_idx]
+ video_placeholder += f"<{curr_time:.1f} seconds>"
+ video_placeholder += (
+ self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token
+ )
+ if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]:
+ text[i] = text[i].replace(
+ f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1
+ )
+ else:
+ # vllm may input video token directly
+ text[i] = text[i].replace(self.video_token, video_placeholder, 1)
+ index += 1
+
+ text[i] = text[i].replace("<|placeholder|>", self.video_token)
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+ video_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (num_frames, height, width) per each video.
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = Qwen3VLProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+ merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
+
+ num_image_patches = [
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
+ for image_size in image_sizes
+ ]
+ num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+
+ if video_sizes is not None:
+ videos_kwargs = Qwen3VLProcessorKwargs._defaults.get("videos_kwargs", {})
+ videos_kwargs.update(kwargs)
+ num_video_patches = [
+ self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
+ for video_size in video_sizes
+ ]
+ num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
+ vision_data["num_video_tokens"] = num_video_tokens
+
+ return MultiModalData(**vision_data)
+
+ def post_process_image_text_to_text(
+ self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
+ ):
+ """
+ Post-process the output of the model to decode the text.
+
+ Args:
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
+ or `(sequence_length,)`.
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+ Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
+ **kwargs:
+ Additional arguments to be passed to the tokenizer's `batch_decode method`.
+
+ Returns:
+ `list[str]`: The decoded text.
+ """
+ return self.tokenizer.batch_decode(
+ generated_outputs,
+ skip_special_tokens=skip_special_tokens,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ def _calculate_timestamps(self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2):
+ if not isinstance(indices, list):
+ indices = indices.tolist()
+ if len(indices) % merge_size != 0:
+ indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size))
+ timestamps = [idx / video_fps for idx in indices]
+ # @JJJYmmm frames are merged by self.merge_size, \
+ # so we need to average the timestamps between the first/last frame within the temporal patch
+ timestamps = [
+ (timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size)
+ ]
+ return timestamps
+
+
+__all__ = ["Qwen3VLProcessor"]
diff --git a/src/transformers/models/qwen3_vl/video_processing_qwen3_vl.py b/src/transformers/models/qwen3_vl/video_processing_qwen3_vl.py
new file mode 100644
index 000000000000..c4648788c9dc
--- /dev/null
+++ b/src/transformers/models/qwen3_vl/video_processing_qwen3_vl.py
@@ -0,0 +1,276 @@
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""video processor class for Qwen3-VL."""
+
+import math
+from typing import Optional, Union
+
+import numpy as np
+import torch
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ChannelDimension, PILImageResampling, SizeDict, get_image_size
+from ...processing_utils import Unpack, VideosKwargs
+from ...utils import TensorType, add_start_docstrings, logging
+from ...video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor
+from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos
+
+
+logger = logging.get_logger(__name__)
+
+
+def smart_resize(
+ num_frames: int,
+ height: int,
+ width: int,
+ temporal_factor: int = 2,
+ factor: int = 32,
+ min_pixels: int = 128 * 128,
+ max_pixels: int = 16 * 16 * 2 * 2 * 2 * 6144,
+):
+ if num_frames < temporal_factor:
+ raise ValueError(f"t:{num_frames} must be larger than temporal_factor:{temporal_factor}")
+ if height < factor or width < factor:
+ raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
+ elif max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = round(height / factor) * factor
+ w_bar = round(width / factor) * factor
+ t_bar = round(num_frames / temporal_factor) * temporal_factor
+
+ if t_bar * h_bar * w_bar > max_pixels:
+ beta = math.sqrt((num_frames * height * width) / max_pixels)
+ h_bar = max(factor, math.floor(height / beta / factor) * factor)
+ w_bar = max(factor, math.floor(width / beta / factor) * factor)
+ elif t_bar * h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (num_frames * height * width))
+ h_bar = math.ceil(height * beta / factor) * factor
+ w_bar = math.ceil(width * beta / factor) * factor
+
+ return h_bar, w_bar
+
+
+class Qwen3VLVideoProcessorInitKwargs(VideosKwargs):
+ patch_size: Optional[int]
+ temporal_patch_size: Optional[int]
+ merge_size: Optional[int]
+ min_frames: Optional[int]
+ max_frames: Optional[int]
+
+
+@add_start_docstrings(
+ "Constructs a fast Qwen3-VL image processor that dynamically resizes videos based on the original videos.",
+ BASE_VIDEO_PROCESSOR_DOCSTRING,
+ """
+ patch_size (`int`, *optional*, defaults to 16):
+ The spacial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to 2):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to 2):
+ The merge size of the vision encoder to llm encoder.
+ """,
+)
+class Qwen3VLVideoProcessor(BaseVideoProcessor):
+ resample = PILImageResampling.BICUBIC
+ size = {"shortest_edge": 128 * 32 * 32, "longest_edge": 32 * 32 * 768}
+ image_mean = [0.5, 0.5, 0.5]
+ image_std = [0.5, 0.5, 0.5]
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ patch_size = 16
+ temporal_patch_size = 2
+ merge_size = 2
+ fps = 2
+ min_frames = 4
+ max_frames = 768
+ do_sample_frames = True
+ valid_kwargs = Qwen3VLVideoProcessorInitKwargs
+ model_input_names = ["pixel_values_videos", "video_grid_thw"]
+
+ def __init__(self, **kwargs: Unpack[Qwen3VLVideoProcessorInitKwargs]):
+ super().__init__(**kwargs)
+ if self.size is not None and (
+ self.size.get("shortest_edge", None) is None or self.size.get("longest_edge", None) is None
+ ):
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
+
+ def _further_process_kwargs(
+ self,
+ size: Optional[SizeDict] = None,
+ **kwargs,
+ ) -> dict:
+ """
+ Update kwargs that need further processing before being validated
+ Can be overridden by subclasses to customize the processing of kwargs.
+ """
+ if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
+
+ return super()._further_process_kwargs(size=size, **kwargs)
+
+ def sample_frames(
+ self,
+ metadata: VideoMetadata,
+ num_frames: Optional[int] = None,
+ fps: Optional[Union[int, float]] = None,
+ **kwargs,
+ ):
+ """
+ Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames.
+ If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames`
+ and `fps` are mutually exclusive.
+
+ Args:
+ video (`torch.Tensor`):
+ Video that need to be sampled.
+ metadata (`VideoMetadata`):
+ Metadata of the video containing information about total duration, fps and total number of frames.
+ num_frames (`int`, *optional*):
+ Maximum number of frames to sample. Defaults to `self.num_frames`.
+ fps (`int` or `float`, *optional*):
+ Target frames to sample per second. Defaults to `self.fps`.
+ Returns:
+ torch.Tensor:
+ Sampled video frames.
+ """
+ if fps is not None and num_frames is not None:
+ raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
+
+ total_num_frames = metadata.total_num_frames
+ fps = fps if fps is not None else self.fps
+
+ # If num_frames is not given but fps is, calculate num_frames from fps
+ if num_frames is None and fps is not None:
+ if metadata.fps is None:
+ metadata.fps = 24
+ logger.warning_once(
+ "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
+ "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
+ )
+ num_frames = int(total_num_frames / metadata.fps * fps)
+ num_frames = min(min(max(num_frames, self.min_frames), self.max_frames), total_num_frames)
+
+ if num_frames is None:
+ num_frames = min(max(total_num_frames, self.min_frames), self.max_frames)
+
+ indices = np.linspace(0, total_num_frames - 1, num_frames).round().astype(int)
+
+ return indices
+
+ def _preprocess(
+ self,
+ videos: list[torch.Tensor],
+ do_convert_rgb: bool = True,
+ do_resize: bool = True,
+ size: Optional[SizeDict] = None,
+ interpolation: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255.0,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ patch_size: Optional[int] = None,
+ temporal_patch_size: Optional[int] = None,
+ merge_size: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs,
+ ):
+ grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
+ resized_videos_grouped = {}
+
+ for shape, stacked_videos in grouped_videos.items():
+ B, T, C, H, W = stacked_videos.shape
+ num_frames, height, width = T, H, W
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ num_frames=num_frames,
+ height=height,
+ width=width,
+ temporal_factor=temporal_patch_size,
+ factor=patch_size * merge_size,
+ min_pixels=size.shortest_edge,
+ max_pixels=size.longest_edge,
+ )
+ stacked_videos = stacked_videos.view(B * T, C, H, W)
+ stacked_videos = self.resize(
+ stacked_videos,
+ size=SizeDict(height=resized_height, width=resized_width),
+ interpolation=interpolation,
+ )
+ stacked_videos = stacked_videos.view(B, T, C, resized_height, resized_width)
+ resized_videos_grouped[shape] = stacked_videos
+ resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
+
+ # Group videos by size for further processing
+ # Needed in case do_resize is False, or resize returns videos with different sizes
+ grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
+ processed_videos_grouped = {}
+ processed_grids = {}
+ for shape, stacked_videos in grouped_videos.items():
+ resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST)
+
+ # Fused rescale and normalize
+ stacked_videos = self.rescale_and_normalize(
+ stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ patches = stacked_videos
+
+ # Check that videos have `num_frames` divisible by `temporal_patch_size`
+ if patches.shape[1] % temporal_patch_size != 0:
+ repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)
+ patches = torch.cat([patches, repeats], dim=1)
+ batch_size, grid_t, channel = patches.shape[:3]
+ grid_t = grid_t // temporal_patch_size
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+
+ patches = patches.view(
+ batch_size,
+ grid_t,
+ temporal_patch_size,
+ channel,
+ grid_h // merge_size,
+ merge_size,
+ patch_size,
+ grid_w // merge_size,
+ merge_size,
+ patch_size,
+ )
+ patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
+ flatten_patches = patches.reshape(
+ batch_size,
+ grid_t * grid_h * grid_w,
+ channel * temporal_patch_size * patch_size * patch_size,
+ )
+
+ processed_videos_grouped[shape] = flatten_patches
+ processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
+
+ processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
+ processed_grids = reorder_videos(processed_grids, grouped_videos_index)
+ pixel_values_videos = torch.cat(processed_videos, dim=0)
+ video_grid_thw = torch.tensor(processed_grids)
+ data = {
+ "pixel_values_videos": pixel_values_videos,
+ "video_grid_thw": video_grid_thw,
+ }
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["Qwen3VLVideoProcessor"]
diff --git a/src/transformers/models/qwen3_vl_moe/__init__.py b/src/transformers/models/qwen3_vl_moe/__init__.py
new file mode 100644
index 000000000000..a4000cb27272
--- /dev/null
+++ b/src/transformers/models/qwen3_vl_moe/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_qwen3_vl_moe import *
+ from .modeling_qwen3_vl_moe import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py
new file mode 100644
index 000000000000..c4a31e8f9f92
--- /dev/null
+++ b/src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py
@@ -0,0 +1,331 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen3_vl_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class Qwen3VLMoeTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a
+ Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151936):
+ Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen2MoeModel`]
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 5632):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 16):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 128000):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 5000000.0):
+ The base period of the RoPE embeddings.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ decoder_sparse_step (`int`, *optional*, defaults to 1):
+ The frequency of the MoE layer.
+ moe_intermediate_size (`int`, *optional*, defaults to 1408):
+ Intermediate size of the routed expert.
+ num_experts_per_tok (`int`, *optional*, defaults to 4):
+ Number of selected experts.
+ num_experts (`int`, *optional*, defaults to 60):
+ Number of routed experts.
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the topk probabilities.
+ mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
+ Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock
+ The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
+ If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ head_dim (`int`, *optional*):
+ The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
+
+ ```python
+ >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
+
+ >>> # Initializing a Qwen3VLMoe style configuration
+ >>> configuration = Qwen3VLMoeConfig()
+
+ >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
+ >>> model = Qwen3VLMoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_vl_moe_text"
+ base_config_key = "text_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `Qwen3VLMoe`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=151936,
+ hidden_size=2048,
+ intermediate_size=5632,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ hidden_act="silu",
+ max_position_embeddings=128000,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=5000000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ decoder_sparse_step=1,
+ moe_intermediate_size=1408,
+ num_experts_per_tok=4,
+ num_experts=60,
+ norm_topk_prob=True,
+ mlp_only_layers=None,
+ rope_scaling=None,
+ head_dim=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.rope_scaling = rope_scaling
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+
+ rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
+
+ # MoE arguments
+ self.decoder_sparse_step = decoder_sparse_step
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_experts = num_experts
+ self.norm_topk_prob = norm_topk_prob
+ self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class Qwen3VLMoeVisionConfig(PretrainedConfig):
+ model_type = "qwen3_vl_moe"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=27,
+ hidden_size=1152,
+ hidden_act="gelu_pytorch_tanh",
+ intermediate_size=4304,
+ num_heads=16,
+ in_channels=3,
+ patch_size=16,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ out_hidden_size=3584,
+ num_position_embeddings=2304,
+ deepstack_visual_indexes=[8, 16, 24],
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.out_hidden_size = out_hidden_size
+ self.num_position_embeddings = num_position_embeddings
+ self.initializer_range = initializer_range
+ self.deepstack_visual_indexes = deepstack_visual_indexes
+
+
+class Qwen3VLMoeConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a
+ Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 151655):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 151656):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 151652):
+ The start token index to encode the image prompt.
+ vision_end_token_id (`int`, *optional*, defaults to 151653):
+ The end token index to encode the image prompt.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie the word embeddings.
+
+ ```python
+ >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
+
+ >>> # Initializing a Qwen3-VL-MOE style configuration
+ >>> configuration = Qwen3VLMoeConfig()
+
+ >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
+ >>> model = Qwen3VLMoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_vl_moe"
+ sub_configs = {"vision_config": Qwen3VLMoeVisionConfig, "text_config": Qwen3VLMoeTextConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=151655,
+ video_token_id=151656,
+ vision_start_token_id=151652,
+ vision_end_token_id=151653,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ self.text_config = self.sub_configs["text_config"]()
+
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+ super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
+
+
+__all__ = ["Qwen3VLMoeConfig", "Qwen3VLMoeTextConfig"]
diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py
new file mode 100644
index 000000000000..74b793f096f3
--- /dev/null
+++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py
@@ -0,0 +1,1711 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_qwen3_vl_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import OutputRecorder, check_model_inputs
+from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Qwen3VLMoeTextRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Qwen3VLMoeTextRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Qwen3VLMoeTextRouter(nn.Linear):
+ def __init__(self, config):
+ super().__init__(config.hidden_size, config.num_experts, bias=False)
+ self.hidden_size = config.hidden_size
+ self.top_k = config.num_experts_per_tok
+ # since all the models use norm_topk_prob, we don't need to have a extra check for it
+ # self.norm_topk_prob = config.norm_topk_prob
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.reshape(-1, self.hidden_size)
+ router_logits = super().forward(hidden_states)
+ routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
+ routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
+ routing_weights = routing_weights.to(hidden_states.dtype)
+ router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights)
+ return router_weights, router_logits, router_indices
+
+
+class Qwen3VLMoeTextExperts(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_experts = config.num_experts
+ self.intermediate_size = config.moe_intermediate_size
+ self.hidden_size = config.hidden_size
+ self.expert_dim = self.intermediate_size
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
+ self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(
+ self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ When training it is more efficient to just loop over the experts and compute the output for each expert
+ as otherwise the memory would explode.
+
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
+
+ Args:
+ hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
+ router_indices (torch.Tensor): (batch_size * token_num, top_k)
+ Returns:
+ torch.Tensor
+ """
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
+ if self.training:
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
+ expert_mask = expert_mask.permute(2, 1, 0)
+ # we sum on the top_k and on the sequence length to get which experts
+ # are hit this time around
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+ for expert_idx in expert_hit[:]:
+ with torch.no_grad():
+ _, token_idx = torch.where(expert_mask[expert_idx[0]])
+ current_state = hidden_states[token_idx]
+ gate_up = current_state @ self.gate_up_proj[expert_idx]
+ gate, up = gate_up.chunk(2, dim=-1)
+ gated_output = up * self.act_fn(gate)
+ out = gated_output @ self.down_proj[expert_idx]
+ weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
+ else:
+ hidden_states = hidden_states.repeat(self.num_experts, 1)
+ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj)
+ gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
+ next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
+ next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size)
+ next_states = (
+ next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None]
+ )
+ next_states = next_states.sum(dim=0)
+ return next_states
+
+
+class Qwen3VLMoeTextSparseMoeBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_experts = config.num_experts
+ self.gate = Qwen3VLMoeTextRouter(config)
+ self.experts = Qwen3VLMoeTextExperts(config)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ router_weights, router_logits, router_indices = self.gate(hidden_states)
+ routed_out = self.experts(hidden_states, router_weights, router_indices)
+ return routed_out, router_logits
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class Qwen3VLMoeTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.q_norm = Qwen3VLMoeTextRMSNorm(
+ self.head_dim, eps=config.rms_norm_eps
+ ) # unlike olmo, only on the head dim!
+ self.k_norm = Qwen3VLMoeTextRMSNorm(
+ self.head_dim, eps=config.rms_norm_eps
+ ) # thus post q_norm does not need reshape
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Qwen3VLMoeTextMLP(nn.Module):
+ def __init__(self, config, intermediate_size=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class Qwen3VLMoeTextDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = Qwen3VLMoeTextAttention(config, layer_idx)
+
+ if (layer_idx not in config.mlp_only_layers) and (
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
+ ):
+ self.mlp = Qwen3VLMoeTextSparseMoeBlock(config)
+ else:
+ self.mlp = Qwen3VLMoeTextMLP(config, intermediate_size=config.intermediate_size)
+
+ self.input_layernorm = Qwen3VLMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen3VLMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[tuple[torch.Tensor]] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> torch.FloatTensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
+ and should not be returned during inference.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ # For the MoE layers, we need to unpack
+ if isinstance(hidden_states, tuple):
+ hidden_states, _ = hidden_states
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+@auto_docstring
+class Qwen3VLMoePreTrainedModel(PreTrainedModel):
+ config: Qwen3VLMoeConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "router_logits": OutputRecorder(Qwen3VLMoeTextSparseMoeBlock, index=1),
+ "hidden_states": Qwen3VLMoeTextDecoderLayer,
+ "attentions": Qwen3VLMoeTextAttention,
+ }
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ super()._init_weights(module)
+ if hasattr(self.config, "initializer_range"):
+ std = self.config.initializer_range
+ else:
+ std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
+ if isinstance(module, Qwen3VLMoeTextExperts):
+ module.gate_up_proj.data.normal_(mean=0.0, std=std)
+ module.down_proj.data.normal_(mean=0.0, std=std)
+
+
+class Qwen3VLMoeVisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
+
+
+class Qwen3VLMoeVisionPatchEmbed(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.temporal_patch_size = config.temporal_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.hidden_size
+
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Qwen3VLMoeVisionRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class Qwen3VLMoeVisionPatchMerger(nn.Module):
+ def __init__(self, config: Qwen3VLMoeVisionConfig, use_postshuffle_norm=False) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
+ self.use_postshuffle_norm = use_postshuffle_norm
+ self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
+ self.act_fn = nn.GELU()
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
+ return x
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_q_dtype = q.dtype
+ orig_k_dtype = k.dtype
+ q, k = q.float(), k.float()
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ q_embed = q_embed.to(orig_q_dtype)
+ k_embed = k_embed.to(orig_k_dtype)
+ return q_embed, k_embed
+
+
+class Qwen3VLMoeVisionAttention(nn.Module):
+ def __init__(self, config: Qwen3VLMoeVisionConfig) -> None:
+ super().__init__()
+ self.dim = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
+ self.proj = nn.Linear(self.dim, self.dim)
+ self.scaling = self.head_dim**-0.5
+ self.config = config
+ self.attention_dropout = 0.0
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ query_states, key_states, value_states = (
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ )
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if self.config._attn_implementation == "flash_attention_2":
+ # Flash Attention 2: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs = [
+ attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )[0]
+ for q, k, v in zip(*splits)
+ ]
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen3VLMoeVisionBlock(GradientCheckpointingLayer):
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
+ super().__init__()
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
+ self.attn = Qwen3VLMoeVisionAttention(config=config)
+ self.mlp = Qwen3VLMoeVisionMLP(config=config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Qwen3VLMoeVisionModel(Qwen3VLMoePreTrainedModel):
+ config: Qwen3VLMoeVisionConfig
+ _no_split_modules = ["Qwen3VLMoeVisionBlock"]
+
+ def __init__(self, config, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
+
+ self.patch_embed = Qwen3VLMoeVisionPatchEmbed(
+ config=config,
+ )
+
+ self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Qwen3VLMoeVisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([Qwen3VLMoeVisionBlock(config) for _ in range(config.depth)])
+ self.merger = Qwen3VLMoeVisionPatchMerger(
+ config=config,
+ use_postshuffle_norm=False,
+ )
+
+ self.deepstack_visual_indexes = config.deepstack_visual_indexes
+ self.deepstack_merger_list = nn.ModuleList(
+ [
+ Qwen3VLMoeVisionPatchMerger(
+ config=config,
+ use_postshuffle_norm=True,
+ )
+ for _ in range(len(config.deepstack_visual_indexes))
+ ]
+ )
+
+ self.gradient_checkpointing = False
+
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+ merge_size = self.spatial_merge_size
+
+ max_hw = int(grid_thw[:, 1:].max().item())
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
+ device = freq_table.device
+
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
+
+ offset = 0
+ for num_frames, height, width in grid_thw:
+ merged_h, merged_w = height // merge_size, width // merge_size
+
+ block_rows = torch.arange(merged_h, device=device) # block row indices
+ block_cols = torch.arange(merged_w, device=device) # block col indices
+ intra_row = torch.arange(merge_size, device=device) # intra-block row offsets
+ intra_col = torch.arange(merge_size, device=device) # intra-block col offsets
+
+ # Compute full-resolution positions
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
+
+ row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+ col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+
+ coords = torch.stack((row_idx, col_idx), dim=-1)
+
+ if num_frames > 1:
+ coords = coords.repeat(num_frames, 1)
+
+ num_tokens = coords.shape[0]
+ pos_ids[offset : offset + num_tokens] = coords
+ offset += num_tokens
+
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
+ embeddings = embeddings.flatten(1)
+ return embeddings
+
+ def fast_pos_embed_interpolate(self, grid_thw):
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.int()
+ w_idxs_floor = w_idxs.int()
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ base_h = h_idxs_floor * self.num_grid_per_side
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
+
+ indices = [
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+ ]
+
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
+
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
+ weight_tensor = torch.tensor(
+ weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
+ )
+ pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+ patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
+
+ patch_pos_embeds_permute = []
+ merge_size = self.config.spatial_merge_size
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+ pos_embed = pos_embed.repeat(t, 1)
+ pos_embed = (
+ pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
+ .permute(0, 1, 3, 2, 4, 5)
+ .flatten(0, 4)
+ )
+ patch_pos_embeds_permute.append(pos_embed)
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
+ return patch_pos_embeds
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ deepstack_feature_lists = []
+ for layer_num, blk in enumerate(self.blocks):
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ if layer_num in self.deepstack_visual_indexes:
+ deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
+ hidden_states
+ )
+ deepstack_feature_lists.append(deepstack_feature)
+
+ hidden_states = self.merger(hidden_states)
+
+ return hidden_states, deepstack_feature_lists
+
+
+class Qwen3VLMoeTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Qwen3VLMoeTextConfig, device=None):
+ super().__init__()
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", "default")
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
+
+ def apply_interleaved_mrope(self, freqs, mrope_section):
+ """Apply interleaved MRoPE to 3D rotary embeddings.
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
+ args:
+ x: (3, bs, seq_len, head_dim // 2)
+ mrope_section: (3,)
+ returns:
+ x_t: (bs, seq_len, head_dim // 2)
+ """
+ freqs_t = freqs[0] # just overwrite the first dimension T
+ for dim, offset in enumerate((1, 2), start=1): # H, W
+ length = mrope_section[dim] * 3
+ idx = slice(offset, length, 3)
+ freqs_t[..., idx] = freqs[dim, ..., idx]
+ return freqs_t
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Qwen3VLMoe has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ if position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring(
+ custom_intro=(
+ "Text part of Qwen3VLMoe, "
+ "not a pure text-only model, as DeepStack integrates visual features into the early hidden states."
+ )
+)
+class Qwen3VLMoeTextModel(Qwen3VLMoePreTrainedModel):
+ config: Qwen3VLMoeTextConfig
+ _no_split_modules = ["Qwen3VLMoeTextDecoderLayer"]
+
+ def __init__(self, config: Qwen3VLMoeTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Qwen3VLMoeTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Qwen3VLMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen3VLMoeTextRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ # args for deepstack
+ visual_pos_masks: Optional[torch.Tensor] = None,
+ deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ r"""
+ visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
+ The mask of the visual positions.
+ deepstack_visual_embeds (`list[torch.Tensor]`, *optional*):
+ The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim).
+ The feature is extracted from the different visual encoder layers, and fed to the decoder
+ hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334).
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ text_position_ids = position_ids[0]
+
+ attention_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=text_position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ for layer_idx, decoder_layer in enumerate(self.layers):
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=text_position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = layer_outputs
+
+ # add visual features to the hidden states of first several layers
+ if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
+ hidden_states = self._deepstack_process(
+ hidden_states,
+ visual_pos_masks,
+ deepstack_visual_embeds[layer_idx],
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+ def _deepstack_process(
+ self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor
+ ):
+ visual_pos_masks = visual_pos_masks.to(hidden_states.device)
+ visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
+ local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
+ hidden_states[visual_pos_masks, :] = local_this
+ return hidden_states
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class Qwen3VLMoeModelOutputWithPast(ModelOutput):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[list[torch.FloatTensor]] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+@auto_docstring
+class Qwen3VLMoeModel(Qwen3VLMoePreTrainedModel):
+ base_model_prefix = ""
+ _checkpoint_conversion_mapping = {}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Qwen3VLMoeConfig
+ _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Qwen3VLMoeVisionModel._from_config(config.vision_config)
+ self.language_model = Qwen3VLMoeTextModel._from_config(config.text_config)
+ self.rope_deltas = None # cache rope_deltas here
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_rope_index(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Different from the original implementation, Qwen3VLMoe use timestamps rather than absolute time position ids."""
+
+ # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split
+ if video_grid_thw is not None:
+ video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
+ video_grid_thw[:, 0] = 1
+
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_index, video_index = 0, 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+ for i, input_ids in enumerate(total_input_ids):
+ input_ids = input_ids[attention_mask[i] == 1]
+ image_nums, video_nums = 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
+
+ Args:
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ # Same implementation as for images
+ return self.get_image_features(pixel_values_videos, video_grid_thw)
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype)
+ image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ return image_embeds, deepstack_image_embeds
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: Optional[torch.FloatTensor] = None,
+ video_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
+ raise ValueError(
+ f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
+ )
+
+ return special_image_mask, special_video_mask
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen3VLMoeModelOutputWithPast]:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ image_mask = None
+ video_mask = None
+
+ if pixel_values is not None:
+ image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw)
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ visual_pos_masks = None
+ deepstack_visual_embeds = None
+ if image_mask is not None and video_mask is not None:
+ # aggregate visual_pos_masks and deepstack_visual_embeds
+ image_mask = image_mask[..., 0]
+ video_mask = video_mask[..., 0]
+ visual_pos_masks = image_mask | video_mask
+ deepstack_visual_embeds = []
+ image_mask_joint = image_mask[visual_pos_masks]
+ video_mask_joint = video_mask[visual_pos_masks]
+ for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
+ embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
+ embed_joint[image_mask_joint, :] = img_embed
+ embed_joint[video_mask_joint, :] = vid_embed
+ deepstack_visual_embeds.append(embed_joint)
+ elif image_mask is not None:
+ image_mask = image_mask[..., 0]
+ visual_pos_masks = image_mask
+ deepstack_visual_embeds = deepstack_image_embeds
+ elif video_mask is not None:
+ video_mask = video_mask[..., 0]
+ visual_pos_masks = video_mask
+ deepstack_visual_embeds = deepstack_video_embeds
+
+ if position_ids is None:
+ attention_mask_tensor = (
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
+ )
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
+ # Only apply conversion for floating point tensors (inverted masks)
+ if attention_mask_tensor.dtype.is_floating_point:
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
+
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do asssisted decoding
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
+ (input_ids is not None and input_ids.shape[1] != 1)
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
+ )
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
+ (cache_position is not None and cache_position[0] == 0)
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
+ )
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask=attention_mask_tensor,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ delta = (
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
+ if cache_position is not None
+ else 0
+ )
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ if cache_position is not None: # otherwise `deltas` is an int `0`
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ visual_pos_masks=visual_pos_masks,
+ deepstack_visual_embeds=deepstack_visual_embeds,
+ **kwargs,
+ )
+
+ return Qwen3VLMoeModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Qwen3VLMoe causal language model (or autoregressive) outputs.
+ """
+)
+class Qwen3VLMoeCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[list[torch.FloatTensor]] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = ["lm_head.weight"]
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Qwen3VLMoeConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen3VLMoeModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ return self.model.get_image_features(pixel_values, image_grid_thw)
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def visual(self):
+ return self.model.visual
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Qwen3VLMoeCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+
+ Example:
+ TODO: Add example
+ """
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return Qwen3VLMoeCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ # Qwen3VLMoe position_ids are prepareed with rope_deltas in forward
+ model_inputs["position_ids"] = None
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+
+ return model_inputs
+
+ def _get_image_nums_and_video_nums(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Returns:
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
+ """
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+
+ if inputs_embeds is not None:
+ vision_start_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ image_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ video_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ else:
+ vision_start_mask = input_ids == vision_start_token_id
+ image_mask = input_ids == image_token_id
+ video_mask = input_ids == video_token_id
+
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
+
+ return image_nums, video_nums
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: Optional[torch.LongTensor] = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ # Overwritten -- Support for expanding tensors without a batch size dimension
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
+ # image_grid_thw.shape[0] is sum(num_images for samples)
+
+ if expand_size == 1:
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
+
+ def _expand_dict_for_generation_visual(dict_to_expand):
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
+ )
+
+ def _repeat_interleave_samples(x, lengths, repeat_times):
+ samples = torch.split(x, lengths)
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
+ return result
+
+ for key in dict_to_expand:
+ if key == "pixel_values":
+ # split images into samples
+ samples = torch.split(image_grid_thw, list(image_nums))
+ # compute the sequence length of images for each sample
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "image_grid_thw":
+ # get the num of images for each sample
+ lengths = list(image_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "pixel_values_videos":
+ samples = torch.split(video_grid_thw, list(video_nums))
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "video_grid_thw":
+ lengths = list(video_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "second_per_grid_ts":
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
+ )
+ return dict_to_expand
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ and key not in visual_keys
+ ):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
+
+
+__all__ = [
+ "Qwen3VLMoeVisionModel",
+ "Qwen3VLMoeForConditionalGeneration",
+ "Qwen3VLMoeModel",
+ "Qwen3VLMoePreTrainedModel",
+ "Qwen3VLMoeTextModel",
+]
diff --git a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py
new file mode 100644
index 000000000000..456d7c60aa89
--- /dev/null
+++ b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py
@@ -0,0 +1,434 @@
+# coding=utf-8
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen3-VL-MOE model."""
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+from ...modeling_utils import PreTrainedModel
+from ...utils import logging
+from ..qwen3_moe.modeling_qwen3_moe import (
+ Qwen3MoeDecoderLayer,
+ Qwen3MoePreTrainedModel,
+ Qwen3MoeRMSNorm,
+)
+from ..qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
+from ..qwen3_vl.modeling_qwen3_vl import (
+ Qwen3VLForConditionalGeneration,
+ Qwen3VLModel,
+ Qwen3VLTextAttention,
+ Qwen3VLTextModel,
+ Qwen3VLVisionModel,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class Qwen3VLMoeTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a
+ Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151936):
+ Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Qwen2MoeModel`]
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 5632):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 16):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 128000):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 5000000.0):
+ The base period of the RoPE embeddings.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ decoder_sparse_step (`int`, *optional*, defaults to 1):
+ The frequency of the MoE layer.
+ moe_intermediate_size (`int`, *optional*, defaults to 1408):
+ Intermediate size of the routed expert.
+ num_experts_per_tok (`int`, *optional*, defaults to 4):
+ Number of selected experts.
+ num_experts (`int`, *optional*, defaults to 60):
+ Number of routed experts.
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the topk probabilities.
+ mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
+ Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock
+ The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
+ If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ head_dim (`int`, *optional*):
+ The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
+
+ ```python
+ >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
+
+ >>> # Initializing a Qwen3VLMoe style configuration
+ >>> configuration = Qwen3VLMoeConfig()
+
+ >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
+ >>> model = Qwen3VLMoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_vl_moe_text"
+ base_config_key = "text_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `Qwen3VLMoe`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=151936,
+ hidden_size=2048,
+ intermediate_size=5632,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ hidden_act="silu",
+ max_position_embeddings=128000,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=5000000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ decoder_sparse_step=1,
+ moe_intermediate_size=1408,
+ num_experts_per_tok=4,
+ num_experts=60,
+ norm_topk_prob=True,
+ mlp_only_layers=None,
+ rope_scaling=None,
+ head_dim=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.rope_scaling = rope_scaling
+ self.head_dim = head_dim or hidden_size // num_attention_heads
+
+ rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
+
+ # MoE arguments
+ self.decoder_sparse_step = decoder_sparse_step
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_experts = num_experts
+ self.norm_topk_prob = norm_topk_prob
+ self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class Qwen3VLMoeVisionConfig(Qwen3VLVisionConfig):
+ pass
+
+
+class Qwen3VLMoeConfig(Qwen3VLConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a
+ Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of
+ Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 151655):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 151656):
+ The video token index to encode the image prompt.
+ vision_start_token_id (`int`, *optional*, defaults to 151652):
+ The start token index to encode the image prompt.
+ vision_end_token_id (`int`, *optional*, defaults to 151653):
+ The end token index to encode the image prompt.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie the word embeddings.
+
+ ```python
+ >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
+
+ >>> # Initializing a Qwen3-VL-MOE style configuration
+ >>> configuration = Qwen3VLMoeConfig()
+
+ >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
+ >>> model = Qwen3VLMoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qwen3_vl_moe"
+ sub_configs = {"vision_config": Qwen3VLMoeVisionConfig, "text_config": Qwen3VLMoeTextConfig}
+
+
+class Qwen3VLMoeTextRMSNorm(Qwen3MoeRMSNorm):
+ pass
+
+
+class Qwen3VLMoeTextRouter(nn.Linear):
+ def __init__(self, config):
+ super().__init__(config.hidden_size, config.num_experts, bias=False)
+ self.hidden_size = config.hidden_size
+ self.top_k = config.num_experts_per_tok
+ # since all the models use norm_topk_prob, we don't need to have a extra check for it
+ # self.norm_topk_prob = config.norm_topk_prob
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.reshape(-1, self.hidden_size)
+ router_logits = super().forward(hidden_states)
+ routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
+ routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
+ routing_weights = routing_weights.to(hidden_states.dtype)
+ router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights)
+ return router_weights, router_logits, router_indices
+
+
+class Qwen3VLMoeTextExperts(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_experts = config.num_experts
+ self.intermediate_size = config.moe_intermediate_size
+ self.hidden_size = config.hidden_size
+ self.expert_dim = self.intermediate_size
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
+ self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(
+ self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ When training it is more efficient to just loop over the experts and compute the output for each expert
+ as otherwise the memory would explode.
+
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
+
+ Args:
+ hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
+ router_indices (torch.Tensor): (batch_size * token_num, top_k)
+ Returns:
+ torch.Tensor
+ """
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
+ if self.training:
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
+ expert_mask = expert_mask.permute(2, 1, 0)
+ # we sum on the top_k and on the sequence length to get which experts
+ # are hit this time around
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+ for expert_idx in expert_hit[:]:
+ with torch.no_grad():
+ _, token_idx = torch.where(expert_mask[expert_idx[0]])
+ current_state = hidden_states[token_idx]
+ gate_up = current_state @ self.gate_up_proj[expert_idx]
+ gate, up = gate_up.chunk(2, dim=-1)
+ gated_output = up * self.act_fn(gate)
+ out = gated_output @ self.down_proj[expert_idx]
+ weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
+ else:
+ hidden_states = hidden_states.repeat(self.num_experts, 1)
+ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj)
+ gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
+ next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
+ next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size)
+ next_states = (
+ next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None]
+ )
+ next_states = next_states.sum(dim=0)
+ return next_states
+
+
+class Qwen3VLMoeTextSparseMoeBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_experts = config.num_experts
+ self.gate = Qwen3VLMoeTextRouter(config)
+ self.experts = Qwen3VLMoeTextExperts(config)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ router_weights, router_logits, router_indices = self.gate(hidden_states)
+ routed_out = self.experts(hidden_states, router_weights, router_indices)
+ return routed_out, router_logits
+
+
+class Qwen3VLMoeTextAttention(Qwen3VLTextAttention):
+ pass
+
+
+class Qwen3VLMoeTextDecoderLayer(Qwen3MoeDecoderLayer):
+ pass
+
+
+class Qwen3VLMoePreTrainedModel(Qwen3MoePreTrainedModel):
+ config: Qwen3VLMoeConfig
+ _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"]
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ PreTrainedModel._init_weights(self, module)
+ if hasattr(self.config, "initializer_range"):
+ std = self.config.initializer_range
+ else:
+ std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
+ if isinstance(module, Qwen3VLMoeTextExperts):
+ module.gate_up_proj.data.normal_(mean=0.0, std=std)
+ module.down_proj.data.normal_(mean=0.0, std=std)
+
+
+class Qwen3VLMoeVisionModel(Qwen3VLVisionModel):
+ pass
+
+
+class Qwen3VLMoeTextModel(Qwen3VLTextModel):
+ pass
+
+
+class Qwen3VLMoeModel(Qwen3VLModel):
+ pass
+
+
+class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
+ pass
+
+
+__all__ = [
+ "Qwen3VLMoeConfig",
+ "Qwen3VLMoeTextConfig",
+ "Qwen3VLMoeVisionModel",
+ "Qwen3VLMoeForConditionalGeneration",
+ "Qwen3VLMoeModel",
+ "Qwen3VLMoePreTrainedModel",
+ "Qwen3VLMoeTextModel",
+]
diff --git a/tests/models/qwen3_vl/__init__.py b/tests/models/qwen3_vl/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/qwen3_vl/test_modeling_qwen3_vl.py b/tests/models/qwen3_vl/test_modeling_qwen3_vl.py
new file mode 100644
index 000000000000..35031bf542aa
--- /dev/null
+++ b/tests/models/qwen3_vl/test_modeling_qwen3_vl.py
@@ -0,0 +1,299 @@
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Qwen3-VL model."""
+
+import copy
+import unittest
+
+from transformers import (
+ Qwen3VLConfig,
+ Qwen3VLForConditionalGeneration,
+ Qwen3VLModel,
+ is_torch_available,
+)
+from transformers.testing_utils import (
+ require_torch,
+ torch_device,
+)
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ floats_tensor,
+ ids_tensor,
+)
+
+
+if is_torch_available():
+ import torch
+
+
+class Qwen3VLVisionText2TextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=3,
+ seq_length=7,
+ num_channels=3,
+ ignore_index=-100,
+ image_size=16,
+ text_config={
+ "bos_token_id": 0,
+ "eos_token_id": 1,
+ "pad_token_id": 2,
+ "hidden_act": "silu",
+ "head_dim": 8,
+ "hidden_size": 32,
+ "vocab_size": 99,
+ "intermediate_size": 37,
+ "max_position_embeddings": 512,
+ "model_type": "qwen3_vl",
+ "num_attention_heads": 4,
+ "num_hidden_layers": 4,
+ "num_key_value_heads": 2,
+ "rope_theta": 10000,
+ "tie_word_embeddings": True,
+ "rope_scaling": {"rope_type": "default", "mrope_section": [16, 8, 8], "mrope_interleaved": True},
+ },
+ vision_config={
+ "depth": 2,
+ "in_chans": 3,
+ "hidden_act": "gelu_pytorch_tanh",
+ "intermediate_size": 32,
+ "out_hidden_size": 32,
+ "hidden_size": 32,
+ "num_heads": 4,
+ "patch_size": 16,
+ "spatial_merge_size": 1,
+ "temporal_patch_size": 2,
+ "num_position_embeddings": 16,
+ "deepstack_visual_indexes": [0, 1],
+ },
+ image_token_id=3,
+ video_token_id=4,
+ vision_start_token_id=5,
+ vision_end_token_id=6,
+ tie_word_embeddings=True,
+ is_training=True,
+ ):
+ self.parent = parent
+ self.ignore_index = ignore_index
+ self.is_training = is_training
+
+ self.vision_config = vision_config
+ self.text_config = text_config
+
+ self.vocab_size = text_config["vocab_size"]
+ self.bos_token_id = text_config["bos_token_id"]
+ self.eos_token_id = text_config["eos_token_id"]
+ self.pad_token_id = text_config["pad_token_id"]
+ self.head_dim = text_config["head_dim"]
+ self.hidden_size = text_config["hidden_size"]
+ self.intermediate_size = text_config["intermediate_size"]
+ self.num_hidden_layers = text_config["num_hidden_layers"]
+ self.num_attention_heads = text_config["num_attention_heads"]
+ self.num_key_value_heads = text_config["num_key_value_heads"]
+ self.rope_theta = text_config["rope_theta"]
+ self.rope_scaling = text_config["rope_scaling"]
+ self.hidden_act = text_config["hidden_act"]
+ self.max_position_embeddings = text_config["max_position_embeddings"]
+ self.model_type = text_config["model_type"]
+
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.tie_word_embeddings = tie_word_embeddings
+
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.num_image_tokens = 32
+ self.seq_length = seq_length + self.num_image_tokens
+
+ def get_config(self):
+ return Qwen3VLConfig(
+ text_config=self.text_config,
+ vision_config=self.vision_config,
+ image_token_id=self.image_token_id,
+ video_token_id=self.video_token_id,
+ vision_start_token_id=self.vision_start_token_id,
+ vision_end_token_id=self.vision_end_token_id,
+ tie_word_embeddings=self.tie_word_embeddings,
+ )
+
+ def prepare_config_and_inputs(self):
+ config = self.get_config()
+ patch_size = config.vision_config.patch_size
+ temporal_patch_size = config.vision_config.temporal_patch_size
+ pixel_values = floats_tensor(
+ [
+ self.batch_size * (self.image_size**2) // (patch_size**2),
+ self.num_channels * (patch_size**2) * temporal_patch_size,
+ ]
+ )
+
+ return config, pixel_values
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+
+ input_ids[:, -1] = self.pad_token_id
+ input_ids[input_ids == self.video_token_id] = self.pad_token_id
+ input_ids[input_ids == self.image_token_id] = self.pad_token_id
+ input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
+ input_ids[:, self.num_image_tokens] = self.image_token_id
+ input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
+ inputs_dict = {
+ "pixel_values": pixel_values,
+ "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device),
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class Qwen3VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ """
+ Model tester for `Qwen3VLForConditionalGeneration`.
+ """
+
+ all_model_classes = (
+ (
+ Qwen3VLModel,
+ Qwen3VLForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
+ test_pruning = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = Qwen3VLVisionText2TextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=Qwen3VLConfig, has_text_modality=False)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_mismatching_num_image_tokens(self):
+ """
+ Tests that VLMs through an error with explicit message saying what is wrong
+ when number of images don't match number of image tokens in the text.
+ Also we need to test multi-image cases when one prompr has multiple image tokens.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ _ = model(**input_dict) # successful forward with no modifications
+ curr_input_dict = copy.deepcopy(input_dict)
+
+ # remove one image but leave the image token in text
+ patch_size = config.vision_config.patch_size
+ one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-one_img_length:, ...]
+ curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][-1:, ...]
+ with self.assertRaises(ValueError):
+ _ = model(**curr_input_dict)
+
+ # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:one_img_length]
+ image_grid_thw = curr_input_dict["image_grid_thw"][:1]
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
+
+ # one image and two image tokens raise an error
+ with self.assertRaises(ValueError):
+ _ = model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ )
+
+ # two images and two image tokens don't raise an error
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
+ image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
+ _ = model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ )
+
+ def test_video_forward(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ B = self.model_tester.batch_size
+ C = config.vision_config.in_chans
+ T = config.vision_config.temporal_patch_size
+ P = config.vision_config.patch_size
+
+ input_ids = ids_tensor([B, self.model_tester.seq_length], self.model_tester.vocab_size)
+
+ F = 4
+ patch_H = self.model_tester.image_size // P
+ patch_W = self.model_tester.image_size // P
+ patch_T = F // T
+ patches_per_video = patch_T * patch_H * patch_W
+ pathed_per_frame = patch_H * patch_W
+ pixel_values_videos = floats_tensor(
+ [
+ # first dim: batch_size * num_patches
+ B * patches_per_video,
+ # second dim: in_channels * temporal_patch_size * patch_size^2
+ C * T * (P**2),
+ ]
+ )
+
+ # qwen3vl use timestamps for video, so split it into patch_T sub-videos
+ video_grid_thw = torch.tensor([[1, patch_H, patch_W] for _ in range(patch_T)] * B)
+
+ # sanity check
+ self.assertEqual(pixel_values_videos.shape[0], video_grid_thw.prod(dim=1).sum().item())
+
+ # Insert video token sequence
+ input_ids[:, -1] = self.model_tester.pad_token_id
+ input_ids[input_ids == self.model_tester.video_token_id] = self.model_tester.pad_token_id
+ input_ids[input_ids == self.model_tester.image_token_id] = self.model_tester.pad_token_id
+ input_ids[input_ids == self.model_tester.vision_start_token_id] = self.model_tester.pad_token_id
+ input_ids[:, self.model_tester.num_image_tokens] = self.model_tester.video_token_id
+
+ insertion_point = self.model_tester.num_image_tokens
+
+ self.assertLessEqual((B * patches_per_video) + insertion_point, self.model_tester.seq_length)
+ for b in range(B):
+ # each frame is separated by a vision_start_token_id
+ for frame_idx in range(patch_T):
+ input_ids[b, insertion_point + frame_idx * (pathed_per_frame + 1)] = (
+ self.model_tester.vision_start_token_id
+ )
+ input_ids[
+ b,
+ insertion_point + frame_idx * (pathed_per_frame + 1) + 1 : insertion_point
+ + (frame_idx + 1) * (pathed_per_frame + 1),
+ ] = self.model_tester.video_token_id
+
+ for model_class in self.all_model_classes:
+ # TODO:we should remove this because we use timestamps for video
+ model = model_class(config).to(torch_device)
+ outputs = model(
+ input_ids=input_ids,
+ pixel_values_videos=pixel_values_videos,
+ video_grid_thw=video_grid_thw,
+ )
+ self.assertIsNotNone(outputs)
diff --git a/tests/models/qwen3_vl/test_processing_qwen3_vl.py b/tests/models/qwen3_vl/test_processing_qwen3_vl.py
new file mode 100644
index 000000000000..87636dcf607d
--- /dev/null
+++ b/tests/models/qwen3_vl/test_processing_qwen3_vl.py
@@ -0,0 +1,379 @@
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import shutil
+import tempfile
+import unittest
+
+import numpy as np
+import pytest
+
+from transformers import AutoProcessor, Qwen2TokenizerFast
+from transformers.testing_utils import require_av, require_torch, require_torchvision, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_processing_common import ProcessorTesterMixin
+
+
+if is_vision_available():
+ from transformers import Qwen2VLImageProcessorFast, Qwen3VLProcessor
+
+if is_torch_available():
+ import torch
+
+
+@require_vision
+@require_torch
+@require_torchvision
+@unittest.skip("The checkpoint is not yet released")
+class Qwen3VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
+ processor_class = Qwen3VLProcessor
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tmpdirname = tempfile.mkdtemp()
+ processor = Qwen3VLProcessor.from_pretrained(
+ "Qwen/Qwen3-VL-4B-Instruct", patch_size=4, max_pixels=56 * 56, min_pixels=28 * 28
+ )
+ processor.save_pretrained(cls.tmpdirname)
+ cls.image_token = processor.image_token
+
+ def get_tokenizer(self, **kwargs):
+ return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
+
+ def get_image_processor(self, **kwargs):
+ return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
+
+ def get_video_processor(self, **kwargs):
+ return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
+
+ def get_processor(self, **kwargs):
+ return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tmpdirname, ignore_errors=True)
+
+ # Copied from tests.models.llava.test_processing_llava.LlavaProcessorTest.test_get_num_vision_tokens
+ def test_get_num_vision_tokens(self):
+ "Tests general functionality of the helper used internally in vLLM"
+
+ processor = self.get_processor()
+
+ output = processor._get_num_multimodal_tokens(image_sizes=[(100, 100), (300, 100), (500, 30)])
+ self.assertTrue("num_image_tokens" in output)
+ self.assertEqual(len(output["num_image_tokens"]), 3)
+
+ self.assertTrue("num_image_patches" in output)
+ self.assertEqual(len(output["num_image_patches"]), 3)
+
+ def test_save_load_pretrained_default(self):
+ tokenizer = self.get_tokenizer()
+ image_processor = self.get_image_processor()
+ video_processor = self.get_video_processor()
+
+ processor = Qwen3VLProcessor(
+ tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
+ )
+ processor.save_pretrained(self.tmpdirname)
+ processor = Qwen3VLProcessor.from_pretrained(self.tmpdirname, use_fast=True)
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
+ self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string())
+ self.assertIsInstance(processor.tokenizer, Qwen2TokenizerFast)
+ self.assertIsInstance(processor.image_processor, Qwen2VLImageProcessorFast)
+
+ def test_image_processor(self):
+ image_processor = self.get_image_processor()
+ tokenizer = self.get_tokenizer()
+ video_processor = self.get_video_processor()
+
+ processor = Qwen3VLProcessor(
+ tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
+ )
+
+ image_input = self.prepare_image_inputs()
+
+ input_image_proc = image_processor(image_input, return_tensors="pt")
+ input_processor = processor(images=image_input, text="dummy", return_tensors="pt")
+
+ for key in input_image_proc:
+ self.assertAlmostEqual(input_image_proc[key].sum(), input_processor[key].sum(), delta=1e-2)
+
+ def test_processor(self):
+ image_processor = self.get_image_processor()
+ tokenizer = self.get_tokenizer()
+ video_processor = self.get_video_processor()
+
+ processor = Qwen3VLProcessor(
+ tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
+ )
+
+ input_str = "lower newer"
+ image_input = self.prepare_image_inputs()
+ inputs = processor(text=input_str, images=image_input)
+
+ self.assertListEqual(
+ list(inputs.keys()),
+ ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"],
+ )
+
+ # test if it raises when no input is passed
+ with pytest.raises(ValueError):
+ processor()
+
+ # test if it raises when no text is passed
+ with pytest.raises(TypeError):
+ processor(images=image_input)
+
+ def test_model_input_names(self):
+ image_processor = self.get_image_processor()
+ tokenizer = self.get_tokenizer()
+ video_processor = self.get_video_processor()
+
+ processor = Qwen3VLProcessor(
+ tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
+ )
+
+ input_str = "lower newer"
+ image_input = self.prepare_image_inputs()
+ video_inputs = self.prepare_video_inputs()
+
+ inputs = processor(text=input_str, images=image_input, videos=video_inputs, do_sample_frames=False)
+
+ self.assertListEqual(list(inputs.keys()), processor.model_input_names)
+
+ @require_torch
+ @require_av
+ def _test_apply_chat_template(
+ self,
+ modality: str,
+ batch_size: int,
+ return_tensors: str,
+ input_name: str,
+ processor_name: str,
+ input_data: list[str],
+ ):
+ processor = self.get_processor()
+ if processor.chat_template is None:
+ self.skipTest("Processor has no chat template")
+
+ if processor_name not in self.processor_class.attributes:
+ self.skipTest(f"{processor_name} attribute not present in {self.processor_class}")
+
+ batch_messages = [
+ [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": "Describe this."}],
+ },
+ ]
+ ] * batch_size
+
+ # Test that jinja can be applied
+ formatted_prompt = processor.apply_chat_template(batch_messages, add_generation_prompt=True, tokenize=False)
+ self.assertEqual(len(formatted_prompt), batch_size)
+
+ # Test that tokenizing with template and directly with `self.tokenizer` gives same output
+ formatted_prompt_tokenized = processor.apply_chat_template(
+ batch_messages, add_generation_prompt=True, tokenize=True, return_tensors=return_tensors
+ )
+ add_special_tokens = True
+ if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
+ add_special_tokens = False
+ tok_output = processor.tokenizer(
+ formatted_prompt, return_tensors=return_tensors, add_special_tokens=add_special_tokens
+ )
+ expected_output = tok_output.input_ids
+ self.assertListEqual(expected_output.tolist(), formatted_prompt_tokenized.tolist())
+
+ # Test that kwargs passed to processor's `__call__` are actually used
+ tokenized_prompt_100 = processor.apply_chat_template(
+ batch_messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ padding="max_length",
+ truncation=True,
+ return_tensors=return_tensors,
+ max_length=100,
+ )
+ self.assertEqual(len(tokenized_prompt_100[0]), 100)
+
+ # Test that `return_dict=True` returns text related inputs in the dict
+ out_dict_text = processor.apply_chat_template(
+ batch_messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors=return_tensors,
+ )
+ self.assertTrue(all(key in out_dict_text for key in ["input_ids", "attention_mask"]))
+ self.assertEqual(len(out_dict_text["input_ids"]), batch_size)
+ self.assertEqual(len(out_dict_text["attention_mask"]), batch_size)
+
+ # Test that with modality URLs and `return_dict=True`, we get modality inputs in the dict
+ for idx, url in enumerate(input_data[:batch_size]):
+ batch_messages[idx][0]["content"] = [batch_messages[idx][0]["content"][0], {"type": modality, "url": url}]
+
+ out_dict = processor.apply_chat_template(
+ batch_messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors=return_tensors,
+ max_frames=2, # by default no more than 2 frames, otherwise too slow
+ )
+ input_name = getattr(self, input_name)
+ self.assertTrue(input_name in out_dict)
+ self.assertEqual(len(out_dict["input_ids"]), batch_size)
+ self.assertEqual(len(out_dict["attention_mask"]), batch_size)
+
+ if modality == "video":
+ # qwen pixels don't scale with bs same way as other models, calculate expected video token count based on video_grid_thw
+ expected_video_token_count = 0
+ for thw in out_dict["video_grid_thw"]:
+ expected_video_token_count += thw[0] * thw[1] * thw[2]
+ mm_len = expected_video_token_count
+ else:
+ mm_len = batch_size * 192
+ self.assertEqual(len(out_dict[input_name]), mm_len)
+
+ return_tensor_to_type = {"pt": torch.Tensor, "np": np.ndarray, None: list}
+ for k in out_dict:
+ self.assertIsInstance(out_dict[k], return_tensor_to_type[return_tensors])
+
+ @require_av
+ @unittest.skip("qwen3_vl can't sample frames from image frames directly, user can use `qwen-vl-utils`")
+ def test_apply_chat_template_video_1(self):
+ pass
+
+ @require_av
+ @unittest.skip("qwen3_vl can't sample frames from image frames directly, user can use `qwen-vl-utils`")
+ def test_apply_chat_template_video_2(self):
+ pass
+
+ @require_av
+ def test_apply_chat_template_video_frame_sampling(self):
+ processor = self.get_processor()
+ if processor.chat_template is None:
+ self.skipTest("Processor has no chat template")
+
+ signature = inspect.signature(processor.__call__)
+ if "videos" not in {*signature.parameters.keys()} or (
+ signature.parameters.get("videos") is not None
+ and signature.parameters["videos"].annotation == inspect._empty
+ ):
+ self.skipTest("Processor doesn't accept videos at input")
+
+ messages = [
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "video"},
+ {"type": "text", "text": "What is shown in this video?"},
+ ],
+ },
+ ]
+ ]
+
+ formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
+ self.assertEqual(len(formatted_prompt), 1)
+
+ formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
+ expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
+ self.assertListEqual(expected_output, formatted_prompt_tokenized)
+
+ out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
+ self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
+
+ # Add video URL for return dict and load with `num_frames` arg
+ messages[0][0]["content"][0] = {
+ "type": "video",
+ "url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
+ }
+ num_frames = 3
+ out_dict_with_video = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ num_frames=num_frames,
+ )
+ self.assertTrue(self.videos_input_name in out_dict_with_video)
+ self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 360)
+
+ # Load with `fps` arg
+ fps = 1
+ out_dict_with_video = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ fps=fps,
+ )
+ self.assertTrue(self.videos_input_name in out_dict_with_video)
+ self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 900)
+
+ # Load with `fps` and `num_frames` args, should raise an error
+ with self.assertRaises(ValueError):
+ out_dict_with_video = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ fps=fps,
+ num_frames=num_frames,
+ )
+
+ # Load without any arg should load the whole video
+ out_dict_with_video = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ )
+ self.assertTrue(self.videos_input_name in out_dict_with_video)
+ self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 27000)
+
+ # Load video as a list of frames (i.e. images). NOTE: each frame should have same size
+ # because we assume they come from one video
+ messages[0][0]["content"][0] = {
+ "type": "video",
+ "url": [
+ "https://www.ilankelman.org/stopsigns/australia.jpg",
+ "https://www.ilankelman.org/stopsigns/australia.jpg",
+ ],
+ }
+ out_dict_with_video = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ do_sample_frames=False,
+ )
+ self.assertTrue(self.videos_input_name in out_dict_with_video)
+ self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 160)
+
+ def test_kwargs_overrides_custom_image_processor_kwargs(self):
+ processor = self.get_processor()
+ self.skip_processor_without_typed_kwargs(processor)
+
+ input_str = self.prepare_text_inputs()
+ image_input = self.prepare_image_inputs()
+ inputs = processor(text=input_str, images=image_input, max_pixels=56 * 56 * 4, return_tensors="pt")
+ self.assertEqual(inputs[self.images_input_name].shape[0], 612)
+ inputs = processor(text=input_str, images=image_input, return_tensors="pt")
+ self.assertEqual(inputs[self.images_input_name].shape[0], 100)
diff --git a/tests/models/qwen3_vl/test_video_processing_qwen3_vl.py b/tests/models/qwen3_vl/test_video_processing_qwen3_vl.py
new file mode 100644
index 000000000000..9230f0f9502e
--- /dev/null
+++ b/tests/models/qwen3_vl/test_video_processing_qwen3_vl.py
@@ -0,0 +1,330 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+
+from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
+
+from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
+
+
+if is_torch_available():
+ from PIL import Image
+
+if is_vision_available() and is_torchvision_available():
+ from transformers import Qwen3VLVideoProcessor
+ from transformers.models.qwen3_vl.video_processing_qwen3_vl import smart_resize
+
+
+class Qwen3VLVideoProcessingTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=5,
+ num_frames=8,
+ num_channels=3,
+ min_resolution=32,
+ max_resolution=80,
+ temporal_patch_size=2,
+ patch_size=16,
+ merge_size=2,
+ do_resize=True,
+ size=None,
+ do_normalize=True,
+ image_mean=IMAGENET_STANDARD_MEAN,
+ image_std=IMAGENET_STANDARD_STD,
+ do_convert_rgb=True,
+ ):
+ size = size if size is not None else {"longest_edge": 20, "shortest_edge": 10}
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_frames = num_frames
+ self.num_channels = num_channels
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.do_convert_rgb = do_convert_rgb
+ self.temporal_patch_size = temporal_patch_size
+ self.patch_size = patch_size
+ self.merge_size = merge_size
+
+ def prepare_video_processor_dict(self):
+ return {
+ "do_resize": self.do_resize,
+ "size": self.size,
+ "do_normalize": self.do_normalize,
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_convert_rgb": self.do_convert_rgb,
+ "do_sample_frames": True,
+ }
+
+ def prepare_video_metadata(self, videos):
+ video_metadata = []
+ for video in videos:
+ if isinstance(video, list):
+ num_frames = len(video)
+ elif hasattr(video, "shape"):
+ if len(video.shape) == 4: # (T, H, W, C)
+ num_frames = video.shape[0]
+ else:
+ num_frames = 1
+ else:
+ num_frames = self.num_frames
+
+ metadata = {
+ "fps": 2,
+ "duration": num_frames / 2,
+ "total_num_frames": num_frames,
+ }
+ video_metadata.append(metadata)
+ return video_metadata
+
+ def expected_output_video_shape(self, videos):
+ grid_t = self.num_frames // self.temporal_patch_size
+ hidden_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size
+ seq_len = 0
+ for video in videos:
+ if isinstance(video, list) and isinstance(video[0], Image.Image):
+ video = np.stack([np.array(frame) for frame in video])
+ elif hasattr(video, "shape"):
+ pass
+ else:
+ video = np.array(video)
+
+ if hasattr(video, "shape") and len(video.shape) >= 3:
+ if len(video.shape) == 4:
+ t, height, width = video.shape[:3]
+ elif len(video.shape) == 3:
+ height, width = video.shape[:2]
+ t = 1
+ else:
+ t, height, width = self.num_frames, self.min_resolution, self.min_resolution
+ else:
+ t, height, width = self.num_frames, self.min_resolution, self.min_resolution
+
+ resized_height, resized_width = smart_resize(
+ t,
+ height,
+ width,
+ factor=self.patch_size * self.merge_size,
+ min_pixels=self.size["shortest_edge"],
+ max_pixels=self.size["longest_edge"],
+ )
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
+ seq_len += grid_t * grid_h * grid_w
+ return [seq_len, hidden_dim]
+
+ def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
+ videos = prepare_video_inputs(
+ batch_size=self.batch_size,
+ num_frames=self.num_frames,
+ num_channels=self.num_channels,
+ min_resolution=self.min_resolution,
+ max_resolution=self.max_resolution,
+ equal_resolution=equal_resolution,
+ return_tensors=return_tensors,
+ )
+ return videos
+
+
+@require_torch
+@require_vision
+class Qwen3VLVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
+ fast_video_processing_class = Qwen3VLVideoProcessor if is_torchvision_available() else None
+ input_name = "pixel_values_videos"
+
+ def setUp(self):
+ super().setUp()
+ self.video_processor_tester = Qwen3VLVideoProcessingTester(self)
+
+ @property
+ def video_processor_dict(self):
+ return self.video_processor_tester.prepare_video_processor_dict()
+
+ def test_video_processor_from_dict_with_kwargs(self):
+ video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict)
+ self.assertEqual(video_processor.size, {"longest_edge": 20, "shortest_edge": 10})
+
+ video_processor = self.fast_video_processing_class.from_dict(
+ self.video_processor_dict, size={"longest_edge": 42, "shortest_edge": 42}
+ )
+ self.assertEqual(video_processor.size, {"longest_edge": 42, "shortest_edge": 42})
+
+ def test_call_pil(self):
+ for video_processing_class in self.video_processor_list:
+ video_processing = video_processing_class(**self.video_processor_dict)
+ video_inputs = self.video_processor_tester.prepare_video_inputs(
+ equal_resolution=False, return_tensors="pil"
+ )
+
+ for video in video_inputs:
+ self.assertIsInstance(video[0], Image.Image)
+
+ video_metadata = self.video_processor_tester.prepare_video_metadata(video_inputs)
+ encoded_videos = video_processing(
+ video_inputs[0], video_metadata=[video_metadata[0]], return_tensors="pt"
+ )[self.input_name]
+ expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
+ self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
+ encoded_videos = video_processing(video_inputs, video_metadata=video_metadata, return_tensors="pt")[
+ self.input_name
+ ]
+ expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
+ self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
+
+ def test_call_numpy(self):
+ for video_processing_class in self.video_processor_list:
+ video_processing = video_processing_class(**self.video_processor_dict)
+ video_inputs = self.video_processor_tester.prepare_video_inputs(
+ equal_resolution=False, return_tensors="np"
+ )
+
+ video_metadata = self.video_processor_tester.prepare_video_metadata(video_inputs)
+ encoded_videos = video_processing(
+ video_inputs[0], video_metadata=[video_metadata[0]], return_tensors="pt"
+ )[self.input_name]
+ expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
+ self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
+
+ encoded_videos = video_processing(video_inputs, video_metadata=video_metadata, return_tensors="pt")[
+ self.input_name
+ ]
+ expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
+ self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
+
+ def test_call_pytorch(self):
+ for video_processing_class in self.video_processor_list:
+ video_processing = video_processing_class(**self.video_processor_dict)
+ video_inputs = self.video_processor_tester.prepare_video_inputs(
+ equal_resolution=False, return_tensors="pt"
+ )
+ video_metadata = self.video_processor_tester.prepare_video_metadata(video_inputs)
+ encoded_videos = video_processing(
+ video_inputs[0], video_metadata=[video_metadata[0]], return_tensors="pt"
+ )[self.input_name]
+ expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
+ self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
+ encoded_videos = video_processing(video_inputs, video_metadata=video_metadata, return_tensors="pt")[
+ self.input_name
+ ]
+ expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
+ self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
+
+ @unittest.skip("Skip for now, the test needs adjustment for Qwen3VL")
+ def test_call_numpy_4_channels(self):
+ for video_processing_class in self.video_processor_list:
+ # Test that can process videos which have an arbitrary number of channels
+ # Initialize video_processing
+ video_processor = video_processing_class(**self.video_processor_dict)
+
+ # create random numpy tensors
+ self.video_processor_tester.num_channels = 4
+ video_inputs = self.video_processor_tester.prepare_video_inputs(
+ equal_resolution=False, return_tensors="np"
+ )
+
+ # Test not batched input
+ encoded_videos = video_processor(
+ video_inputs[0],
+ return_tensors="pt",
+ input_data_format="channels_last",
+ image_mean=0,
+ image_std=1,
+ )[self.input_name]
+ expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
+ self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
+
+ # Test batched
+ encoded_videos = video_processor(
+ video_inputs,
+ return_tensors="pt",
+ input_data_format="channels_last",
+ image_mean=0,
+ image_std=1,
+ )[self.input_name]
+ expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
+ self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
+
+ def test_nested_input(self):
+ """Tests that the processor can work with nested list where each video is a list of arrays"""
+ for video_processing_class in self.video_processor_list:
+ video_processing = video_processing_class(**self.video_processor_dict)
+ video_inputs = self.video_processor_tester.prepare_video_inputs(
+ equal_resolution=False, return_tensors="np"
+ )
+
+ video_inputs_nested = [list(video) for video in video_inputs]
+ video_metadata = self.video_processor_tester.prepare_video_metadata(video_inputs)
+
+ # Test not batched input
+ encoded_videos = video_processing(
+ video_inputs_nested[0], video_metadata=[video_metadata[0]], return_tensors="pt"
+ )[self.input_name]
+ expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
+ self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
+
+ # Test batched
+ encoded_videos = video_processing(video_inputs_nested, video_metadata=video_metadata, return_tensors="pt")[
+ self.input_name
+ ]
+ expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
+ self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
+
+ def test_call_sample_frames(self):
+ for video_processing_class in self.video_processor_list:
+ video_processor_dict = self.video_processor_dict.copy()
+ video_processing = video_processing_class(**video_processor_dict)
+
+ prev_num_frames = self.video_processor_tester.num_frames
+ self.video_processor_tester.num_frames = 8
+ prev_min_resolution = getattr(self.video_processor_tester, "min_resolution", None)
+ prev_max_resolution = getattr(self.video_processor_tester, "max_resolution", None)
+ self.video_processor_tester.min_resolution = 56
+ self.video_processor_tester.max_resolution = 112
+
+ video_inputs = self.video_processor_tester.prepare_video_inputs(
+ equal_resolution=False,
+ return_tensors="torch",
+ )
+
+ metadata = [[{"total_num_frames": 8, "fps": 4}]]
+ batched_metadata = metadata * len(video_inputs)
+
+ encoded_videos = video_processing(video_inputs[0], return_tensors="pt", video_metadata=metadata)[
+ self.input_name
+ ]
+ encoded_videos_batched = video_processing(
+ video_inputs, return_tensors="pt", video_metadata=batched_metadata
+ )[self.input_name]
+
+ self.assertIsNotNone(encoded_videos)
+ self.assertIsNotNone(encoded_videos_batched)
+ self.assertEqual(len(encoded_videos.shape), 2)
+ self.assertEqual(len(encoded_videos_batched.shape), 2)
+
+ self.video_processor_tester.num_frames = prev_num_frames
+ if prev_min_resolution is not None:
+ self.video_processor_tester.min_resolution = prev_min_resolution
+ if prev_max_resolution is not None:
+ self.video_processor_tester.max_resolution = prev_max_resolution
diff --git a/tests/models/qwen3_vl_moe/__init__.py b/tests/models/qwen3_vl_moe/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py
new file mode 100644
index 000000000000..adae69a81fa8
--- /dev/null
+++ b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py
@@ -0,0 +1,298 @@
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Qwen3VLMoe model."""
+
+import copy
+import unittest
+
+from transformers import (
+ Qwen3VLMoeConfig,
+ Qwen3VLMoeForConditionalGeneration,
+ Qwen3VLMoeModel,
+ is_torch_available,
+)
+from transformers.testing_utils import (
+ require_torch,
+ torch_device,
+)
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ floats_tensor,
+ ids_tensor,
+)
+
+
+if is_torch_available():
+ import torch
+
+
+class Qwen3VLMoeVisionText2TextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=3,
+ seq_length=7,
+ num_channels=3,
+ ignore_index=-100,
+ image_size=16,
+ text_config={
+ "bos_token_id": 0,
+ "eos_token_id": 1,
+ "pad_token_id": 2,
+ "hidden_act": "silu",
+ "hidden_size": 32,
+ "vocab_size": 99,
+ "intermediate_size": 37,
+ "max_position_embeddings": 512,
+ "model_type": "qwen3_vl_moe",
+ "num_attention_heads": 4,
+ "num_key_value_heads": 2,
+ "num_hidden_layers": 4,
+ "moe_intermediate_size": 16,
+ "num_experts_per_tok": 4,
+ "num_experts": 8,
+ "rope_theta": 10000,
+ "tie_word_embeddings": True,
+ "rope_scaling": {"rope_type": "default", "mrope_section": [16, 8, 8], "mrope_interleaved": True},
+ },
+ vision_config={
+ "depth": 2,
+ "in_chans": 3,
+ "hidden_act": "gelu_pytorch_tanh",
+ "intermediate_size": 32,
+ "out_hidden_size": 32,
+ "hidden_size": 32,
+ "num_heads": 4,
+ "patch_size": 16,
+ "spatial_merge_size": 1,
+ "temporal_patch_size": 2,
+ "num_position_embeddings": 16,
+ "deepstack_visual_indexes": [0, 1],
+ },
+ image_token_id=3,
+ video_token_id=4,
+ vision_start_token_id=5,
+ vision_end_token_id=6,
+ tie_word_embeddings=True,
+ is_training=True,
+ ):
+ self.parent = parent
+ self.ignore_index = ignore_index
+ self.is_training = is_training
+
+ self.vision_config = vision_config
+ self.text_config = text_config
+
+ self.vocab_size = text_config["vocab_size"]
+ self.bos_token_id = text_config["bos_token_id"]
+ self.eos_token_id = text_config["eos_token_id"]
+ self.pad_token_id = text_config["pad_token_id"]
+ self.hidden_size = text_config["hidden_size"]
+ self.intermediate_size = text_config["intermediate_size"]
+ self.num_hidden_layers = text_config["num_hidden_layers"]
+ self.num_attention_heads = text_config["num_attention_heads"]
+ self.num_key_value_heads = text_config["num_key_value_heads"]
+ self.rope_theta = text_config["rope_theta"]
+ self.rope_scaling = text_config["rope_scaling"]
+ self.hidden_act = text_config["hidden_act"]
+ self.max_position_embeddings = text_config["max_position_embeddings"]
+ self.model_type = text_config["model_type"]
+
+ self.vision_start_token_id = vision_start_token_id
+ self.vision_end_token_id = vision_end_token_id
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.tie_word_embeddings = tie_word_embeddings
+
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.num_image_tokens = 32
+ self.seq_length = seq_length + self.num_image_tokens
+
+ def get_config(self):
+ return Qwen3VLMoeConfig(
+ text_config=self.text_config,
+ vision_config=self.vision_config,
+ image_token_id=self.image_token_id,
+ video_token_id=self.video_token_id,
+ vision_start_token_id=self.vision_start_token_id,
+ vision_end_token_id=self.vision_end_token_id,
+ tie_word_embeddings=self.tie_word_embeddings,
+ )
+
+ def prepare_config_and_inputs(self):
+ config = self.get_config()
+ patch_size = config.vision_config.patch_size
+ temporal_patch_size = config.vision_config.temporal_patch_size
+ pixel_values = floats_tensor(
+ [
+ self.batch_size * (self.image_size**2) // (patch_size**2),
+ self.num_channels * (patch_size**2) * temporal_patch_size,
+ ]
+ )
+
+ return config, pixel_values
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+
+ input_ids[:, -1] = self.pad_token_id
+ input_ids[input_ids == self.video_token_id] = self.pad_token_id
+ input_ids[input_ids == self.image_token_id] = self.pad_token_id
+ input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
+ input_ids[:, self.num_image_tokens] = self.image_token_id
+ input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
+ inputs_dict = {
+ "pixel_values": pixel_values,
+ "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device),
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class Qwen3VLMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ """
+ Model tester for `Qwen3VLMoeForConditionalGeneration`.
+ """
+
+ all_model_classes = (
+ (
+ Qwen3VLMoeModel,
+ Qwen3VLMoeForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
+ test_pruning = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = Qwen3VLMoeVisionText2TextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=Qwen3VLMoeConfig, has_text_modality=False)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_mismatching_num_image_tokens(self):
+ """
+ Tests that VLMs through an error with explicit message saying what is wrong
+ when number of images don't match number of image tokens in the text.
+ Also we need to test multi-image cases when one prompr has multiple image tokens.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ _ = model(**input_dict) # successful forward with no modifications
+ curr_input_dict = copy.deepcopy(input_dict)
+
+ # remove one image but leave the image token in text
+ patch_size = config.vision_config.patch_size
+ one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-one_img_length:, ...]
+ curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][-1:, ...]
+ with self.assertRaises(ValueError):
+ _ = model(**curr_input_dict)
+
+ # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:one_img_length]
+ image_grid_thw = curr_input_dict["image_grid_thw"][:1]
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
+
+ # one image and two image tokens raise an error
+ with self.assertRaises(ValueError):
+ _ = model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ )
+
+ # two images and two image tokens don't raise an error
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
+ image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
+ _ = model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ )
+
+ def test_video_forward(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ B = self.model_tester.batch_size
+ C = config.vision_config.in_chans
+ T = config.vision_config.temporal_patch_size
+ P = config.vision_config.patch_size
+
+ input_ids = ids_tensor([B, self.model_tester.seq_length], self.model_tester.vocab_size)
+
+ F = 4
+ patch_H = self.model_tester.image_size // P
+ patch_W = self.model_tester.image_size // P
+ patch_T = F // T
+ patches_per_video = patch_T * patch_H * patch_W
+ pathed_per_frame = patch_H * patch_W
+ pixel_values_videos = floats_tensor(
+ [
+ # first dim: batch_size * num_patches
+ B * patches_per_video,
+ # second dim: in_channels * temporal_patch_size * patch_size^2
+ C * T * (P**2),
+ ]
+ )
+ video_grid_thw = torch.tensor([[1, patch_H, patch_W] for _ in range(patch_T)] * B)
+
+ # sanity check
+ self.assertEqual(pixel_values_videos.shape[0], video_grid_thw.prod(dim=1).sum().item())
+
+ # Insert video token sequence
+ input_ids[:, -1] = self.model_tester.pad_token_id
+ input_ids[input_ids == self.model_tester.video_token_id] = self.model_tester.pad_token_id
+ input_ids[input_ids == self.model_tester.image_token_id] = self.model_tester.pad_token_id
+ input_ids[input_ids == self.model_tester.vision_start_token_id] = self.model_tester.pad_token_id
+ input_ids[:, self.model_tester.num_image_tokens] = self.model_tester.video_token_id
+
+ insertion_point = self.model_tester.num_image_tokens
+
+ self.assertLessEqual((B * patches_per_video) + insertion_point, self.model_tester.seq_length)
+ for b in range(B):
+ # each frame is separated by a vision_start_token_id
+ for frame_idx in range(patch_T):
+ input_ids[b, insertion_point + frame_idx * (pathed_per_frame + 1)] = (
+ self.model_tester.vision_start_token_id
+ )
+ input_ids[
+ b,
+ insertion_point + frame_idx * (pathed_per_frame + 1) + 1 : insertion_point
+ + (frame_idx + 1) * (pathed_per_frame + 1),
+ ] = self.model_tester.video_token_id
+
+ for model_class in self.all_model_classes:
+ # TODO:we should remove this because we use timestamps for video
+ model = model_class(config).to(torch_device)
+ outputs = model(
+ input_ids=input_ids,
+ pixel_values_videos=pixel_values_videos,
+ video_grid_thw=video_grid_thw,
+ )
+ self.assertIsNotNone(outputs)
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 8a73468a1e49..e932e5bfc24c 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -71,6 +71,8 @@
"Qwen2AudioEncoder",
"Qwen2VisionTransformerPretrainedModel",
"Qwen2_5_VisionTransformerPretrainedModel",
+ "Qwen3VLVisionModel",
+ "Qwen3VLMoeVisionModel",
"SwitchTransformersStack",
"TFDPRSpanPredictor",
"MaskFormerSwinModel",
@@ -151,13 +153,17 @@
"ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model
"Qwen2VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2VLForConditionalGeneration.
"Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration.
- "Qwen2_5OmniForConditionalGeneration", # Not a regular model. Testted in Qwen2_5OmniModelIntegrationTest
- "Qwen2_5OmniTalkerForConditionalGeneration", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntegrationTest.
- "Qwen2_5OmniTalkerModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntegrationTest.
- "Qwen2_5OmniThinkerTextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntegrationTest.
- "Qwen2_5OmniToken2WavModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntegrationTest.
- "Qwen2_5OmniToken2WavDiTModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntegrationTest.
- "Qwen2_5OmniToken2WavBigVGANModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntegrationTest.
+ "Qwen3VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3VLForConditionalGeneration.
+ "Qwen3VLMoeModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3VLMoeForConditionalGeneration.
+ "Qwen3VLTextModel", # Building part of bigger (tested) model.
+ "Qwen3VLMoeTextModel", # Building part of bigger (tested) model.
+ "Qwen2_5OmniForConditionalGeneration", # Not a regular model. Testted in Qwen2_5OmniModelIntergrationTest
+ "Qwen2_5OmniTalkerForConditionalGeneration", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
+ "Qwen2_5OmniTalkerModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
+ "Qwen2_5OmniThinkerTextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
+ "Qwen2_5OmniToken2WavModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
+ "Qwen2_5OmniToken2WavDiTModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
+ "Qwen2_5OmniToken2WavBigVGANModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
"MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests
"MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests
"Llama4TextModel", # Building part of bigger (tested) model. # TODO: add tests