Skip to content

Commit a2de293

Browse files
JJJYmmmArthurZucker
authored andcommitted
Add processor and intergration test for qwen3vl (#41277)
* support aux loss in qwen3vlmoe * update qwen3vl processor test! * add integration tests for qwen3vl-30a3 * remove duplicated decorator * code clean * fix consistency * do not inherit from nn.Linear for better quantization * pass check
1 parent f8ec172 commit a2de293

File tree

7 files changed

+623
-108
lines changed

7 files changed

+623
-108
lines changed

src/transformers/models/qwen3_vl/modeling_qwen3_vl.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
3838
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
3939
from ...processing_utils import Unpack
40-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
40+
from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling
4141
from ...utils.deprecation import deprecate_kwarg
4242
from ...utils.generic import check_model_inputs
4343
from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig
@@ -1104,7 +1104,7 @@ def get_placeholder_mask(
11041104
return special_image_mask, special_video_mask
11051105

11061106
@auto_docstring
1107-
@can_return_tuple
1107+
@check_model_inputs
11081108
def forward(
11091109
self,
11101110
input_ids: torch.LongTensor = None,
@@ -1235,8 +1235,6 @@ def forward(
12351235
return Qwen3VLModelOutputWithPast(
12361236
last_hidden_state=outputs.last_hidden_state,
12371237
past_key_values=outputs.past_key_values,
1238-
hidden_states=outputs.hidden_states,
1239-
attentions=outputs.attentions,
12401238
rope_deltas=self.rope_deltas,
12411239
)
12421240

@@ -1313,8 +1311,7 @@ def language_model(self):
13131311
def visual(self):
13141312
return self.model.visual
13151313

1316-
@can_return_tuple
1317-
@auto_docstring
1314+
@check_model_inputs
13181315
def forward(
13191316
self,
13201317
input_ids: torch.LongTensor = None,
@@ -1372,8 +1369,6 @@ def forward(
13721369
loss=loss,
13731370
logits=logits,
13741371
past_key_values=outputs.past_key_values,
1375-
hidden_states=outputs.hidden_states,
1376-
attentions=outputs.attentions,
13771372
rope_deltas=outputs.rope_deltas,
13781373
)
13791374

src/transformers/models/qwen3_vl/modular_qwen3_vl.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
3434
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
3535
from ...tokenization_utils_base import PreTokenizedInput, TextInput
36-
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
36+
from ...utils import auto_docstring, is_torchdynamo_compiling, logging
3737
from ...utils.generic import check_model_inputs
3838
from ...video_utils import VideoInput
3939
from ..qwen2_5_vl.modeling_qwen2_5_vl import (
@@ -1006,7 +1006,7 @@ def get_video_features(
10061006
return self.get_image_features(pixel_values_videos, video_grid_thw)
10071007

10081008
@auto_docstring
1009-
@can_return_tuple
1009+
@check_model_inputs
10101010
def forward(
10111011
self,
10121012
input_ids: torch.LongTensor = None,
@@ -1137,8 +1137,6 @@ def forward(
11371137
return Qwen3VLModelOutputWithPast(
11381138
last_hidden_state=outputs.last_hidden_state,
11391139
past_key_values=outputs.past_key_values,
1140-
hidden_states=outputs.hidden_states,
1141-
attentions=outputs.attentions,
11421140
rope_deltas=self.rope_deltas,
11431141
)
11441142

@@ -1151,6 +1149,7 @@ class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
11511149
config: Qwen3VLConfig
11521150
_checkpoint_conversion_mapping = {}
11531151

1152+
@check_model_inputs
11541153
def forward(
11551154
self,
11561155
input_ids: torch.LongTensor = None,
@@ -1208,8 +1207,6 @@ def forward(
12081207
loss=loss,
12091208
logits=logits,
12101209
past_key_values=outputs.past_key_values,
1211-
hidden_states=outputs.hidden_states,
1212-
attentions=outputs.attentions,
12131210
rope_deltas=outputs.rope_deltas,
12141211
)
12151212

src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class Qwen3VLMoeTextConfig(PretrainedConfig):
8080
Number of routed experts.
8181
norm_topk_prob (`bool`, *optional*, defaults to `True`):
8282
Whether to normalize the topk probabilities.
83+
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
84+
The aux loss factor for the total loss.
8385
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
8486
Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock
8587
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
@@ -178,6 +180,7 @@ def __init__(
178180
num_experts_per_tok=4,
179181
num_experts=60,
180182
norm_topk_prob=True,
183+
router_aux_loss_coef=0.001,
181184
mlp_only_layers=None,
182185
rope_scaling=None,
183186
head_dim=None,
@@ -213,6 +216,7 @@ def __init__(
213216
self.num_experts_per_tok = num_experts_per_tok
214217
self.num_experts = num_experts
215218
self.norm_topk_prob = norm_topk_prob
219+
self.router_aux_loss_coef = router_aux_loss_coef
216220
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
217221

218222
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

0 commit comments

Comments
 (0)