1818import torch
1919from accelerate .utils import gather_object
2020
21- from trl .data_utils import is_conversational
21+ from trl .data_utils import (
22+ apply_chat_template ,
23+ is_conversational ,
24+ prepare_multimodal_messages ,
25+ )
2226from trl .trainer .grpo_trainer import GRPOTrainer
2327from trl .trainer .utils import nanmax , nanmin , nanstd , pad
2428
@@ -80,19 +84,36 @@ def _generate_and_score_completions(
8084 if images is not None and all (img_list == [] for img_list in images ):
8185 images = None
8286
83- (
84- prompt_ids ,
85- completion_ids ,
86- prompt_mask ,
87- completion_mask ,
88- num_items_in_batch ,
89- sampling_per_token_logps ,
90- forward_kwargs ,
91- ) = self ._generate (prompts , images )
87+ # If the prompts are conversational and the inputs contain images, we need to convert the prompts from
88+ # [{"role": "user", "content": "What color is the sky?"}] to
89+ # [{"role": "user", "content": [{"type": "image", "image": <Image>}, {"type": "text", "text": "What color is the sky?"}]}]
90+ if images is not None :
91+ prompts = [prepare_multimodal_messages (prompt , image_list ) for prompt , image_list in zip (prompts , images )]
92+
93+ prompt_ids_list , completion_ids_list , num_items_in_batch , sampling_per_token_logps_list , extra_fields = (
94+ self ._generate (prompts )
95+ )
96+
97+ # Convert lists of token IDs to padded tensors
98+ prompt_ids = [torch .tensor (ids , device = device ) for ids in prompt_ids_list ]
99+ prompt_mask = [torch .ones_like (ids , dtype = torch .long ) for ids in prompt_ids ]
100+ prompt_ids = pad (prompt_ids , padding_value = self .pad_token_id , padding_side = "left" )
101+ prompt_mask = pad (prompt_mask , padding_value = 0 , padding_side = "left" )
102+ completion_ids = [torch .tensor (ids , device = device ) for ids in completion_ids_list ]
103+ completion_mask = [torch .ones_like (ids , dtype = torch .long ) for ids in completion_ids ]
104+ completion_ids = pad (completion_ids , padding_value = self .pad_token_id , padding_side = "right" )
105+ completion_mask = pad (completion_mask , padding_value = 0 , padding_side = "right" )
106+ if sampling_per_token_logps_list is not None :
107+ sampling_per_token_logps = [torch .tensor (logps , device = device ) for logps in sampling_per_token_logps_list ]
108+ sampling_per_token_logps = pad (sampling_per_token_logps , padding_value = 0.0 , padding_side = "right" )
109+ else :
110+ sampling_per_token_logps = None
92111
93- # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
94- # to re-tokenize completions if the reward is computed from tokens.
95- completion_ids_list = [row [mask_row ].tolist () for row , mask_row in zip (completion_ids , completion_mask .bool ())]
112+ # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
113+ if self .mask_truncated_completions :
114+ eos_and_pad = [self .eos_token_id , self .pad_token_id ]
115+ is_truncated = torch .tensor ([ids [- 1 ] not in eos_and_pad for ids in completion_ids_list ], device = device )
116+ completion_mask = completion_mask * (~ is_truncated ).unsqueeze (1 ).int ()
96117
97118 # Concatenate prompt_mask with completion_mask for logit computation
98119 prompt_completion_ids = torch .cat ([prompt_ids , completion_ids ], dim = 1 ) # (B, P+C)
@@ -103,6 +124,25 @@ def _generate_and_score_completions(
103124
104125 num_images = [len (img_list ) for img_list in images ] if images is not None else None
105126
127+ # Get forward_kwargs for models with multimodal inputs
128+ if images is not None :
129+ prompts_text = [
130+ apply_chat_template ({"prompt" : prompt }, self .processing_class , ** self .chat_template_kwargs )["prompt" ]
131+ for prompt in prompts
132+ ]
133+ prompt_inputs = self .processing_class (images = images , text = prompts_text , padding = True , return_tensors = "pt" )
134+ prompt_inputs = super ()._prepare_inputs (prompt_inputs )
135+ forward_kwargs = {k : v for k , v in prompt_inputs .items () if k not in ["input_ids" , "attention_mask" ]}
136+ else :
137+ forward_kwargs = {}
138+
139+ # If token_type_ids are used, extend them with zeros for the completion part
140+ if "token_type_ids" in forward_kwargs :
141+ token_type_ids = forward_kwargs ["token_type_ids" ]
142+ forward_kwargs ["token_type_ids" ] = torch .cat (
143+ [token_type_ids , token_type_ids .new_zeros (completion_ids .shape )], dim = 1
144+ )
145+
106146 with torch .no_grad ():
107147 # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
108148 # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
@@ -171,6 +211,15 @@ def _generate_and_score_completions(
171211 else :
172212 completions = completions_text
173213
214+ # Merge extra_fields from rollout_func into inputs for reward functions
215+ if extra_fields :
216+ for i , inp in enumerate (inputs ):
217+ for key , values in extra_fields .items ():
218+ if isinstance (values , list ) and i < len (values ):
219+ inp [key ] = values [i ]
220+ elif not isinstance (values , list ):
221+ inp [key ] = values
222+
174223 # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
175224 # important because rewards will be normalized per group, and completions are distributed. We will later slice
176225 # rewards_per_func to extract each process's subset.
@@ -185,7 +234,7 @@ def _generate_and_score_completions(
185234 # Normalize the rewards to compute the advantages
186235 mean_grouped_rewards = mean_grouped_rewards .repeat_interleave (self .num_generations , dim = 0 )
187236 advantages = rewards - mean_grouped_rewards
188- std_rewards = None
237+
189238 if self .scale_rewards in ["group" , "none" ]:
190239 # If self.scale_rewards = "none", we'll still log group level std
191240 std_rewards = rewards .view (- 1 , self .num_generations ).std (dim = 1 )
@@ -209,10 +258,7 @@ def _generate_and_score_completions(
209258 )
210259 all_process_advantages = advantages .clone () # keep the aggregated advantages for logging
211260 advantages = advantages [process_slice ]
212- if std_rewards is None :
213- std_rewards = rewards .view (- 1 , self .num_generations ).std (dim = 1 )
214- std_rewards = std_rewards .repeat_interleave (self .num_generations , dim = 0 )
215- std_rewards = std_rewards [process_slice ] if std_rewards is not None else None
261+ std_rewards = std_rewards [process_slice ]
216262
217263 # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
218264 for i , reward_func_name in enumerate (self .reward_func_names ):
@@ -306,7 +352,7 @@ def _generate_and_score_completions(
306352 if "token_type_ids" in forward_kwargs :
307353 output ["token_type_ids" ] = forward_kwargs ["token_type_ids" ]
308354 if images is not None :
309- output ["images " ] = images
355+ output ["num_images " ] = num_images
310356 return output
311357
312358 def slice_group_data (
0 commit comments