Skip to content

Commit 26b7c25

Browse files
aweersqgallouedec
andauthored
Add support for token_type_ids in DPOTrainer (#4285)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent aa25c26 commit 26b7c25

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

tests/test_dpo_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,7 @@ class TestDPOVisionTrainer(TrlTestCase):
14221422
# ("trl-internal-testing/tiny-PaliGemmaForConditionalGeneration",),
14231423
("trl-internal-testing/tiny-LlavaForConditionalGeneration",),
14241424
("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",),
1425+
("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",),
14251426
]
14261427
)
14271428
def test_vdpo_trainer(self, model_id):

trl/trainer/dpo_trainer.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
177177
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
178178
output["ref_chosen_logps"] = ref_chosen_logps
179179
output["ref_rejected_logps"] = ref_rejected_logps
180+
if "token_type_ids" in examples[0]:
181+
token_type_ids = [torch.tensor(example["token_type_ids"]) for example in examples]
182+
output["token_type_ids"] = pad(token_type_ids, padding_value=0, padding_side="left")
180183

181184
return output
182185

@@ -790,6 +793,8 @@ def process_row(
790793
output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
791794
if "image_sizes" in processed_features:
792795
output["image_sizes"] = processed_features["image_sizes"][0]
796+
if "token_type_ids" in processed_features:
797+
output["token_type_ids"] = processed_features["token_type_ids"][0]
793798

794799
return output
795800

@@ -804,6 +809,7 @@ def _set_signature_columns_if_needed(self):
804809
"chosen_input_ids",
805810
"rejected_input_ids",
806811
"image_sizes",
812+
"token_type_ids",
807813
"ref_chosen_logps",
808814
"ref_rejected_logps",
809815
]
@@ -991,6 +997,8 @@ def concatenated_inputs(
991997
)
992998
if "image_sizes" in batch:
993999
output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)
1000+
if "token_type_ids" in batch:
1001+
output["token_type_ids"] = torch.cat((batch["token_type_ids"], batch["token_type_ids"]))
9941002

9951003
# Concatenate the chosen and rejected completions
9961004
max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
@@ -1516,6 +1524,9 @@ def concatenated_forward(
15161524
# Concatenate the prompt and completion inputs
15171525
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
15181526
attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
1527+
if "token_type_ids" in concatenated_batch:
1528+
prompt_token_type_ids = concatenated_batch["token_type_ids"]
1529+
token_type_ids = pad_to_length(prompt_token_type_ids, input_ids.shape[1], 0)
15191530
# Mask the prompt but not the completion for the loss
15201531
loss_mask = torch.cat(
15211532
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
@@ -1528,19 +1539,35 @@ def concatenated_forward(
15281539
# Flush left to reduce the memory usage
15291540
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
15301541
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
1531-
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
1542+
if "token_type_ids" in concatenated_batch:
1543+
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
1544+
attention_mask, input_ids, loss_mask, token_type_ids
1545+
)
1546+
else:
1547+
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
15321548
attention_mask = attention_mask[:, : self.max_length]
15331549
input_ids = input_ids[:, : self.max_length]
15341550
loss_mask = loss_mask[:, : self.max_length]
15351551
elif self.truncation_mode == "keep_end":
15361552
# Flush right before truncating left, then flush left
15371553
# [[0, 0, x, x, x, x], -> [[0, 0, x, x],
15381554
# [0, x, x, x, 0, 0]] [0, x, x, x]]
1539-
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
1555+
if "token_type_ids" in concatenated_batch:
1556+
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
1557+
attention_mask, input_ids, loss_mask, token_type_ids
1558+
)
1559+
token_type_ids = token_type_ids[:, -self.max_length :]
1560+
else:
1561+
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
15401562
input_ids = input_ids[:, -self.max_length :]
15411563
attention_mask = attention_mask[:, -self.max_length :]
15421564
loss_mask = loss_mask[:, -self.max_length :]
1543-
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
1565+
if "token_type_ids" in concatenated_batch:
1566+
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
1567+
attention_mask, input_ids, loss_mask, token_type_ids
1568+
)
1569+
else:
1570+
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
15441571
else:
15451572
raise ValueError(
15461573
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
@@ -1550,7 +1577,15 @@ def concatenated_forward(
15501577
# Flush left to reduce the memory usage
15511578
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
15521579
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
1553-
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
1580+
if "token_type_ids" in concatenated_batch:
1581+
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
1582+
attention_mask, input_ids, loss_mask, token_type_ids
1583+
)
1584+
else:
1585+
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
1586+
1587+
if "token_type_ids" in concatenated_batch:
1588+
model_kwargs["token_type_ids"] = token_type_ids
15541589

15551590
if self.use_logits_to_keep:
15561591
# Compute logits_to_keep based on loss_mask pattern:

0 commit comments

Comments
 (0)