Skip to content

Conversation

@NDNM1408
Copy link

This PR adds a Qwen3VLGRPOTrainer,.

Motivation

When I was training Qwen3-VL using GRPOTrainer, I encountered the error IndexError: metadata = video_metadata[index].
While debugging, I found that the original GRPOTrainer class, when given video inputs, produced messages where video fields and fps values appeared inside "text"-type chunks with fps=None and video=None, which caused the bug.
To fix this, I created a new class that normalizes the messages before generation and prevents this error.

What this PR changes

  • Adds a new class: Qwen3VLVideoGRPOTrainer(GRPOTrainer) in grpo_trainer.py.

    • Overrides _generate_single_turn only; all GRPO/DAPO logic (rewards, advantages, clipping, logging, etc.) remains unchanged.
    • Cleans and normalizes the multi-modal conversation structure while preserving:
      • {"type": "video", "video": ..., "fps": ...}
      • {"type": "text", "text": ...}
    • For each conversation in the batch:
      • Calls self.processing_class.apply_chat_template(...) directly on the Qwen3-VL conversation.
      • Sends the resulting inputs to the correct device.
      • Runs model.generate(..., generation_config=self.generation_config).
      • Splits prompt_ids and completion_ids using the prompt length, and returns them in the format expected by the GRPO pipeline.
    • Currently supports only the standard transformers.generate path and explicitly errors if use_vllm=True or use_transformers_paged=True, to keep behavior simple and predictable for this first iteration.
  • Exports Qwen3VLGRPOTrainer:

    • From src/trl/trainers/__init__.py.
    • From src/trl/__init__.py.

This design keeps the change isolated and backward compatible: existing users of GRPOTrainer are unaffected, while Qwen3-VL users can opt into the specialized trainer.

Usage

Example (simplified):

from trl import Qwen3VLVideoGRPOTrainer, GRPOConfig

config = GRPOConfig(
    output_dir="qwen3vl-video-grpo",
    loss_type="dapo",
    use_vllm=False,
    use_transformers_paged=False,
    num_generations=4,
    # other GRPOConfig parameters...
)

trainer = Qwen3VLVideoGRPOTrainer(
    model="Qwen/Qwen3-VL-8B-Instruct",
    reward_funcs=my_reward_fn,       
    args=config,
    train_dataset=my_qwen3vl_video_dataset, 
)

trainer.train()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant