Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions scripts/generate_tiny_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
Qwen3MoeConfig,
Qwen3MoeForCausalLM,
Qwen3MoeForSequenceClassification,
Qwen3VLConfig,
Qwen3VLForConditionalGeneration,
SmolVLMForConditionalGeneration,
T5ForConditionalGeneration,
)
Expand Down Expand Up @@ -313,6 +315,7 @@ def init_weights_tiny_model(model):
("OpenGVLab/InternVL3-8B-hf", InternVLForConditionalGeneration),
("Qwen/Qwen2-VL-2B-Instruct", Qwen2VLForConditionalGeneration),
("Qwen/Qwen2.5-VL-3B-Instruct", Qwen2_5_VLForConditionalGeneration),
("Qwen/Qwen3-VL-2B-Instruct", Qwen3VLForConditionalGeneration),
]:
processor = AutoProcessor.from_pretrained(model_id)

Expand Down Expand Up @@ -350,6 +353,16 @@ def init_weights_tiny_model(model):
if issubclass(model_class.config_class, Idefics2Config):
kwargs["perceiver_config"] = {"hidden_size": 16}

if issubclass(model_class.config_class, Qwen3VLConfig):
# So hasattr(config, "layer_types") is False
# See: https://github.com/huggingface/transformers/blob/fe5ca9ddaa07fac2872407e75c7a7661216ac956/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L420
del text_config["layer_types"]
# "mrope_section" needs 3 elements: for dim, offset in enumerate((1, 2), start=1): mrope_section[dim]
# See: https://github.com/huggingface/transformers/blob/fe5ca9ddaa07fac2872407e75c7a7661216ac956/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L361
text_config["rope_scaling"] = {"mrope_interleaved": True, "mrope_section": [2, 2, 2], "rope_type": "default"}
vision_config["depth"] = 2
vision_config["out_hidden_size"] = 16

config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs)
model = model_class(config).to(dtype=torch.bfloat16)
push_to_hub(model, processor, "tiny")
11 changes: 10 additions & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import transformers
from accelerate.utils.memory import release_memory
from datasets import load_dataset
from packaging.version import Version
from packaging.version import parse as parse_version
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.testing_utils import backend_empty_cache, torch_device
Expand Down Expand Up @@ -1345,6 +1346,13 @@ def test_tag_added_peft(self):
"trl-internal-testing/tiny-Qwen2VLForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
# "trl-internal-testing/tiny-SmolVLMForConditionalGeneration", device issue from transformers, see https://github.com/huggingface/transformers/pull/39975
pytest.param(
"trl-internal-testing/tiny-Qwen3VLForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("4.57.0"),
reason="Qwen3-VL series were introduced in transformers-4.57.0",
),
),
],
)
@require_vision
Expand Down Expand Up @@ -1380,7 +1388,8 @@ def test_train_vlm(self, model_id):
model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or
model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or
model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or
model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n
model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or
model_id == "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration" and "model.visual.deepstack_merger_list" in n
):
# fmt: on
continue
Expand Down
Loading