@@ -177,6 +177,9 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
177177 if "ref_chosen_logps" in examples [0 ] and "ref_rejected_logps" in examples [0 ]:
178178 output ["ref_chosen_logps" ] = ref_chosen_logps
179179 output ["ref_rejected_logps" ] = ref_rejected_logps
180+ if "token_type_ids" in examples [0 ]:
181+ token_type_ids = [torch .tensor (example ["token_type_ids" ]) for example in examples ]
182+ output ["token_type_ids" ] = pad (token_type_ids , padding_value = 0 , padding_side = "left" )
180183
181184 return output
182185
@@ -790,6 +793,8 @@ def process_row(
790793 output ["pixel_attention_mask" ] = processed_features ["pixel_attention_mask" ][0 ]
791794 if "image_sizes" in processed_features :
792795 output ["image_sizes" ] = processed_features ["image_sizes" ][0 ]
796+ if "token_type_ids" in processed_features :
797+ output ["token_type_ids" ] = processed_features ["token_type_ids" ][0 ]
793798
794799 return output
795800
@@ -804,6 +809,7 @@ def _set_signature_columns_if_needed(self):
804809 "chosen_input_ids" ,
805810 "rejected_input_ids" ,
806811 "image_sizes" ,
812+ "token_type_ids" ,
807813 "ref_chosen_logps" ,
808814 "ref_rejected_logps" ,
809815 ]
@@ -991,6 +997,8 @@ def concatenated_inputs(
991997 )
992998 if "image_sizes" in batch :
993999 output ["image_sizes" ] = torch .cat ([batch ["image_sizes" ], batch ["image_sizes" ]], dim = 0 )
1000+ if "token_type_ids" in batch :
1001+ output ["token_type_ids" ] = torch .cat ((batch ["token_type_ids" ], batch ["token_type_ids" ]))
9941002
9951003 # Concatenate the chosen and rejected completions
9961004 max_completion_length = max (batch ["chosen_input_ids" ].shape [1 ], batch ["rejected_input_ids" ].shape [1 ])
@@ -1516,6 +1524,9 @@ def concatenated_forward(
15161524 # Concatenate the prompt and completion inputs
15171525 input_ids = torch .cat ((prompt_input_ids , completion_input_ids ), dim = 1 )
15181526 attention_mask = torch .cat ((prompt_attention_mask , completion_attention_mask ), dim = 1 )
1527+ if "token_type_ids" in concatenated_batch :
1528+ prompt_token_type_ids = concatenated_batch ["token_type_ids" ]
1529+ token_type_ids = pad_to_length (prompt_token_type_ids , input_ids .shape [1 ], 0 )
15191530 # Mask the prompt but not the completion for the loss
15201531 loss_mask = torch .cat (
15211532 (torch .zeros_like (prompt_attention_mask ), completion_attention_mask ),
@@ -1528,19 +1539,35 @@ def concatenated_forward(
15281539 # Flush left to reduce the memory usage
15291540 # [[0, 0, x, x, x, x], -> [[x, x, x, x],
15301541 # [0, x, x, x, 0, 0]] [x, x, x, 0]]
1531- attention_mask , input_ids , loss_mask = flush_left (attention_mask , input_ids , loss_mask )
1542+ if "token_type_ids" in concatenated_batch :
1543+ attention_mask , input_ids , loss_mask , token_type_ids = flush_left (
1544+ attention_mask , input_ids , loss_mask , token_type_ids
1545+ )
1546+ else :
1547+ attention_mask , input_ids , loss_mask = flush_left (attention_mask , input_ids , loss_mask )
15321548 attention_mask = attention_mask [:, : self .max_length ]
15331549 input_ids = input_ids [:, : self .max_length ]
15341550 loss_mask = loss_mask [:, : self .max_length ]
15351551 elif self .truncation_mode == "keep_end" :
15361552 # Flush right before truncating left, then flush left
15371553 # [[0, 0, x, x, x, x], -> [[0, 0, x, x],
15381554 # [0, x, x, x, 0, 0]] [0, x, x, x]]
1539- attention_mask , input_ids , loss_mask = flush_right (attention_mask , input_ids , loss_mask )
1555+ if "token_type_ids" in concatenated_batch :
1556+ attention_mask , input_ids , loss_mask , token_type_ids = flush_left (
1557+ attention_mask , input_ids , loss_mask , token_type_ids
1558+ )
1559+ token_type_ids = token_type_ids [:, - self .max_length :]
1560+ else :
1561+ attention_mask , input_ids , loss_mask = flush_right (attention_mask , input_ids , loss_mask )
15401562 input_ids = input_ids [:, - self .max_length :]
15411563 attention_mask = attention_mask [:, - self .max_length :]
15421564 loss_mask = loss_mask [:, - self .max_length :]
1543- attention_mask , input_ids , loss_mask = flush_left (attention_mask , input_ids , loss_mask )
1565+ if "token_type_ids" in concatenated_batch :
1566+ attention_mask , input_ids , loss_mask , token_type_ids = flush_left (
1567+ attention_mask , input_ids , loss_mask , token_type_ids
1568+ )
1569+ else :
1570+ attention_mask , input_ids , loss_mask = flush_left (attention_mask , input_ids , loss_mask )
15441571 else :
15451572 raise ValueError (
15461573 f"Unknown truncation mode: '{ self .truncation_mode } '. Should be one of ['keep_end', "
@@ -1550,7 +1577,15 @@ def concatenated_forward(
15501577 # Flush left to reduce the memory usage
15511578 # [[0, 0, x, x, x, x], -> [[x, x, x, x],
15521579 # [0, x, x, x, 0, 0]] [x, x, x, 0]]
1553- attention_mask , input_ids , loss_mask = flush_left (attention_mask , input_ids , loss_mask )
1580+ if "token_type_ids" in concatenated_batch :
1581+ attention_mask , input_ids , loss_mask , token_type_ids = flush_left (
1582+ attention_mask , input_ids , loss_mask , token_type_ids
1583+ )
1584+ else :
1585+ attention_mask , input_ids , loss_mask = flush_left (attention_mask , input_ids , loss_mask )
1586+
1587+ if "token_type_ids" in concatenated_batch :
1588+ model_kwargs ["token_type_ids" ] = token_type_ids
15541589
15551590 if self .use_logits_to_keep :
15561591 # Compute logits_to_keep based on loss_mask pattern:
0 commit comments