diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 240d912f10b6..07b7ef4e72fe 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1125,6 +1125,10 @@ title: Qwen2Audio - local: model_doc/qwen2_vl title: Qwen2VL + - local: model_doc/qwen3_vl + title: Qwen3VL + - local: model_doc/qwen3_vl_moe + title: Qwen3VLMoe - local: model_doc/sam2 title: SAM2 - local: model_doc/sam2_video diff --git a/docs/source/en/model_doc/qwen3_vl.md b/docs/source/en/model_doc/qwen3_vl.md new file mode 100644 index 000000000000..9e90363a1eba --- /dev/null +++ b/docs/source/en/model_doc/qwen3_vl.md @@ -0,0 +1,117 @@ + +*This model was released on None and added to Hugging Face Transformers on 2025-08-16.* + +
+
+PyTorch +FlashAttention +SDPA
+
+ +# Qwen3-VL + +[Qwen3-VL](https://huggingface.co/papers/2502.13923) is a multimodal vision-language model series, encompassing both dense and MoE variants, as well as Instruct and Thinking versions. Building upon its predecessors, Qwen3-VL delivers significant improvements in visual understanding while maintaining strong pure text capabilities. Key architectural advancements include: enhanced MRope with interleaved layout for better spatial-temporal modeling, DeepStack integration to effectively leverage multi-level features from the Vision Transformer (ViT), and improved video understanding through text-based time alignment—evolving from T-RoPE to text timestamp alignment for more precise temporal grounding. These innovations collectively enable Qwen3-VL to achieve superior performance in complex multimodal tasks. + +Model usage + + + + +```py +import torch +from transformers import Qwen3VLForConditionalGeneration, AutoProcessor + +model = Qwen3VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen3-VL", + dtype=torch.float16, + device_map="auto", + attn_implementation="sdpa" +) +processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL") +messages = [ + { + "role":"user", + "content":[ + { + "type":"image", + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + }, + { + "type":"text", + "text":"Describe this image." + } + ] + } + +] + +inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", +) +inputs.pop("token_type_ids", None) + +generated_ids = model.generate(**inputs, max_new_tokens=128) +generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) +] +output_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False +) +print(output_text) +``` + + + +## Qwen3VLConfig + +[[autodoc]] Qwen3VLConfig + +## Qwen3VLTextConfig + +[[autodoc]] Qwen3VLTextConfig + +## Qwen3VLProcessor + +[[autodoc]] Qwen3VLProcessor + +## Qwen3VLVideoProcessor + +[[autodoc]] Qwen3VLVideoProcessor + +## Qwen3VLVisionModel + +[[autodoc]] Qwen3VLVisionModel + - forward + +## Qwen3VLTextModel + +[[autodoc]] Qwen3VLTextModel + - forward + +## Qwen3VLModel + +[[autodoc]] Qwen3VLModel + - forward + +## Qwen3VLForConditionalGeneration + +[[autodoc]] Qwen3VLForConditionalGeneration + - forward diff --git a/docs/source/en/model_doc/qwen3_vl_moe.md b/docs/source/en/model_doc/qwen3_vl_moe.md new file mode 100644 index 000000000000..76d046efff2d --- /dev/null +++ b/docs/source/en/model_doc/qwen3_vl_moe.md @@ -0,0 +1,109 @@ + +*This model was released on None and added to Hugging Face Transformers on 2025-08-17.* + +
+
+PyTorch +FlashAttention +SDPA
+
+ +# Qwen3-VL-Moe + +[Qwen3-VL](https://huggingface.co/papers/2502.13923) is a multimodal vision-language model series, encompassing both dense and MoE variants, as well as Instruct and Thinking versions. Building upon its predecessors, Qwen3-VL delivers significant improvements in visual understanding while maintaining strong pure text capabilities. Key architectural advancements include: enhanced MRope with interleaved layout for better spatial-temporal modeling, DeepStack integration to effectively leverage multi-level features from the Vision Transformer (ViT), and improved video understanding through text-based time alignment—evolving from T-RoPE to text timestamp alignment for more precise temporal grounding. These innovations collectively enable Qwen3-VL to achieve superior performance in complex multimodal tasks. + +Model usage + + + + +```py +import torch +from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor + +model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + "Qwen/Qwen3-VL-Moe", + dtype=torch.float16, + device_map="auto", + attn_implementation="sdpa" +) +processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-Moe") +messages = [ + { + "role":"user", + "content":[ + { + "type":"image", + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + }, + { + "type":"text", + "text":"Describe this image." + } + ] + } + +] + +inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", +) +inputs.pop("token_type_ids", None) + +generated_ids = model.generate(**inputs, max_new_tokens=128) +generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) +] +output_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False +) +print(output_text) +``` + + + +## Qwen3VLMoeConfig + +[[autodoc]] Qwen3VLMoeConfig + +## Qwen3VLMoeTextConfig + +[[autodoc]] Qwen3VLMoeTextConfig + +## Qwen3VLMoeVisionModel + +[[autodoc]] Qwen3VLMoeVisionModel + - forward + +## Qwen3VLMoeTextModel + +[[autodoc]] Qwen3VLMoeTextModel + - forward + +## Qwen3VLMoeModel + +[[autodoc]] Qwen3VLMoeModel + - forward + +## Qwen3VLMoeForConditionalGeneration + +[[autodoc]] Qwen3VLMoeForConditionalGeneration + - forward diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 1e70e0e7b4d7..9a631255ff9b 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -278,6 +278,8 @@ from .qwen3 import * from .qwen3_moe import * from .qwen3_next import * + from .qwen3_vl import * + from .qwen3_vl_moe import * from .rag import * from .recurrent_gemma import * from .reformer import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 0d6981d685ed..c3aa38de4b31 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -325,6 +325,10 @@ ("qwen3", "Qwen3Config"), ("qwen3_moe", "Qwen3MoeConfig"), ("qwen3_next", "Qwen3NextConfig"), + ("qwen3_vl", "Qwen3VLConfig"), + ("qwen3_vl_moe", "Qwen3VLMoeConfig"), + ("qwen3_vl_moe_text", "Qwen3VLMoeTextConfig"), + ("qwen3_vl_text", "Qwen3VLTextConfig"), ("rag", "RagConfig"), ("realm", "RealmConfig"), ("recurrent_gemma", "RecurrentGemmaConfig"), @@ -763,6 +767,10 @@ ("qwen3", "Qwen3"), ("qwen3_moe", "Qwen3MoE"), ("qwen3_next", "Qwen3Next"), + ("qwen3_vl", "Qwen3VL"), + ("qwen3_vl_moe", "Qwen3VLMoe"), + ("qwen3_vl_moe_text", "Qwen3VLMoe"), + ("qwen3_vl_text", "Qwen3VL"), ("rag", "RAG"), ("realm", "REALM"), ("recurrent_gemma", "RecurrentGemma"), @@ -950,6 +958,8 @@ ("internvl_vision", "internvl"), ("qwen2_5_vl_text", "qwen2_5_vl"), ("qwen2_vl_text", "qwen2_vl"), + ("qwen3_vl_text", "qwen3_vl"), + ("qwen3_vl_moe_text", "qwen3_vl_moe"), ("sam_vision_model", "sam"), ("sam2_vision_model", "sam2"), ("sam2_hiera_det_model", "sam2"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 7d07ca6dc7d6..193e8f8fd940 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -156,6 +156,7 @@ ("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")), ("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), ("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), + ("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), ("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index a4b9434f24b9..2109e487328e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -319,6 +319,10 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("qwen3", "Qwen3Model"), ("qwen3_moe", "Qwen3MoeModel"), ("qwen3_next", "Qwen3NextModel"), + ("qwen3_vl", "Qwen3VLModel"), + ("qwen3_vl_moe", "Qwen3VLMoeModel"), + ("qwen3_vl_moe_text", "Qwen3VLMoeTextModel"), + ("qwen3_vl_text", "Qwen3VLTextModel"), ("recurrent_gemma", "RecurrentGemmaModel"), ("reformer", "ReformerModel"), ("regnet", "RegNetModel"), @@ -972,6 +976,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("pix2struct", "Pix2StructForConditionalGeneration"), ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), ("qwen2_vl", "Qwen2VLForConditionalGeneration"), + ("qwen3_vl", "Qwen3VLForConditionalGeneration"), + ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), ("video_llava", "VideoLlavaForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"), ("vision-encoder-decoder", "VisionEncoderDecoderModel"), @@ -1026,6 +1032,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("pixtral", "LlavaForConditionalGeneration"), ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), ("qwen2_vl", "Qwen2VLForConditionalGeneration"), + ("qwen3_vl", "Qwen3VLForConditionalGeneration"), + ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), ("shieldgemma2", "Gemma3ForConditionalGeneration"), ("smolvlm", "SmolVLMForConditionalGeneration"), ("udop", "UdopForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index d8db58cb7b1f..13583c55002f 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -120,6 +120,8 @@ ("qwen2_5_vl", "Qwen2_5_VLProcessor"), ("qwen2_audio", "Qwen2AudioProcessor"), ("qwen2_vl", "Qwen2VLProcessor"), + ("qwen3_vl", "Qwen3VLProcessor"), + ("qwen3_vl_moe", "Qwen3VLProcessor"), ("sam", "SamProcessor"), ("sam2", "Sam2Processor"), ("sam_hq", "SamHQProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 688faf00c4ea..0ef450f45cb9 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -583,6 +583,8 @@ "Qwen2TokenizerFast" if is_tokenizers_available() else None, ), ), + ("qwen3_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), + ("qwen3_vl_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), ("rag", ("RagTokenizer", None)), ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)), ( diff --git a/src/transformers/models/auto/video_processing_auto.py b/src/transformers/models/auto/video_processing_auto.py index b9a5c2204fd1..551de914626e 100644 --- a/src/transformers/models/auto/video_processing_auto.py +++ b/src/transformers/models/auto/video_processing_auto.py @@ -56,6 +56,8 @@ ("qwen2_5_omni", "Qwen2VLVideoProcessor"), ("qwen2_5_vl", "Qwen2VLVideoProcessor"), ("qwen2_vl", "Qwen2VLVideoProcessor"), + ("qwen3_vl", "Qwen3VLVideoProcessor"), + ("qwen3_vl_moe", "Qwen3VLVideoProcessor"), ("sam2_video", "Sam2VideoVideoProcessor"), ("smolvlm", "SmolVLMVideoProcessor"), ("video_llava", "VideoLlavaVideoProcessor"), diff --git a/src/transformers/models/qwen3_vl/__init__.py b/src/transformers/models/qwen3_vl/__init__.py new file mode 100644 index 000000000000..e37161a2e415 --- /dev/null +++ b/src/transformers/models/qwen3_vl/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen3_vl import * + from .modeling_qwen3_vl import * + from .processing_qwen3_vl import * + from .video_processing_qwen3_vl import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/qwen3_vl/configuration_qwen3_vl.py b/src/transformers/models/qwen3_vl/configuration_qwen3_vl.py new file mode 100644 index 000000000000..132ffa8be150 --- /dev/null +++ b/src/transformers/models/qwen3_vl/configuration_qwen3_vl.py @@ -0,0 +1,287 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class Qwen3VLVisionConfig(PretrainedConfig): + model_type = "qwen3_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + deepstack_visual_indexes=[8, 16, 24], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + self.deepstack_visual_indexes = deepstack_visual_indexes + + +class Qwen3VLTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a + Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen3VLModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + head_dim (`int`, *optional*, defaults to 128): + The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 5000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig + + >>> # Initializing a Qwen3VL style configuration + >>> configuration = Qwen3VLTextConfig() + + >>> # Initializing a model from the Qwen3-VL-7B style configuration + >>> model = Qwen3VLTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_text" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + head_dim=128, + hidden_act="silu", + max_position_embeddings=128000, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=5000000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen3VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a + Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig + + >>> # Initializing a Qwen3-VL style configuration + >>> configuration = Qwen3VLConfig() + + >>> # Initializing a model from the Qwen3-VL-4B style configuration + >>> model = Qwen3VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl" + sub_configs = {"vision_config": Qwen3VLVisionConfig, "text_config": Qwen3VLTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) + + +__all__ = ["Qwen3VLConfig", "Qwen3VLTextConfig"] diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py new file mode 100644 index 000000000000..a18366a2a534 --- /dev/null +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -0,0 +1,1568 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig + + +class Qwen3VLVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3VLVisionPatchEmbed(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen3VLVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen3VLVisionPatchMerger(nn.Module): + def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen3VLVisionAttention(nn.Module): + def __init__(self, config: Qwen3VLVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen3VLVisionBlock(GradientCheckpointingLayer): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen3VLVisionAttention(config=config) + self.mlp = Qwen3VLVisionMLP(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3VLTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3VLTextConfig, device=None): + super().__init__() + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", "default") + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen3VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3VLTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Qwen3VLTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3VLTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3VLTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3VLTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3VLTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3VLTextAttention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen3VLTextMLP(config) + self.input_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class Qwen3VLModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[list[torch.FloatTensor]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +@auto_docstring +class Qwen3VLPreTrainedModel(PreTrainedModel): + config: Qwen3VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Qwen3VLTextDecoderLayer, + "attentions": Qwen3VLTextAttention, + } + + +class Qwen3VLVisionModel(Qwen3VLPreTrainedModel): + config: Qwen3VLVisionConfig + _no_split_modules = ["Qwen3VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen3VLVisionPatchEmbed( + config=config, + ) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=False, + ) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=True, + ) + for _ in range(len(config.deepstack_visual_indexes)) + ] + ) + + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + +@auto_docstring( + custom_intro=( + "Text part of Qwen3VL, " + "not a pure text-only model, as DeepStack integrates visual features into the early hidden states." + ) +) +class Qwen3VLTextModel(Qwen3VLPreTrainedModel): + config: Qwen3VLTextConfig + _no_split_modules = ["Qwen3VLTextDecoderLayer"] + + def __init__(self, config: Qwen3VLTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3VLTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*): + The mask of the visual positions. + deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): + The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). + The feature is extracted from the different visual encoder layers, and fed to the decoder + hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + for layer_idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = layer_outputs + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _deepstack_process( + self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor + ): + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + +@auto_docstring +class Qwen3VLModel(Qwen3VLPreTrainedModel): + base_model_prefix = "" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3VLConfig + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen3VLVisionModel._from_config(config.vision_config) + self.language_model = Qwen3VLTextModel._from_config(config.text_config) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + # Same implementation as for images + return self.get_image_features(pixel_values_videos, video_grid_thw) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds, deepstack_image_embeds + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLModelOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_mask = None + video_mask = None + + if pixel_values is not None: + image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + # aggregate visual_pos_masks and deepstack_visual_embeds + image_mask = image_mask[..., 0] + video_mask = video_mask[..., 0] + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + image_mask_joint = image_mask[visual_pos_masks] + video_mask_joint = video_mask[visual_pos_masks] + for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds): + embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device) + embed_joint[image_mask_joint, :] = img_embed + embed_joint[video_mask_joint, :] = vid_embed + deepstack_visual_embeds.append(embed_joint) + elif image_mask is not None: + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + video_mask = video_mask[..., 0] + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_video_embeds + + if position_ids is None: + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_tensor.dtype.is_floating_point: + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **kwargs, + ) + + return Qwen3VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen3VL causal language model (or autoregressive) outputs. + """ +) +class Qwen3VLCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[list[torch.FloatTensor]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = ["lm_head.weight"] + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3VLConfig + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3VLModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + return self.model.get_video_features(pixel_values_videos, video_grid_thw) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + return self.model.get_image_features(pixel_values, image_grid_thw) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def visual(self): + return self.model.visual + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + TODO: Add example + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return Qwen3VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + **kwargs, + ) + + # Qwen3VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size + ) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = [ + "Qwen3VLVisionModel", + "Qwen3VLForConditionalGeneration", + "Qwen3VLModel", + "Qwen3VLPreTrainedModel", + "Qwen3VLTextModel", +] diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py new file mode 100644 index 000000000000..ae608e81a05d --- /dev/null +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -0,0 +1,1472 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3-VL model.""" + +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PretrainedConfig +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update, rope_config_validation +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.generic import check_model_inputs +from ...video_utils import VideoInput +from ..qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLCausalLMOutputWithPast, + Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLModel, + Qwen2_5_VLVisionBlock, +) +from ..qwen2_vl.modeling_qwen2_vl import ( + PatchEmbed, + Qwen2VLModelOutputWithPast, + Qwen2VLPreTrainedModel, + TransformersKwargs, + VisionAttention, + VisionRotaryEmbedding, +) +from ..qwen2_vl.processing_qwen2_vl import Qwen2VLImagesKwargs, Qwen2VLProcessor +from ..qwen3.modeling_qwen3 import ( + Qwen3Attention, + Qwen3DecoderLayer, + Qwen3Model, + apply_rotary_pos_emb, + eager_attention_forward, +) + + +logger = logging.get_logger(__name__) + + +class Qwen3VLVisionConfig(PretrainedConfig): + model_type = "qwen3_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + deepstack_visual_indexes=[8, 16, 24], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + self.deepstack_visual_indexes = deepstack_visual_indexes + + +class Qwen3VLTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a + Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen3VLModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + head_dim (`int`, *optional*, defaults to 128): + The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 5000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig + + >>> # Initializing a Qwen3VL style configuration + >>> configuration = Qwen3VLTextConfig() + + >>> # Initializing a model from the Qwen3-VL-7B style configuration + >>> model = Qwen3VLTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_text" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + head_dim=128, + hidden_act="silu", + max_position_embeddings=128000, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=5000000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen3VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a + Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig + + >>> # Initializing a Qwen3-VL style configuration + >>> configuration = Qwen3VLConfig() + + >>> # Initializing a model from the Qwen3-VL-4B style configuration + >>> model = Qwen3VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl" + sub_configs = {"vision_config": Qwen3VLVisionConfig, "text_config": Qwen3VLTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) + + +class Qwen3VLVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3VLVisionPatchEmbed(PatchEmbed): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + +class Qwen3VLVisionRotaryEmbedding(VisionRotaryEmbedding): + pass + + +class Qwen3VLVisionPatchMerger(nn.Module): + def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +class Qwen3VLVisionAttention(VisionAttention): + def __init__(self, config: Qwen3VLVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + + +class Qwen3VLVisionBlock(Qwen2_5_VLVisionBlock): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen3VLVisionAttention(config=config) + self.mlp = Qwen3VLVisionMLP(config=config) + + +class Qwen3VLTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3VLTextConfig, device=None): + super().__init__() + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", "default") + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen3VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen3VLTextAttention(Qwen3Attention): + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + super().__init__(config, layer_idx) + del self.sliding_window + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3VLTextDecoderLayer(Qwen3DecoderLayer): + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + super().__init__(config, layer_idx) + del self.attention_type + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + return super().forward( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + +class Qwen3VLModelOutputWithPast(Qwen2VLModelOutputWithPast): + pass + + +class Qwen3VLPreTrainedModel(Qwen2VLPreTrainedModel): + config: Qwen3VLConfig + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + _can_record_outputs = { + "hidden_states": Qwen3VLTextDecoderLayer, + "attentions": Qwen3VLTextAttention, + } + + +class Qwen3VLVisionModel(Qwen3VLPreTrainedModel): + config: Qwen3VLVisionConfig + _no_split_modules = ["Qwen3VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen3VLVisionPatchEmbed( + config=config, + ) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=False, + ) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=True, + ) + for _ in range(len(config.deepstack_visual_indexes)) + ] + ) + + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + +@auto_docstring( + custom_intro=( + "Text part of Qwen3VL, " + "not a pure text-only model, as DeepStack integrates visual features into the early hidden states." + ) +) +class Qwen3VLTextModel(Qwen3VLPreTrainedModel, Qwen3Model): + config: Qwen3VLTextConfig + _no_split_modules = ["Qwen3VLTextDecoderLayer"] + + def __init__(self, config: Qwen3VLTextConfig): + super().__init__(config) + del self.has_sliding_layers + + def _deepstack_process( + self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor + ): + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*): + The mask of the visual positions. + deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): + The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). + The feature is extracted from the different visual encoder layers, and fed to the decoder + hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + for layer_idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = layer_outputs + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Qwen3VLModel(Qwen2_5_VLModel): + config: Qwen3VLConfig + _checkpoint_conversion_mapping = {} + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen3VLVisionModel._from_config(config.vision_config) + self.language_model = Qwen3VLTextModel._from_config(config.text_config) + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds, deepstack_image_embeds + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + # Same implementation as for images + return self.get_image_features(pixel_values_videos, video_grid_thw) + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLModelOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_mask = None + video_mask = None + + if pixel_values is not None: + image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + # aggregate visual_pos_masks and deepstack_visual_embeds + image_mask = image_mask[..., 0] + video_mask = video_mask[..., 0] + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + image_mask_joint = image_mask[visual_pos_masks] + video_mask_joint = video_mask[visual_pos_masks] + for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds): + embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device) + embed_joint[image_mask_joint, :] = img_embed + embed_joint[video_mask_joint, :] = vid_embed + deepstack_visual_embeds.append(embed_joint) + elif image_mask is not None: + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + video_mask = video_mask[..., 0] + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_video_embeds + + if position_ids is None: + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_tensor.dtype.is_floating_point: + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **kwargs, + ) + + return Qwen3VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + +class Qwen3VLCausalLMOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast): + pass + + +class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): + config: Qwen3VLConfig + _checkpoint_conversion_mapping = {} + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + TODO: Add example + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return Qwen3VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + **kwargs, + ) + + # Qwen3VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + +class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False): + pass + + +class Qwen3VLImagesKwargs(Qwen2VLImagesKwargs): + pass + + +class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Qwen3VLImagesKwargs + videos_kwargs: Qwen3VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_token_type_ids": False, + "return_mm_token_type_ids": False, + }, + "videos_kwargs": {"return_metadata": True}, + } + + +class Qwen3VLProcessor(Qwen2VLProcessor): + r""" + Constructs a Qwen3VL processor which wraps a Qwen3VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen3VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen3VLProcessor.__call__`] and [`~Qwen3VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + video_processor ([`Qwen3VLVideoProcessor`], *optional*): + The video processor is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): + super().__init__(image_processor, tokenizer, video_processor, chat_template, **kwargs) + self.vision_start_token = ( + "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token + ) + self.vision_end_token = ( + "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token + ) + self.vision_start_token_id = ( + tokenizer.vision_start_token_id + if getattr(tokenizer, "vision_start_token_id", None) + else tokenizer.convert_tokens_to_ids(self.vision_start_token) + ) + self.vision_end_token_id = ( + tokenizer.vision_end_token_id + if getattr(tokenizer, "vision_end_token_id", None) + else tokenizer.convert_tokens_to_ids(self.vision_end_token) + ) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen3VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen3VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + # If user has not requested video metadata, pop it + if "return_metadata" not in kwargs: + video_metadata = videos_inputs.pop("video_metadata") + else: + video_metadata = videos_inputs["video_metadata"] + video_grid_thw = videos_inputs["video_grid_thw"] + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + text = text.copy() # below lines change text in-place + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.video_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + metadata = video_metadata[i] + if metadata.fps is None: + logger.warning_once( + "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. " + "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. " + "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results." + ) + metadata.fps = 24 if metadata.fps is None else metadata.fps + + # if timestamps are not provided, calculate them + curr_timestamp = self._calculate_timestamps( + metadata.frames_indices, + metadata.fps, + self.video_processor.merge_size, + ) + + video_placeholder = "" + frame_seqlen = video_grid_thw[index][1:].prod() // merge_length + for frame_idx in range(video_grid_thw[index][0]): + curr_time = curr_timestamp[frame_idx] + video_placeholder += f"<{curr_time:.1f} seconds>" + video_placeholder += ( + self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token + ) + if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]: + text[i] = text[i].replace( + f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1 + ) + else: + # vllm may input video token directly + text[i] = text[i].replace(self.video_token, video_placeholder, 1) + index += 1 + + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) + + def _calculate_timestamps(self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2): + if not isinstance(indices, list): + indices = indices.tolist() + if len(indices) % merge_size != 0: + indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size)) + timestamps = [idx / video_fps for idx in indices] + # @JJJYmmm frames are merged by self.merge_size, \ + # so we need to average the timestamps between the first/last frame within the temporal patch + timestamps = [ + (timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size) + ] + return timestamps + + +__all__ = [ + "Qwen3VLConfig", + "Qwen3VLTextConfig", + "Qwen3VLVisionModel", + "Qwen3VLForConditionalGeneration", + "Qwen3VLModel", + "Qwen3VLPreTrainedModel", + "Qwen3VLProcessor", + "Qwen3VLTextModel", +] diff --git a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py new file mode 100644 index 000000000000..cac82e738f39 --- /dev/null +++ b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py @@ -0,0 +1,328 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging +from ...video_utils import VideoInput + + +logger = logging.get_logger(__name__) + + +class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False): + pass + + +class Qwen3VLImagesKwargs(ImagesKwargs): + min_pixels: Optional[int] + max_pixels: Optional[int] + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Qwen3VLImagesKwargs + videos_kwargs: Qwen3VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_token_type_ids": False, + "return_mm_token_type_ids": False, + }, + "videos_kwargs": {"return_metadata": True}, + } + + +class Qwen3VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen3VL processor which wraps a Qwen3VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen3VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen3VLProcessor.__call__`] and [`~Qwen3VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + video_processor ([`Qwen3VLVideoProcessor`], *optional*): + The video processor is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer", "video_processor"] + image_processor_class = "AutoImageProcessor" + video_processor_class = "AutoVideoProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): + super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template) + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + self.video_token_id = ( + tokenizer.video_token_id + if getattr(tokenizer, "video_token_id", None) + else tokenizer.convert_tokens_to_ids(self.video_token) + ) + self.vision_start_token = ( + "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token + ) + self.vision_end_token = ( + "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token + ) + self.vision_start_token_id = ( + tokenizer.vision_start_token_id + if getattr(tokenizer, "vision_start_token_id", None) + else tokenizer.convert_tokens_to_ids(self.vision_start_token) + ) + self.vision_end_token_id = ( + tokenizer.vision_end_token_id + if getattr(tokenizer, "vision_end_token_id", None) + else tokenizer.convert_tokens_to_ids(self.vision_end_token) + ) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen3VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen3VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + # If user has not requested video metadata, pop it + if "return_metadata" not in kwargs: + video_metadata = videos_inputs.pop("video_metadata") + else: + video_metadata = videos_inputs["video_metadata"] + video_grid_thw = videos_inputs["video_grid_thw"] + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + text = text.copy() # below lines change text in-place + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.video_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + metadata = video_metadata[i] + if metadata.fps is None: + logger.warning_once( + "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. " + "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. " + "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results." + ) + metadata.fps = 24 if metadata.fps is None else metadata.fps + + # if timestamps are not provided, calculate them + curr_timestamp = self._calculate_timestamps( + metadata.frames_indices, + metadata.fps, + self.video_processor.merge_size, + ) + + video_placeholder = "" + frame_seqlen = video_grid_thw[index][1:].prod() // merge_length + for frame_idx in range(video_grid_thw[index][0]): + curr_time = curr_timestamp[frame_idx] + video_placeholder += f"<{curr_time:.1f} seconds>" + video_placeholder += ( + self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token + ) + if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]: + text[i] = text[i].replace( + f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1 + ) + else: + # vllm may input video token directly + text[i] = text[i].replace(self.video_token, video_placeholder, 1) + index += 1 + + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + video_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (num_frames, height, width) per each video. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + images_kwargs = Qwen3VLProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + if video_sizes is not None: + videos_kwargs = Qwen3VLProcessorKwargs._defaults.get("videos_kwargs", {}) + videos_kwargs.update(kwargs) + num_video_patches = [ + self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs) + for video_size in video_sizes + ] + num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches] + vision_data["num_video_tokens"] = num_video_tokens + + return MultiModalData(**vision_data) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `list[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def _calculate_timestamps(self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2): + if not isinstance(indices, list): + indices = indices.tolist() + if len(indices) % merge_size != 0: + indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size)) + timestamps = [idx / video_fps for idx in indices] + # @JJJYmmm frames are merged by self.merge_size, \ + # so we need to average the timestamps between the first/last frame within the temporal patch + timestamps = [ + (timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size) + ] + return timestamps + + +__all__ = ["Qwen3VLProcessor"] diff --git a/src/transformers/models/qwen3_vl/video_processing_qwen3_vl.py b/src/transformers/models/qwen3_vl/video_processing_qwen3_vl.py new file mode 100644 index 000000000000..c4648788c9dc --- /dev/null +++ b/src/transformers/models/qwen3_vl/video_processing_qwen3_vl.py @@ -0,0 +1,276 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""video processor class for Qwen3-VL.""" + +import math +from typing import Optional, Union + +import numpy as np +import torch + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ChannelDimension, PILImageResampling, SizeDict, get_image_size +from ...processing_utils import Unpack, VideosKwargs +from ...utils import TensorType, add_start_docstrings, logging +from ...video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor +from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos + + +logger = logging.get_logger(__name__) + + +def smart_resize( + num_frames: int, + height: int, + width: int, + temporal_factor: int = 2, + factor: int = 32, + min_pixels: int = 128 * 128, + max_pixels: int = 16 * 16 * 2 * 2 * 2 * 6144, +): + if num_frames < temporal_factor: + raise ValueError(f"t:{num_frames} must be larger than temporal_factor:{temporal_factor}") + if height < factor or width < factor: + raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") + elif max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + t_bar = round(num_frames / temporal_factor) * temporal_factor + + if t_bar * h_bar * w_bar > max_pixels: + beta = math.sqrt((num_frames * height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif t_bar * h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (num_frames * height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + + return h_bar, w_bar + + +class Qwen3VLVideoProcessorInitKwargs(VideosKwargs): + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + min_frames: Optional[int] + max_frames: Optional[int] + + +@add_start_docstrings( + "Constructs a fast Qwen3-VL image processor that dynamically resizes videos based on the original videos.", + BASE_VIDEO_PROCESSOR_DOCSTRING, + """ + patch_size (`int`, *optional*, defaults to 16): + The spacial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """, +) +class Qwen3VLVideoProcessor(BaseVideoProcessor): + resample = PILImageResampling.BICUBIC + size = {"shortest_edge": 128 * 32 * 32, "longest_edge": 32 * 32 * 768} + image_mean = [0.5, 0.5, 0.5] + image_std = [0.5, 0.5, 0.5] + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + patch_size = 16 + temporal_patch_size = 2 + merge_size = 2 + fps = 2 + min_frames = 4 + max_frames = 768 + do_sample_frames = True + valid_kwargs = Qwen3VLVideoProcessorInitKwargs + model_input_names = ["pixel_values_videos", "video_grid_thw"] + + def __init__(self, **kwargs: Unpack[Qwen3VLVideoProcessorInitKwargs]): + super().__init__(**kwargs) + if self.size is not None and ( + self.size.get("shortest_edge", None) is None or self.size.get("longest_edge", None) is None + ): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + + return super()._further_process_kwargs(size=size, **kwargs) + + def sample_frames( + self, + metadata: VideoMetadata, + num_frames: Optional[int] = None, + fps: Optional[Union[int, float]] = None, + **kwargs, + ): + """ + Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames. + If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames` + and `fps` are mutually exclusive. + + Args: + video (`torch.Tensor`): + Video that need to be sampled. + metadata (`VideoMetadata`): + Metadata of the video containing information about total duration, fps and total number of frames. + num_frames (`int`, *optional*): + Maximum number of frames to sample. Defaults to `self.num_frames`. + fps (`int` or `float`, *optional*): + Target frames to sample per second. Defaults to `self.fps`. + Returns: + torch.Tensor: + Sampled video frames. + """ + if fps is not None and num_frames is not None: + raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!") + + total_num_frames = metadata.total_num_frames + fps = fps if fps is not None else self.fps + + # If num_frames is not given but fps is, calculate num_frames from fps + if num_frames is None and fps is not None: + if metadata.fps is None: + metadata.fps = 24 + logger.warning_once( + "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. " + "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results." + ) + num_frames = int(total_num_frames / metadata.fps * fps) + num_frames = min(min(max(num_frames, self.min_frames), self.max_frames), total_num_frames) + + if num_frames is None: + num_frames = min(max(total_num_frames, self.min_frames), self.max_frames) + + indices = np.linspace(0, total_num_frames - 1, num_frames).round().astype(int) + + return indices + + def _preprocess( + self, + videos: list[torch.Tensor], + do_convert_rgb: bool = True, + do_resize: bool = True, + size: Optional[SizeDict] = None, + interpolation: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: float = 1 / 255.0, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + patch_size: Optional[int] = None, + temporal_patch_size: Optional[int] = None, + merge_size: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ): + grouped_videos, grouped_videos_index = group_videos_by_shape(videos) + resized_videos_grouped = {} + + for shape, stacked_videos in grouped_videos.items(): + B, T, C, H, W = stacked_videos.shape + num_frames, height, width = T, H, W + if do_resize: + resized_height, resized_width = smart_resize( + num_frames=num_frames, + height=height, + width=width, + temporal_factor=temporal_patch_size, + factor=patch_size * merge_size, + min_pixels=size.shortest_edge, + max_pixels=size.longest_edge, + ) + stacked_videos = stacked_videos.view(B * T, C, H, W) + stacked_videos = self.resize( + stacked_videos, + size=SizeDict(height=resized_height, width=resized_width), + interpolation=interpolation, + ) + stacked_videos = stacked_videos.view(B, T, C, resized_height, resized_width) + resized_videos_grouped[shape] = stacked_videos + resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) + + # Group videos by size for further processing + # Needed in case do_resize is False, or resize returns videos with different sizes + grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos) + processed_videos_grouped = {} + processed_grids = {} + for shape, stacked_videos in grouped_videos.items(): + resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST) + + # Fused rescale and normalize + stacked_videos = self.rescale_and_normalize( + stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + patches = stacked_videos + + # Check that videos have `num_frames` divisible by `temporal_patch_size` + if patches.shape[1] % temporal_patch_size != 0: + repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1) + patches = torch.cat([patches, repeats], dim=1) + batch_size, grid_t, channel = patches.shape[:3] + grid_t = grid_t // temporal_patch_size + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + + patches = patches.view( + batch_size, + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + flatten_patches = patches.reshape( + batch_size, + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + + processed_videos_grouped[shape] = flatten_patches + processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size + + processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index) + processed_grids = reorder_videos(processed_grids, grouped_videos_index) + pixel_values_videos = torch.cat(processed_videos, dim=0) + video_grid_thw = torch.tensor(processed_grids) + data = { + "pixel_values_videos": pixel_values_videos, + "video_grid_thw": video_grid_thw, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Qwen3VLVideoProcessor"] diff --git a/src/transformers/models/qwen3_vl_moe/__init__.py b/src/transformers/models/qwen3_vl_moe/__init__.py new file mode 100644 index 000000000000..a4000cb27272 --- /dev/null +++ b/src/transformers/models/qwen3_vl_moe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen3_vl_moe import * + from .modeling_qwen3_vl_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py new file mode 100644 index 000000000000..c4a31e8f9f92 --- /dev/null +++ b/src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py @@ -0,0 +1,331 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_vl_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class Qwen3VLMoeTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a + Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2MoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 5000000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 1408): + Intermediate size of the routed expert. + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 60): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + head_dim (`int`, *optional*): + The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`. + + ```python + >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig + + >>> # Initializing a Qwen3VLMoe style configuration + >>> configuration = Qwen3VLMoeConfig() + + >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration + >>> model = Qwen3VLMoeForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_moe_text" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen3VLMoe` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + hidden_act="silu", + max_position_embeddings=128000, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=5000000.0, + attention_bias=False, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=1408, + num_experts_per_tok=4, + num_experts=60, + norm_topk_prob=True, + mlp_only_layers=None, + rope_scaling=None, + head_dim=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + self.head_dim = head_dim or hidden_size // num_attention_heads + + rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen3VLMoeVisionConfig(PretrainedConfig): + model_type = "qwen3_vl_moe" + base_config_key = "vision_config" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + deepstack_visual_indexes=[8, 16, 24], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + self.deepstack_visual_indexes = deepstack_visual_indexes + + +class Qwen3VLMoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a + Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig + + >>> # Initializing a Qwen3-VL-MOE style configuration + >>> configuration = Qwen3VLMoeConfig() + + >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration + >>> model = Qwen3VLMoeForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_moe" + sub_configs = {"vision_config": Qwen3VLMoeVisionConfig, "text_config": Qwen3VLMoeTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) + + +__all__ = ["Qwen3VLMoeConfig", "Qwen3VLMoeTextConfig"] diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py new file mode 100644 index 000000000000..74b793f096f3 --- /dev/null +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -0,0 +1,1711 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_vl_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3VLMoeTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen3VLMoeTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen3VLMoeTextRouter(nn.Linear): + def __init__(self, config): + super().__init__(config.hidden_size, config.num_experts, bias=False) + self.hidden_size = config.hidden_size + self.top_k = config.num_experts_per_tok + # since all the models use norm_topk_prob, we don't need to have a extra check for it + # self.norm_topk_prob = config.norm_topk_prob + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_size) + router_logits = super().forward(hidden_states) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights) + return router_weights, router_logits, router_indices + + +class Qwen3VLMoeTextExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.intermediate_size = config.moe_intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor + ) -> torch.Tensor: + """ + When training it is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) + router_indices (torch.Tensor): (batch_size * token_num, top_k) + Returns: + torch.Tensor + """ + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + if self.training: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence length to get which experts + # are hit this time around + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit[:]: + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx[0]]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + gate, up = gate_up.chunk(2, dim=-1) + gated_output = up * self.act_fn(gate) + out = gated_output @ self.down_proj[expert_idx] + weighted_output = out[0] * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.view(batch_size, -1, self.hidden_size) + else: + hidden_states = hidden_states.repeat(self.num_experts, 1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) + next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size) + next_states = ( + next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None] + ) + next_states = next_states.sum(dim=0) + return next_states + + +class Qwen3VLMoeTextSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_experts = config.num_experts + self.gate = Qwen3VLMoeTextRouter(config) + self.experts = Qwen3VLMoeTextExperts(config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + router_weights, router_logits, router_indices = self.gate(hidden_states) + routed_out = self.experts(hidden_states, router_weights, router_indices) + return routed_out, router_logits + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3VLMoeTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3VLMoeTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # unlike olmo, only on the head dim! + self.k_norm = Qwen3VLMoeTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3VLMoeTextMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3VLMoeTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3VLMoeTextAttention(config, layer_idx) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3VLMoeTextSparseMoeBlock(config) + else: + self.mlp = Qwen3VLMoeTextMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3VLMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3VLMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # For the MoE layers, we need to unpack + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +@auto_docstring +class Qwen3VLMoePreTrainedModel(PreTrainedModel): + config: Qwen3VLMoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(Qwen3VLMoeTextSparseMoeBlock, index=1), + "hidden_states": Qwen3VLMoeTextDecoderLayer, + "attentions": Qwen3VLMoeTextAttention, + } + + def _init_weights(self, module): + """Initialize the weights.""" + super()._init_weights(module) + if hasattr(self.config, "initializer_range"): + std = self.config.initializer_range + else: + std = getattr(self.config.get_text_config(), "initializer_range", 0.02) + if isinstance(module, Qwen3VLMoeTextExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.down_proj.data.normal_(mean=0.0, std=std) + + +class Qwen3VLMoeVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3VLMoeVisionPatchEmbed(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen3VLMoeVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen3VLMoeVisionPatchMerger(nn.Module): + def __init__(self, config: Qwen3VLMoeVisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class Qwen3VLMoeVisionAttention(nn.Module): + def __init__(self, config: Qwen3VLMoeVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen3VLMoeVisionBlock(GradientCheckpointingLayer): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen3VLMoeVisionAttention(config=config) + self.mlp = Qwen3VLMoeVisionMLP(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3VLMoeVisionModel(Qwen3VLMoePreTrainedModel): + config: Qwen3VLMoeVisionConfig + _no_split_modules = ["Qwen3VLMoeVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen3VLMoeVisionPatchEmbed( + config=config, + ) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLMoeVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Qwen3VLMoeVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen3VLMoeVisionPatchMerger( + config=config, + use_postshuffle_norm=False, + ) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLMoeVisionPatchMerger( + config=config, + use_postshuffle_norm=True, + ) + for _ in range(len(config.deepstack_visual_indexes)) + ] + ) + + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + +class Qwen3VLMoeTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3VLMoeTextConfig, device=None): + super().__init__() + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", "default") + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen3VLMoe has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring( + custom_intro=( + "Text part of Qwen3VLMoe, " + "not a pure text-only model, as DeepStack integrates visual features into the early hidden states." + ) +) +class Qwen3VLMoeTextModel(Qwen3VLMoePreTrainedModel): + config: Qwen3VLMoeTextConfig + _no_split_modules = ["Qwen3VLMoeTextDecoderLayer"] + + def __init__(self, config: Qwen3VLMoeTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3VLMoeTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3VLMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3VLMoeTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*): + The mask of the visual positions. + deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): + The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). + The feature is extracted from the different visual encoder layers, and fed to the decoder + hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + for layer_idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = layer_outputs + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _deepstack_process( + self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor + ): + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class Qwen3VLMoeModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[list[torch.FloatTensor]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +@auto_docstring +class Qwen3VLMoeModel(Qwen3VLMoePreTrainedModel): + base_model_prefix = "" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3VLMoeConfig + _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen3VLMoeVisionModel._from_config(config.vision_config) + self.language_model = Qwen3VLMoeTextModel._from_config(config.text_config) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3VLMoe use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + # Same implementation as for images + return self.get_image_features(pixel_values_videos, video_grid_thw) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds, deepstack_image_embeds + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLMoeModelOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_mask = None + video_mask = None + + if pixel_values is not None: + image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + # aggregate visual_pos_masks and deepstack_visual_embeds + image_mask = image_mask[..., 0] + video_mask = video_mask[..., 0] + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + image_mask_joint = image_mask[visual_pos_masks] + video_mask_joint = video_mask[visual_pos_masks] + for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds): + embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device) + embed_joint[image_mask_joint, :] = img_embed + embed_joint[video_mask_joint, :] = vid_embed + deepstack_visual_embeds.append(embed_joint) + elif image_mask is not None: + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + video_mask = video_mask[..., 0] + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_video_embeds + + if position_ids is None: + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_tensor.dtype.is_floating_point: + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **kwargs, + ) + + return Qwen3VLMoeModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen3VLMoe causal language model (or autoregressive) outputs. + """ +) +class Qwen3VLMoeCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[list[torch.FloatTensor]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = ["lm_head.weight"] + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3VLMoeConfig + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3VLMoeModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + return self.model.get_video_features(pixel_values_videos, video_grid_thw) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + return self.model.get_image_features(pixel_values, image_grid_thw) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def visual(self): + return self.model.visual + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLMoeCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + TODO: Add example + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return Qwen3VLMoeCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + **kwargs, + ) + + # Qwen3VLMoe position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size + ) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = [ + "Qwen3VLMoeVisionModel", + "Qwen3VLMoeForConditionalGeneration", + "Qwen3VLMoeModel", + "Qwen3VLMoePreTrainedModel", + "Qwen3VLMoeTextModel", +] diff --git a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py new file mode 100644 index 000000000000..456d7c60aa89 --- /dev/null +++ b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py @@ -0,0 +1,434 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3-VL-MOE model.""" + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from ..qwen3_moe.modeling_qwen3_moe import ( + Qwen3MoeDecoderLayer, + Qwen3MoePreTrainedModel, + Qwen3MoeRMSNorm, +) +from ..qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig +from ..qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + Qwen3VLModel, + Qwen3VLTextAttention, + Qwen3VLTextModel, + Qwen3VLVisionModel, +) + + +logger = logging.get_logger(__name__) + + +class Qwen3VLMoeTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a + Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2MoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 5000000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 1408): + Intermediate size of the routed expert. + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 60): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + head_dim (`int`, *optional*): + The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`. + + ```python + >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig + + >>> # Initializing a Qwen3VLMoe style configuration + >>> configuration = Qwen3VLMoeConfig() + + >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration + >>> model = Qwen3VLMoeForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_moe_text" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen3VLMoe` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + hidden_act="silu", + max_position_embeddings=128000, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=5000000.0, + attention_bias=False, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=1408, + num_experts_per_tok=4, + num_experts=60, + norm_topk_prob=True, + mlp_only_layers=None, + rope_scaling=None, + head_dim=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + self.head_dim = head_dim or hidden_size // num_attention_heads + + rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"}) + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen3VLMoeVisionConfig(Qwen3VLVisionConfig): + pass + + +class Qwen3VLMoeConfig(Qwen3VLConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a + Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig + + >>> # Initializing a Qwen3-VL-MOE style configuration + >>> configuration = Qwen3VLMoeConfig() + + >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration + >>> model = Qwen3VLMoeForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_vl_moe" + sub_configs = {"vision_config": Qwen3VLMoeVisionConfig, "text_config": Qwen3VLMoeTextConfig} + + +class Qwen3VLMoeTextRMSNorm(Qwen3MoeRMSNorm): + pass + + +class Qwen3VLMoeTextRouter(nn.Linear): + def __init__(self, config): + super().__init__(config.hidden_size, config.num_experts, bias=False) + self.hidden_size = config.hidden_size + self.top_k = config.num_experts_per_tok + # since all the models use norm_topk_prob, we don't need to have a extra check for it + # self.norm_topk_prob = config.norm_topk_prob + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_size) + router_logits = super().forward(hidden_states) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights) + return router_weights, router_logits, router_indices + + +class Qwen3VLMoeTextExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.intermediate_size = config.moe_intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor + ) -> torch.Tensor: + """ + When training it is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) + router_indices (torch.Tensor): (batch_size * token_num, top_k) + Returns: + torch.Tensor + """ + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + if self.training: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence length to get which experts + # are hit this time around + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit[:]: + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx[0]]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + gate, up = gate_up.chunk(2, dim=-1) + gated_output = up * self.act_fn(gate) + out = gated_output @ self.down_proj[expert_idx] + weighted_output = out[0] * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.view(batch_size, -1, self.hidden_size) + else: + hidden_states = hidden_states.repeat(self.num_experts, 1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) + next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size) + next_states = ( + next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None] + ) + next_states = next_states.sum(dim=0) + return next_states + + +class Qwen3VLMoeTextSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_experts = config.num_experts + self.gate = Qwen3VLMoeTextRouter(config) + self.experts = Qwen3VLMoeTextExperts(config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + router_weights, router_logits, router_indices = self.gate(hidden_states) + routed_out = self.experts(hidden_states, router_weights, router_indices) + return routed_out, router_logits + + +class Qwen3VLMoeTextAttention(Qwen3VLTextAttention): + pass + + +class Qwen3VLMoeTextDecoderLayer(Qwen3MoeDecoderLayer): + pass + + +class Qwen3VLMoePreTrainedModel(Qwen3MoePreTrainedModel): + config: Qwen3VLMoeConfig + _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"] + + def _init_weights(self, module): + """Initialize the weights.""" + PreTrainedModel._init_weights(self, module) + if hasattr(self.config, "initializer_range"): + std = self.config.initializer_range + else: + std = getattr(self.config.get_text_config(), "initializer_range", 0.02) + if isinstance(module, Qwen3VLMoeTextExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.down_proj.data.normal_(mean=0.0, std=std) + + +class Qwen3VLMoeVisionModel(Qwen3VLVisionModel): + pass + + +class Qwen3VLMoeTextModel(Qwen3VLTextModel): + pass + + +class Qwen3VLMoeModel(Qwen3VLModel): + pass + + +class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): + pass + + +__all__ = [ + "Qwen3VLMoeConfig", + "Qwen3VLMoeTextConfig", + "Qwen3VLMoeVisionModel", + "Qwen3VLMoeForConditionalGeneration", + "Qwen3VLMoeModel", + "Qwen3VLMoePreTrainedModel", + "Qwen3VLMoeTextModel", +] diff --git a/tests/models/qwen3_vl/__init__.py b/tests/models/qwen3_vl/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/qwen3_vl/test_modeling_qwen3_vl.py b/tests/models/qwen3_vl/test_modeling_qwen3_vl.py new file mode 100644 index 000000000000..35031bf542aa --- /dev/null +++ b/tests/models/qwen3_vl/test_modeling_qwen3_vl.py @@ -0,0 +1,299 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Qwen3-VL model.""" + +import copy +import unittest + +from transformers import ( + Qwen3VLConfig, + Qwen3VLForConditionalGeneration, + Qwen3VLModel, + is_torch_available, +) +from transformers.testing_utils import ( + require_torch, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + floats_tensor, + ids_tensor, +) + + +if is_torch_available(): + import torch + + +class Qwen3VLVisionText2TextModelTester: + def __init__( + self, + parent, + batch_size=3, + seq_length=7, + num_channels=3, + ignore_index=-100, + image_size=16, + text_config={ + "bos_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 2, + "hidden_act": "silu", + "head_dim": 8, + "hidden_size": 32, + "vocab_size": 99, + "intermediate_size": 37, + "max_position_embeddings": 512, + "model_type": "qwen3_vl", + "num_attention_heads": 4, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rope_theta": 10000, + "tie_word_embeddings": True, + "rope_scaling": {"rope_type": "default", "mrope_section": [16, 8, 8], "mrope_interleaved": True}, + }, + vision_config={ + "depth": 2, + "in_chans": 3, + "hidden_act": "gelu_pytorch_tanh", + "intermediate_size": 32, + "out_hidden_size": 32, + "hidden_size": 32, + "num_heads": 4, + "patch_size": 16, + "spatial_merge_size": 1, + "temporal_patch_size": 2, + "num_position_embeddings": 16, + "deepstack_visual_indexes": [0, 1], + }, + image_token_id=3, + video_token_id=4, + vision_start_token_id=5, + vision_end_token_id=6, + tie_word_embeddings=True, + is_training=True, + ): + self.parent = parent + self.ignore_index = ignore_index + self.is_training = is_training + + self.vision_config = vision_config + self.text_config = text_config + + self.vocab_size = text_config["vocab_size"] + self.bos_token_id = text_config["bos_token_id"] + self.eos_token_id = text_config["eos_token_id"] + self.pad_token_id = text_config["pad_token_id"] + self.head_dim = text_config["head_dim"] + self.hidden_size = text_config["hidden_size"] + self.intermediate_size = text_config["intermediate_size"] + self.num_hidden_layers = text_config["num_hidden_layers"] + self.num_attention_heads = text_config["num_attention_heads"] + self.num_key_value_heads = text_config["num_key_value_heads"] + self.rope_theta = text_config["rope_theta"] + self.rope_scaling = text_config["rope_scaling"] + self.hidden_act = text_config["hidden_act"] + self.max_position_embeddings = text_config["max_position_embeddings"] + self.model_type = text_config["model_type"] + + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.tie_word_embeddings = tie_word_embeddings + + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.num_image_tokens = 32 + self.seq_length = seq_length + self.num_image_tokens + + def get_config(self): + return Qwen3VLConfig( + text_config=self.text_config, + vision_config=self.vision_config, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + vision_end_token_id=self.vision_end_token_id, + tie_word_embeddings=self.tie_word_embeddings, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + patch_size = config.vision_config.patch_size + temporal_patch_size = config.vision_config.temporal_patch_size + pixel_values = floats_tensor( + [ + self.batch_size * (self.image_size**2) // (patch_size**2), + self.num_channels * (patch_size**2) * temporal_patch_size, + ] + ) + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + input_ids[:, -1] = self.pad_token_id + input_ids[input_ids == self.video_token_id] = self.pad_token_id + input_ids[input_ids == self.image_token_id] = self.pad_token_id + input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id + input_ids[:, self.num_image_tokens] = self.image_token_id + input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id + inputs_dict = { + "pixel_values": pixel_values, + "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device), + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class Qwen3VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + """ + Model tester for `Qwen3VLForConditionalGeneration`. + """ + + all_model_classes = ( + ( + Qwen3VLModel, + Qwen3VLForConditionalGeneration, + ) + if is_torch_available() + else () + ) + test_pruning = False + test_head_masking = False + + def setUp(self): + self.model_tester = Qwen3VLVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=Qwen3VLConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_mismatching_num_image_tokens(self): + """ + Tests that VLMs through an error with explicit message saying what is wrong + when number of images don't match number of image tokens in the text. + Also we need to test multi-image cases when one prompr has multiple image tokens. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + _ = model(**input_dict) # successful forward with no modifications + curr_input_dict = copy.deepcopy(input_dict) + + # remove one image but leave the image token in text + patch_size = config.vision_config.patch_size + one_img_length = (self.model_tester.image_size**2) // (patch_size**2) + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-one_img_length:, ...] + curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][-1:, ...] + with self.assertRaises(ValueError): + _ = model(**curr_input_dict) + + # simulate multi-image case by concatenating inputs where each has exactly one image/image-token + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:one_img_length] + image_grid_thw = curr_input_dict["image_grid_thw"][:1] + input_ids = torch.cat([input_ids, input_ids], dim=0) + + # one image and two image tokens raise an error + with self.assertRaises(ValueError): + _ = model( + input_ids=input_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + # two images and two image tokens don't raise an error + pixel_values = torch.cat([pixel_values, pixel_values], dim=0) + image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0) + _ = model( + input_ids=input_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + def test_video_forward(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + B = self.model_tester.batch_size + C = config.vision_config.in_chans + T = config.vision_config.temporal_patch_size + P = config.vision_config.patch_size + + input_ids = ids_tensor([B, self.model_tester.seq_length], self.model_tester.vocab_size) + + F = 4 + patch_H = self.model_tester.image_size // P + patch_W = self.model_tester.image_size // P + patch_T = F // T + patches_per_video = patch_T * patch_H * patch_W + pathed_per_frame = patch_H * patch_W + pixel_values_videos = floats_tensor( + [ + # first dim: batch_size * num_patches + B * patches_per_video, + # second dim: in_channels * temporal_patch_size * patch_size^2 + C * T * (P**2), + ] + ) + + # qwen3vl use timestamps for video, so split it into patch_T sub-videos + video_grid_thw = torch.tensor([[1, patch_H, patch_W] for _ in range(patch_T)] * B) + + # sanity check + self.assertEqual(pixel_values_videos.shape[0], video_grid_thw.prod(dim=1).sum().item()) + + # Insert video token sequence + input_ids[:, -1] = self.model_tester.pad_token_id + input_ids[input_ids == self.model_tester.video_token_id] = self.model_tester.pad_token_id + input_ids[input_ids == self.model_tester.image_token_id] = self.model_tester.pad_token_id + input_ids[input_ids == self.model_tester.vision_start_token_id] = self.model_tester.pad_token_id + input_ids[:, self.model_tester.num_image_tokens] = self.model_tester.video_token_id + + insertion_point = self.model_tester.num_image_tokens + + self.assertLessEqual((B * patches_per_video) + insertion_point, self.model_tester.seq_length) + for b in range(B): + # each frame is separated by a vision_start_token_id + for frame_idx in range(patch_T): + input_ids[b, insertion_point + frame_idx * (pathed_per_frame + 1)] = ( + self.model_tester.vision_start_token_id + ) + input_ids[ + b, + insertion_point + frame_idx * (pathed_per_frame + 1) + 1 : insertion_point + + (frame_idx + 1) * (pathed_per_frame + 1), + ] = self.model_tester.video_token_id + + for model_class in self.all_model_classes: + # TODO:we should remove this because we use timestamps for video + model = model_class(config).to(torch_device) + outputs = model( + input_ids=input_ids, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + self.assertIsNotNone(outputs) diff --git a/tests/models/qwen3_vl/test_processing_qwen3_vl.py b/tests/models/qwen3_vl/test_processing_qwen3_vl.py new file mode 100644 index 000000000000..87636dcf607d --- /dev/null +++ b/tests/models/qwen3_vl/test_processing_qwen3_vl.py @@ -0,0 +1,379 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import shutil +import tempfile +import unittest + +import numpy as np +import pytest + +from transformers import AutoProcessor, Qwen2TokenizerFast +from transformers.testing_utils import require_av, require_torch, require_torchvision, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import Qwen2VLImageProcessorFast, Qwen3VLProcessor + +if is_torch_available(): + import torch + + +@require_vision +@require_torch +@require_torchvision +@unittest.skip("The checkpoint is not yet released") +class Qwen3VLProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Qwen3VLProcessor + + @classmethod + def setUpClass(cls): + cls.tmpdirname = tempfile.mkdtemp() + processor = Qwen3VLProcessor.from_pretrained( + "Qwen/Qwen3-VL-4B-Instruct", patch_size=4, max_pixels=56 * 56, min_pixels=28 * 28 + ) + processor.save_pretrained(cls.tmpdirname) + cls.image_token = processor.image_token + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def get_video_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor + + def get_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdirname, ignore_errors=True) + + # Copied from tests.models.llava.test_processing_llava.LlavaProcessorTest.test_get_num_vision_tokens + def test_get_num_vision_tokens(self): + "Tests general functionality of the helper used internally in vLLM" + + processor = self.get_processor() + + output = processor._get_num_multimodal_tokens(image_sizes=[(100, 100), (300, 100), (500, 30)]) + self.assertTrue("num_image_tokens" in output) + self.assertEqual(len(output["num_image_tokens"]), 3) + + self.assertTrue("num_image_patches" in output) + self.assertEqual(len(output["num_image_patches"]), 3) + + def test_save_load_pretrained_default(self): + tokenizer = self.get_tokenizer() + image_processor = self.get_image_processor() + video_processor = self.get_video_processor() + + processor = Qwen3VLProcessor( + tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor + ) + processor.save_pretrained(self.tmpdirname) + processor = Qwen3VLProcessor.from_pretrained(self.tmpdirname, use_fast=True) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) + self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string()) + self.assertIsInstance(processor.tokenizer, Qwen2TokenizerFast) + self.assertIsInstance(processor.image_processor, Qwen2VLImageProcessorFast) + + def test_image_processor(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + video_processor = self.get_video_processor() + + processor = Qwen3VLProcessor( + tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor + ) + + image_input = self.prepare_image_inputs() + + input_image_proc = image_processor(image_input, return_tensors="pt") + input_processor = processor(images=image_input, text="dummy", return_tensors="pt") + + for key in input_image_proc: + self.assertAlmostEqual(input_image_proc[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_processor(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + video_processor = self.get_video_processor() + + processor = Qwen3VLProcessor( + tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor + ) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor(text=input_str, images=image_input) + + self.assertListEqual( + list(inputs.keys()), + ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"], + ) + + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() + + # test if it raises when no text is passed + with pytest.raises(TypeError): + processor(images=image_input) + + def test_model_input_names(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + video_processor = self.get_video_processor() + + processor = Qwen3VLProcessor( + tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor + ) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + video_inputs = self.prepare_video_inputs() + + inputs = processor(text=input_str, images=image_input, videos=video_inputs, do_sample_frames=False) + + self.assertListEqual(list(inputs.keys()), processor.model_input_names) + + @require_torch + @require_av + def _test_apply_chat_template( + self, + modality: str, + batch_size: int, + return_tensors: str, + input_name: str, + processor_name: str, + input_data: list[str], + ): + processor = self.get_processor() + if processor.chat_template is None: + self.skipTest("Processor has no chat template") + + if processor_name not in self.processor_class.attributes: + self.skipTest(f"{processor_name} attribute not present in {self.processor_class}") + + batch_messages = [ + [ + { + "role": "user", + "content": [{"type": "text", "text": "Describe this."}], + }, + ] + ] * batch_size + + # Test that jinja can be applied + formatted_prompt = processor.apply_chat_template(batch_messages, add_generation_prompt=True, tokenize=False) + self.assertEqual(len(formatted_prompt), batch_size) + + # Test that tokenizing with template and directly with `self.tokenizer` gives same output + formatted_prompt_tokenized = processor.apply_chat_template( + batch_messages, add_generation_prompt=True, tokenize=True, return_tensors=return_tensors + ) + add_special_tokens = True + if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token): + add_special_tokens = False + tok_output = processor.tokenizer( + formatted_prompt, return_tensors=return_tensors, add_special_tokens=add_special_tokens + ) + expected_output = tok_output.input_ids + self.assertListEqual(expected_output.tolist(), formatted_prompt_tokenized.tolist()) + + # Test that kwargs passed to processor's `__call__` are actually used + tokenized_prompt_100 = processor.apply_chat_template( + batch_messages, + add_generation_prompt=True, + tokenize=True, + padding="max_length", + truncation=True, + return_tensors=return_tensors, + max_length=100, + ) + self.assertEqual(len(tokenized_prompt_100[0]), 100) + + # Test that `return_dict=True` returns text related inputs in the dict + out_dict_text = processor.apply_chat_template( + batch_messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors=return_tensors, + ) + self.assertTrue(all(key in out_dict_text for key in ["input_ids", "attention_mask"])) + self.assertEqual(len(out_dict_text["input_ids"]), batch_size) + self.assertEqual(len(out_dict_text["attention_mask"]), batch_size) + + # Test that with modality URLs and `return_dict=True`, we get modality inputs in the dict + for idx, url in enumerate(input_data[:batch_size]): + batch_messages[idx][0]["content"] = [batch_messages[idx][0]["content"][0], {"type": modality, "url": url}] + + out_dict = processor.apply_chat_template( + batch_messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors=return_tensors, + max_frames=2, # by default no more than 2 frames, otherwise too slow + ) + input_name = getattr(self, input_name) + self.assertTrue(input_name in out_dict) + self.assertEqual(len(out_dict["input_ids"]), batch_size) + self.assertEqual(len(out_dict["attention_mask"]), batch_size) + + if modality == "video": + # qwen pixels don't scale with bs same way as other models, calculate expected video token count based on video_grid_thw + expected_video_token_count = 0 + for thw in out_dict["video_grid_thw"]: + expected_video_token_count += thw[0] * thw[1] * thw[2] + mm_len = expected_video_token_count + else: + mm_len = batch_size * 192 + self.assertEqual(len(out_dict[input_name]), mm_len) + + return_tensor_to_type = {"pt": torch.Tensor, "np": np.ndarray, None: list} + for k in out_dict: + self.assertIsInstance(out_dict[k], return_tensor_to_type[return_tensors]) + + @require_av + @unittest.skip("qwen3_vl can't sample frames from image frames directly, user can use `qwen-vl-utils`") + def test_apply_chat_template_video_1(self): + pass + + @require_av + @unittest.skip("qwen3_vl can't sample frames from image frames directly, user can use `qwen-vl-utils`") + def test_apply_chat_template_video_2(self): + pass + + @require_av + def test_apply_chat_template_video_frame_sampling(self): + processor = self.get_processor() + if processor.chat_template is None: + self.skipTest("Processor has no chat template") + + signature = inspect.signature(processor.__call__) + if "videos" not in {*signature.parameters.keys()} or ( + signature.parameters.get("videos") is not None + and signature.parameters["videos"].annotation == inspect._empty + ): + self.skipTest("Processor doesn't accept videos at input") + + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "video"}, + {"type": "text", "text": "What is shown in this video?"}, + ], + }, + ] + ] + + formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + self.assertEqual(len(formatted_prompt), 1) + + formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) + expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids + self.assertListEqual(expected_output, formatted_prompt_tokenized) + + out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True) + self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"]) + + # Add video URL for return dict and load with `num_frames` arg + messages[0][0]["content"][0] = { + "type": "video", + "url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4", + } + num_frames = 3 + out_dict_with_video = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + num_frames=num_frames, + ) + self.assertTrue(self.videos_input_name in out_dict_with_video) + self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 360) + + # Load with `fps` arg + fps = 1 + out_dict_with_video = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + fps=fps, + ) + self.assertTrue(self.videos_input_name in out_dict_with_video) + self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 900) + + # Load with `fps` and `num_frames` args, should raise an error + with self.assertRaises(ValueError): + out_dict_with_video = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + fps=fps, + num_frames=num_frames, + ) + + # Load without any arg should load the whole video + out_dict_with_video = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + ) + self.assertTrue(self.videos_input_name in out_dict_with_video) + self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 27000) + + # Load video as a list of frames (i.e. images). NOTE: each frame should have same size + # because we assume they come from one video + messages[0][0]["content"][0] = { + "type": "video", + "url": [ + "https://www.ilankelman.org/stopsigns/australia.jpg", + "https://www.ilankelman.org/stopsigns/australia.jpg", + ], + } + out_dict_with_video = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + do_sample_frames=False, + ) + self.assertTrue(self.videos_input_name in out_dict_with_video) + self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 160) + + def test_kwargs_overrides_custom_image_processor_kwargs(self): + processor = self.get_processor() + self.skip_processor_without_typed_kwargs(processor) + + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + inputs = processor(text=input_str, images=image_input, max_pixels=56 * 56 * 4, return_tensors="pt") + self.assertEqual(inputs[self.images_input_name].shape[0], 612) + inputs = processor(text=input_str, images=image_input, return_tensors="pt") + self.assertEqual(inputs[self.images_input_name].shape[0], 100) diff --git a/tests/models/qwen3_vl/test_video_processing_qwen3_vl.py b/tests/models/qwen3_vl/test_video_processing_qwen3_vl.py new file mode 100644 index 000000000000..9230f0f9502e --- /dev/null +++ b/tests/models/qwen3_vl/test_video_processing_qwen3_vl.py @@ -0,0 +1,330 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available + +from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs + + +if is_torch_available(): + from PIL import Image + +if is_vision_available() and is_torchvision_available(): + from transformers import Qwen3VLVideoProcessor + from transformers.models.qwen3_vl.video_processing_qwen3_vl import smart_resize + + +class Qwen3VLVideoProcessingTester: + def __init__( + self, + parent, + batch_size=5, + num_frames=8, + num_channels=3, + min_resolution=32, + max_resolution=80, + temporal_patch_size=2, + patch_size=16, + merge_size=2, + do_resize=True, + size=None, + do_normalize=True, + image_mean=IMAGENET_STANDARD_MEAN, + image_std=IMAGENET_STANDARD_STD, + do_convert_rgb=True, + ): + size = size if size is not None else {"longest_edge": 20, "shortest_edge": 10} + self.parent = parent + self.batch_size = batch_size + self.num_frames = num_frames + self.num_channels = num_channels + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + self.temporal_patch_size = temporal_patch_size + self.patch_size = patch_size + self.merge_size = merge_size + + def prepare_video_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + "do_sample_frames": True, + } + + def prepare_video_metadata(self, videos): + video_metadata = [] + for video in videos: + if isinstance(video, list): + num_frames = len(video) + elif hasattr(video, "shape"): + if len(video.shape) == 4: # (T, H, W, C) + num_frames = video.shape[0] + else: + num_frames = 1 + else: + num_frames = self.num_frames + + metadata = { + "fps": 2, + "duration": num_frames / 2, + "total_num_frames": num_frames, + } + video_metadata.append(metadata) + return video_metadata + + def expected_output_video_shape(self, videos): + grid_t = self.num_frames // self.temporal_patch_size + hidden_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size + seq_len = 0 + for video in videos: + if isinstance(video, list) and isinstance(video[0], Image.Image): + video = np.stack([np.array(frame) for frame in video]) + elif hasattr(video, "shape"): + pass + else: + video = np.array(video) + + if hasattr(video, "shape") and len(video.shape) >= 3: + if len(video.shape) == 4: + t, height, width = video.shape[:3] + elif len(video.shape) == 3: + height, width = video.shape[:2] + t = 1 + else: + t, height, width = self.num_frames, self.min_resolution, self.min_resolution + else: + t, height, width = self.num_frames, self.min_resolution, self.min_resolution + + resized_height, resized_width = smart_resize( + t, + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.size["shortest_edge"], + max_pixels=self.size["longest_edge"], + ) + grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size + seq_len += grid_t * grid_h * grid_w + return [seq_len, hidden_dim] + + def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"): + videos = prepare_video_inputs( + batch_size=self.batch_size, + num_frames=self.num_frames, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + return_tensors=return_tensors, + ) + return videos + + +@require_torch +@require_vision +class Qwen3VLVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase): + fast_video_processing_class = Qwen3VLVideoProcessor if is_torchvision_available() else None + input_name = "pixel_values_videos" + + def setUp(self): + super().setUp() + self.video_processor_tester = Qwen3VLVideoProcessingTester(self) + + @property + def video_processor_dict(self): + return self.video_processor_tester.prepare_video_processor_dict() + + def test_video_processor_from_dict_with_kwargs(self): + video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict) + self.assertEqual(video_processor.size, {"longest_edge": 20, "shortest_edge": 10}) + + video_processor = self.fast_video_processing_class.from_dict( + self.video_processor_dict, size={"longest_edge": 42, "shortest_edge": 42} + ) + self.assertEqual(video_processor.size, {"longest_edge": 42, "shortest_edge": 42}) + + def test_call_pil(self): + for video_processing_class in self.video_processor_list: + video_processing = video_processing_class(**self.video_processor_dict) + video_inputs = self.video_processor_tester.prepare_video_inputs( + equal_resolution=False, return_tensors="pil" + ) + + for video in video_inputs: + self.assertIsInstance(video[0], Image.Image) + + video_metadata = self.video_processor_tester.prepare_video_metadata(video_inputs) + encoded_videos = video_processing( + video_inputs[0], video_metadata=[video_metadata[0]], return_tensors="pt" + )[self.input_name] + expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]]) + self.assertEqual(list(encoded_videos.shape), expected_output_video_shape) + encoded_videos = video_processing(video_inputs, video_metadata=video_metadata, return_tensors="pt")[ + self.input_name + ] + expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs) + self.assertEqual(list(encoded_videos.shape), expected_output_video_shape) + + def test_call_numpy(self): + for video_processing_class in self.video_processor_list: + video_processing = video_processing_class(**self.video_processor_dict) + video_inputs = self.video_processor_tester.prepare_video_inputs( + equal_resolution=False, return_tensors="np" + ) + + video_metadata = self.video_processor_tester.prepare_video_metadata(video_inputs) + encoded_videos = video_processing( + video_inputs[0], video_metadata=[video_metadata[0]], return_tensors="pt" + )[self.input_name] + expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]]) + self.assertEqual(list(encoded_videos.shape), expected_output_video_shape) + + encoded_videos = video_processing(video_inputs, video_metadata=video_metadata, return_tensors="pt")[ + self.input_name + ] + expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs) + self.assertEqual(list(encoded_videos.shape), expected_output_video_shape) + + def test_call_pytorch(self): + for video_processing_class in self.video_processor_list: + video_processing = video_processing_class(**self.video_processor_dict) + video_inputs = self.video_processor_tester.prepare_video_inputs( + equal_resolution=False, return_tensors="pt" + ) + video_metadata = self.video_processor_tester.prepare_video_metadata(video_inputs) + encoded_videos = video_processing( + video_inputs[0], video_metadata=[video_metadata[0]], return_tensors="pt" + )[self.input_name] + expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]]) + self.assertEqual(list(encoded_videos.shape), expected_output_video_shape) + encoded_videos = video_processing(video_inputs, video_metadata=video_metadata, return_tensors="pt")[ + self.input_name + ] + expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs) + self.assertEqual(list(encoded_videos.shape), expected_output_video_shape) + + @unittest.skip("Skip for now, the test needs adjustment for Qwen3VL") + def test_call_numpy_4_channels(self): + for video_processing_class in self.video_processor_list: + # Test that can process videos which have an arbitrary number of channels + # Initialize video_processing + video_processor = video_processing_class(**self.video_processor_dict) + + # create random numpy tensors + self.video_processor_tester.num_channels = 4 + video_inputs = self.video_processor_tester.prepare_video_inputs( + equal_resolution=False, return_tensors="np" + ) + + # Test not batched input + encoded_videos = video_processor( + video_inputs[0], + return_tensors="pt", + input_data_format="channels_last", + image_mean=0, + image_std=1, + )[self.input_name] + expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]]) + self.assertEqual(list(encoded_videos.shape), expected_output_video_shape) + + # Test batched + encoded_videos = video_processor( + video_inputs, + return_tensors="pt", + input_data_format="channels_last", + image_mean=0, + image_std=1, + )[self.input_name] + expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs) + self.assertEqual(list(encoded_videos.shape), expected_output_video_shape) + + def test_nested_input(self): + """Tests that the processor can work with nested list where each video is a list of arrays""" + for video_processing_class in self.video_processor_list: + video_processing = video_processing_class(**self.video_processor_dict) + video_inputs = self.video_processor_tester.prepare_video_inputs( + equal_resolution=False, return_tensors="np" + ) + + video_inputs_nested = [list(video) for video in video_inputs] + video_metadata = self.video_processor_tester.prepare_video_metadata(video_inputs) + + # Test not batched input + encoded_videos = video_processing( + video_inputs_nested[0], video_metadata=[video_metadata[0]], return_tensors="pt" + )[self.input_name] + expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]]) + self.assertEqual(list(encoded_videos.shape), expected_output_video_shape) + + # Test batched + encoded_videos = video_processing(video_inputs_nested, video_metadata=video_metadata, return_tensors="pt")[ + self.input_name + ] + expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs) + self.assertEqual(list(encoded_videos.shape), expected_output_video_shape) + + def test_call_sample_frames(self): + for video_processing_class in self.video_processor_list: + video_processor_dict = self.video_processor_dict.copy() + video_processing = video_processing_class(**video_processor_dict) + + prev_num_frames = self.video_processor_tester.num_frames + self.video_processor_tester.num_frames = 8 + prev_min_resolution = getattr(self.video_processor_tester, "min_resolution", None) + prev_max_resolution = getattr(self.video_processor_tester, "max_resolution", None) + self.video_processor_tester.min_resolution = 56 + self.video_processor_tester.max_resolution = 112 + + video_inputs = self.video_processor_tester.prepare_video_inputs( + equal_resolution=False, + return_tensors="torch", + ) + + metadata = [[{"total_num_frames": 8, "fps": 4}]] + batched_metadata = metadata * len(video_inputs) + + encoded_videos = video_processing(video_inputs[0], return_tensors="pt", video_metadata=metadata)[ + self.input_name + ] + encoded_videos_batched = video_processing( + video_inputs, return_tensors="pt", video_metadata=batched_metadata + )[self.input_name] + + self.assertIsNotNone(encoded_videos) + self.assertIsNotNone(encoded_videos_batched) + self.assertEqual(len(encoded_videos.shape), 2) + self.assertEqual(len(encoded_videos_batched.shape), 2) + + self.video_processor_tester.num_frames = prev_num_frames + if prev_min_resolution is not None: + self.video_processor_tester.min_resolution = prev_min_resolution + if prev_max_resolution is not None: + self.video_processor_tester.max_resolution = prev_max_resolution diff --git a/tests/models/qwen3_vl_moe/__init__.py b/tests/models/qwen3_vl_moe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py new file mode 100644 index 000000000000..adae69a81fa8 --- /dev/null +++ b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py @@ -0,0 +1,298 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Qwen3VLMoe model.""" + +import copy +import unittest + +from transformers import ( + Qwen3VLMoeConfig, + Qwen3VLMoeForConditionalGeneration, + Qwen3VLMoeModel, + is_torch_available, +) +from transformers.testing_utils import ( + require_torch, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + floats_tensor, + ids_tensor, +) + + +if is_torch_available(): + import torch + + +class Qwen3VLMoeVisionText2TextModelTester: + def __init__( + self, + parent, + batch_size=3, + seq_length=7, + num_channels=3, + ignore_index=-100, + image_size=16, + text_config={ + "bos_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 2, + "hidden_act": "silu", + "hidden_size": 32, + "vocab_size": 99, + "intermediate_size": 37, + "max_position_embeddings": 512, + "model_type": "qwen3_vl_moe", + "num_attention_heads": 4, + "num_key_value_heads": 2, + "num_hidden_layers": 4, + "moe_intermediate_size": 16, + "num_experts_per_tok": 4, + "num_experts": 8, + "rope_theta": 10000, + "tie_word_embeddings": True, + "rope_scaling": {"rope_type": "default", "mrope_section": [16, 8, 8], "mrope_interleaved": True}, + }, + vision_config={ + "depth": 2, + "in_chans": 3, + "hidden_act": "gelu_pytorch_tanh", + "intermediate_size": 32, + "out_hidden_size": 32, + "hidden_size": 32, + "num_heads": 4, + "patch_size": 16, + "spatial_merge_size": 1, + "temporal_patch_size": 2, + "num_position_embeddings": 16, + "deepstack_visual_indexes": [0, 1], + }, + image_token_id=3, + video_token_id=4, + vision_start_token_id=5, + vision_end_token_id=6, + tie_word_embeddings=True, + is_training=True, + ): + self.parent = parent + self.ignore_index = ignore_index + self.is_training = is_training + + self.vision_config = vision_config + self.text_config = text_config + + self.vocab_size = text_config["vocab_size"] + self.bos_token_id = text_config["bos_token_id"] + self.eos_token_id = text_config["eos_token_id"] + self.pad_token_id = text_config["pad_token_id"] + self.hidden_size = text_config["hidden_size"] + self.intermediate_size = text_config["intermediate_size"] + self.num_hidden_layers = text_config["num_hidden_layers"] + self.num_attention_heads = text_config["num_attention_heads"] + self.num_key_value_heads = text_config["num_key_value_heads"] + self.rope_theta = text_config["rope_theta"] + self.rope_scaling = text_config["rope_scaling"] + self.hidden_act = text_config["hidden_act"] + self.max_position_embeddings = text_config["max_position_embeddings"] + self.model_type = text_config["model_type"] + + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.tie_word_embeddings = tie_word_embeddings + + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.num_image_tokens = 32 + self.seq_length = seq_length + self.num_image_tokens + + def get_config(self): + return Qwen3VLMoeConfig( + text_config=self.text_config, + vision_config=self.vision_config, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + vision_end_token_id=self.vision_end_token_id, + tie_word_embeddings=self.tie_word_embeddings, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + patch_size = config.vision_config.patch_size + temporal_patch_size = config.vision_config.temporal_patch_size + pixel_values = floats_tensor( + [ + self.batch_size * (self.image_size**2) // (patch_size**2), + self.num_channels * (patch_size**2) * temporal_patch_size, + ] + ) + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + input_ids[:, -1] = self.pad_token_id + input_ids[input_ids == self.video_token_id] = self.pad_token_id + input_ids[input_ids == self.image_token_id] = self.pad_token_id + input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id + input_ids[:, self.num_image_tokens] = self.image_token_id + input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id + inputs_dict = { + "pixel_values": pixel_values, + "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device), + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class Qwen3VLMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + """ + Model tester for `Qwen3VLMoeForConditionalGeneration`. + """ + + all_model_classes = ( + ( + Qwen3VLMoeModel, + Qwen3VLMoeForConditionalGeneration, + ) + if is_torch_available() + else () + ) + test_pruning = False + test_head_masking = False + + def setUp(self): + self.model_tester = Qwen3VLMoeVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=Qwen3VLMoeConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_mismatching_num_image_tokens(self): + """ + Tests that VLMs through an error with explicit message saying what is wrong + when number of images don't match number of image tokens in the text. + Also we need to test multi-image cases when one prompr has multiple image tokens. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + _ = model(**input_dict) # successful forward with no modifications + curr_input_dict = copy.deepcopy(input_dict) + + # remove one image but leave the image token in text + patch_size = config.vision_config.patch_size + one_img_length = (self.model_tester.image_size**2) // (patch_size**2) + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-one_img_length:, ...] + curr_input_dict["image_grid_thw"] = curr_input_dict["image_grid_thw"][-1:, ...] + with self.assertRaises(ValueError): + _ = model(**curr_input_dict) + + # simulate multi-image case by concatenating inputs where each has exactly one image/image-token + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:one_img_length] + image_grid_thw = curr_input_dict["image_grid_thw"][:1] + input_ids = torch.cat([input_ids, input_ids], dim=0) + + # one image and two image tokens raise an error + with self.assertRaises(ValueError): + _ = model( + input_ids=input_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + # two images and two image tokens don't raise an error + pixel_values = torch.cat([pixel_values, pixel_values], dim=0) + image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0) + _ = model( + input_ids=input_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + def test_video_forward(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + B = self.model_tester.batch_size + C = config.vision_config.in_chans + T = config.vision_config.temporal_patch_size + P = config.vision_config.patch_size + + input_ids = ids_tensor([B, self.model_tester.seq_length], self.model_tester.vocab_size) + + F = 4 + patch_H = self.model_tester.image_size // P + patch_W = self.model_tester.image_size // P + patch_T = F // T + patches_per_video = patch_T * patch_H * patch_W + pathed_per_frame = patch_H * patch_W + pixel_values_videos = floats_tensor( + [ + # first dim: batch_size * num_patches + B * patches_per_video, + # second dim: in_channels * temporal_patch_size * patch_size^2 + C * T * (P**2), + ] + ) + video_grid_thw = torch.tensor([[1, patch_H, patch_W] for _ in range(patch_T)] * B) + + # sanity check + self.assertEqual(pixel_values_videos.shape[0], video_grid_thw.prod(dim=1).sum().item()) + + # Insert video token sequence + input_ids[:, -1] = self.model_tester.pad_token_id + input_ids[input_ids == self.model_tester.video_token_id] = self.model_tester.pad_token_id + input_ids[input_ids == self.model_tester.image_token_id] = self.model_tester.pad_token_id + input_ids[input_ids == self.model_tester.vision_start_token_id] = self.model_tester.pad_token_id + input_ids[:, self.model_tester.num_image_tokens] = self.model_tester.video_token_id + + insertion_point = self.model_tester.num_image_tokens + + self.assertLessEqual((B * patches_per_video) + insertion_point, self.model_tester.seq_length) + for b in range(B): + # each frame is separated by a vision_start_token_id + for frame_idx in range(patch_T): + input_ids[b, insertion_point + frame_idx * (pathed_per_frame + 1)] = ( + self.model_tester.vision_start_token_id + ) + input_ids[ + b, + insertion_point + frame_idx * (pathed_per_frame + 1) + 1 : insertion_point + + (frame_idx + 1) * (pathed_per_frame + 1), + ] = self.model_tester.video_token_id + + for model_class in self.all_model_classes: + # TODO:we should remove this because we use timestamps for video + model = model_class(config).to(torch_device) + outputs = model( + input_ids=input_ids, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + self.assertIsNotNone(outputs) diff --git a/utils/check_repo.py b/utils/check_repo.py index 8a73468a1e49..e932e5bfc24c 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -71,6 +71,8 @@ "Qwen2AudioEncoder", "Qwen2VisionTransformerPretrainedModel", "Qwen2_5_VisionTransformerPretrainedModel", + "Qwen3VLVisionModel", + "Qwen3VLMoeVisionModel", "SwitchTransformersStack", "TFDPRSpanPredictor", "MaskFormerSwinModel", @@ -151,13 +153,17 @@ "ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model "Qwen2VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2VLForConditionalGeneration. "Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration. - "Qwen2_5OmniForConditionalGeneration", # Not a regular model. Testted in Qwen2_5OmniModelIntegrationTest - "Qwen2_5OmniTalkerForConditionalGeneration", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntegrationTest. - "Qwen2_5OmniTalkerModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntegrationTest. - "Qwen2_5OmniThinkerTextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntegrationTest. - "Qwen2_5OmniToken2WavModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntegrationTest. - "Qwen2_5OmniToken2WavDiTModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntegrationTest. - "Qwen2_5OmniToken2WavBigVGANModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntegrationTest. + "Qwen3VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3VLForConditionalGeneration. + "Qwen3VLMoeModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3VLMoeForConditionalGeneration. + "Qwen3VLTextModel", # Building part of bigger (tested) model. + "Qwen3VLMoeTextModel", # Building part of bigger (tested) model. + "Qwen2_5OmniForConditionalGeneration", # Not a regular model. Testted in Qwen2_5OmniModelIntergrationTest + "Qwen2_5OmniTalkerForConditionalGeneration", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. + "Qwen2_5OmniTalkerModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. + "Qwen2_5OmniThinkerTextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. + "Qwen2_5OmniToken2WavModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. + "Qwen2_5OmniToken2WavDiTModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. + "Qwen2_5OmniToken2WavBigVGANModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. "MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests "MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests "Llama4TextModel", # Building part of bigger (tested) model. # TODO: add tests