Skip to content

Commit 237900d

Browse files
Fix bug with VLM processors in prompt-completion completion text-only training (#4553)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 52ed4df commit 237900d

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

tests/test_sft_trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,10 +1532,14 @@ def test_train_vlm_gemma_3n(self):
15321532
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
15331533
],
15341534
)
1535+
@pytest.mark.parametrize(
1536+
"dataset_config",
1537+
["conversational_language_modeling", "conversational_prompt_completion", "standard_prompt_completion"],
1538+
)
15351539
@require_vision
1536-
def test_train_vlm_text_only_data(self, model_id):
1540+
def test_train_vlm_text_only_data(self, model_id, dataset_config):
15371541
# Get the dataset
1538-
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
1542+
dataset = load_dataset("trl-internal-testing/zen", dataset_config, split="train")
15391543

15401544
# Initialize the trainer
15411545
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")

trl/trainer/sft_trainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,9 +935,10 @@ def add_eos(example, eos_token):
935935
example["completion"] = example["completion"] + eos_token
936936
return example
937937

938+
eos_token = processing_class.tokenizer.eos_token if self._is_vlm else processing_class.eos_token
938939
dataset = dataset.map(
939940
add_eos,
940-
fn_kwargs={"eos_token": processing_class.eos_token},
941+
fn_kwargs={"eos_token": eos_token},
941942
remove_columns="messages" if "messages" in column_names else None, # renamed to "text"
942943
**map_kwargs,
943944
)
@@ -988,6 +989,14 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
988989
prompt_completion_ids = processing_class(text=example["prompt"] + example["completion"])[
989990
"input_ids"
990991
]
992+
# Fix transformers inconsistency: for VLMs, processing_class returns lists of lists
993+
# even for single examples, while for LLMs it returns lists of ints.
994+
prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids
995+
prompt_completion_ids = (
996+
prompt_completion_ids[0]
997+
if isinstance(prompt_completion_ids[0], list)
998+
else prompt_completion_ids
999+
)
9911000

9921001
# Check if the tokenized prompt starts with the tokenized prompt+completion
9931002
if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids:

0 commit comments

Comments
 (0)