@@ -199,6 +199,56 @@ def sample_sonnet_requests(
199199 return sampled_requests
200200
201201
202+ def sample_mmmu_pro_vision_requests (
203+ dataset ,
204+ num_requests : int ,
205+ tokenizer : PreTrainedTokenizerBase ,
206+ fixed_output_len : Optional [int ] = None ,
207+ ) -> List [Tuple [str , str , int , Optional [Dict [str , Collection [str ]]]]]:
208+ sampled_requests : List [Tuple [str , int , int , Dict [str ,
209+ Collection [str ]]]] = []
210+ for data in dataset :
211+ if len (sampled_requests ) == num_requests :
212+ break
213+
214+ # MMMU-Pro vision direct prompt
215+ # Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5
216+ prompt = (
217+ "Answer with the option letter from the given choices directly. "
218+ "The last line of your response should be of the following "
219+ "format: 'Answer: $LETTER' (without quotes) where LETTER is one of "
220+ "options." )
221+
222+ prompt_token_ids = tokenizer (prompt ).input_ids
223+ if fixed_output_len is None :
224+ # Default max output len is set to 128
225+ print ("--hf-output-len is not provided. Using default value 128." )
226+ fixed_output_len = 128
227+
228+ prompt_len = len (prompt_token_ids )
229+ output_len = fixed_output_len
230+
231+ assert isinstance (
232+ data ["image" ],
233+ Image ), ("Input image format must be `PIL.Image.Image`, "
234+ f"given { type (data ['image' ])} ." )
235+ image : Image = data ["image" ]
236+ image = image .convert ("RGB" )
237+ image_data = io .BytesIO ()
238+ image .save (image_data , format = 'JPEG' )
239+ image_base64 = base64 .b64encode (image_data .getvalue ()).decode ("utf-8" )
240+ mm_content = {
241+ "type" : "image_url" ,
242+ "image_url" : {
243+ "url" : f"data:image/jpeg;base64,{ image_base64 } "
244+ },
245+ }
246+
247+ sampled_requests .append ((prompt , prompt_len , output_len , mm_content ))
248+
249+ return sampled_requests
250+
251+
202252def sample_hf_requests (
203253 dataset_path : str ,
204254 dataset_subset : str ,
@@ -208,6 +258,21 @@ def sample_hf_requests(
208258 random_seed : int ,
209259 fixed_output_len : Optional [int ] = None ,
210260) -> List [Tuple [str , str , int , Optional [Dict [str , Collection [str ]]]]]:
261+
262+ # Special case for MMMU-Pro vision dataset
263+ if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision' :
264+ assert dataset_split == "test"
265+ dataset = load_dataset (dataset_path ,
266+ name = dataset_subset ,
267+ split = dataset_split ,
268+ streaming = True )
269+ assert "image" in dataset .features , (
270+ "MMMU/MMMU_Pro vision dataset must have 'image' column." )
271+ filter_func = lambda x : isinstance (x ["image" ], Image )
272+ dataset = dataset .shuffle (seed = random_seed ).filter (filter_func )
273+ return sample_mmmu_pro_vision_requests (dataset , num_requests ,
274+ tokenizer , fixed_output_len )
275+
211276 dataset = load_dataset (dataset_path ,
212277 name = dataset_subset ,
213278 split = dataset_split ,
0 commit comments