Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling
from ...utils.deprecation import deprecate_kwarg
from ...utils.generic import check_model_inputs
from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig
Expand Down Expand Up @@ -1104,7 +1104,7 @@ def get_placeholder_mask(
return special_image_mask, special_video_mask

@auto_docstring
@can_return_tuple
@check_model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -1235,8 +1235,6 @@ def forward(
return Qwen3VLModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)

Expand Down Expand Up @@ -1313,8 +1311,7 @@ def language_model(self):
def visual(self):
return self.model.visual

@can_return_tuple
@auto_docstring
@check_model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -1372,8 +1369,6 @@ def forward(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=outputs.rope_deltas,
)

Expand Down
9 changes: 3 additions & 6 deletions src/transformers/models/qwen3_vl/modular_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
from ...utils import auto_docstring, is_torchdynamo_compiling, logging
from ...utils.generic import check_model_inputs
from ...video_utils import VideoInput
from ..qwen2_5_vl.modeling_qwen2_5_vl import (
Expand Down Expand Up @@ -1006,7 +1006,7 @@ def get_video_features(
return self.get_image_features(pixel_values_videos, video_grid_thw)

@auto_docstring
@can_return_tuple
@check_model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -1137,8 +1137,6 @@ def forward(
return Qwen3VLModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)

Expand All @@ -1151,6 +1149,7 @@ class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
config: Qwen3VLConfig
_checkpoint_conversion_mapping = {}

@check_model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -1208,8 +1207,6 @@ def forward(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=outputs.rope_deltas,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class Qwen3VLMoeTextConfig(PretrainedConfig):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the topk probabilities.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
Expand Down Expand Up @@ -178,6 +180,7 @@ def __init__(
num_experts_per_tok=4,
num_experts=60,
norm_topk_prob=True,
router_aux_loss_coef=0.001,
mlp_only_layers=None,
rope_scaling=None,
head_dim=None,
Expand Down Expand Up @@ -213,6 +216,7 @@ def __init__(
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
Expand Down
Loading