Skip to content
20 changes: 20 additions & 0 deletions docs/source/dataset_formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -1043,3 +1043,23 @@ An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](h
width="100%"
height="560px"
></iframe>

> [!NOTE]
> Mixing text-only and vision-language data in the dataset is possible, but it requires `transformers` version 4.57.0 or later. Example:
>
> ```python
> dataset = Dataset.from_dict({
> "prompt": [
> [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky in the image?"}]}],
> [{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}],
> ],
> "completion": [
> [{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}],
> [{"role": "assistant", "content": [{"type": "text", "text": "Paris."}]}],
> ],
> "images": [
> [PIL.Image.open("path/to/sky_image1.png")],
> [],
> ],
> })
> ```
3 changes: 0 additions & 3 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,9 +1560,6 @@ def reward_func(completions, **kwargs):
def test_training_vlm_multi_image(self):
dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train")

# For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples
dataset = dataset.filter(lambda x: len(x["images"]) > 0)

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]
Expand Down
3 changes: 0 additions & 3 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,9 +1301,6 @@ def reward_func(completions, **kwargs):
def test_training_vlm_multi_image(self):
dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train")

# For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples
dataset = dataset.filter(lambda x: len(x["images"]) > 0)

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]
Expand Down
40 changes: 40 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import pytest
import torch
import transformers
from datasets import load_dataset
from packaging.version import parse as parse_version
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_flash_attn, require_liger_kernel
Expand Down Expand Up @@ -1302,6 +1304,44 @@ def test_train_vlm(self, model_id):
continue
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"

@pytest.mark.xfail(
parse_version(transformers.__version__) < parse_version("4.57.0"),
reason="Mixing text-only and image+text examples is only supported in transformers >= 4.57.0",
strict=False,
)
@require_vision
def test_train_vlm_multi_image(self):
# Get the dataset
dataset = load_dataset(
"trl-internal-testing/zen-multi-image", "conversational_prompt_completion", split="train"
)

# Initialize the trainer
training_args = SFTConfig(
output_dir=self.tmp_dir,
max_length=None, # For VLMs, truncating can remove image tokens, leading to errors
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
args=training_args,
train_dataset=dataset,
)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"

@require_vision
def test_train_vlm_prompt_completion(self):
# Get the dataset
Expand Down
3 changes: 3 additions & 0 deletions trl/experimental/gfpo/gfpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def _generate_and_score_completions(self, inputs):
images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
else:
images = None
# Transformers requires at least one image in the batch, otherwise it throws an error
if images is not None and all(img_list == [] for img_list in images):
images = None

(
prompt_ids_list,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def _generate_and_score_completions(
images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
else:
images = None
# Transformers requires at least one image in the batch, otherwise it throws an error
if images is not None and all(img_list == [] for img_list in images):
images = None

(
prompt_ids,
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,9 @@ def _generate_and_score_completions(
images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
else:
images = None
# Transformers requires at least one image in the batch, otherwise it throws an error
if images is not None and all(img_list == [] for img_list in images):
images = None

(
prompt_ids_list,
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,9 @@ def _generate_and_score_completions(
images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
else:
images = None
# Transformers requires at least one image in the batch, otherwise it throws an error
if images is not None and all(img_list == [] for img_list in images):
images = None

prompt_ids_list, completion_ids_list, forward_kwargs = self._generate(prompts, images)

Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:

def _collate_language_modeling(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
images = [example["images"] for example in examples]
# Transformers requires at least one image in the batch, otherwise it throws an error
if all(img_list == [] for img_list in images):
images = None

if "messages" in examples[0]: # conversational case
for example in examples:
Expand Down Expand Up @@ -388,6 +391,9 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str
"prompt-completion data yet."
)
images = [example["images"] for example in examples]
# Transformers requires at least one image in the batch, otherwise it throws an error
if all(img_list == [] for img_list in images):
images = None
if is_conversational(examples[0]): # conversational case
for example in examples:
prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"]))
Expand Down
Loading