@@ -1450,6 +1450,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14501450        ):
14511451            dataset_class  =  MLPerfDataset 
14521452            args .hf_split  =  "train" 
1453+         elif  (
1454+             args .dataset_path  in  MMStarDataset .SUPPORTED_DATASET_PATHS 
1455+             or  args .hf_name  in  MMStarDataset .SUPPORTED_DATASET_PATHS 
1456+         ):
1457+             dataset_class  =  MMStarDataset 
1458+             args .hf_split  =  "val" 
1459+             args .hf_subset  =  None 
14531460        else :
14541461            supported_datasets  =  set ([
14551462                dataset_name  for  cls  in  HuggingFaceDataset .__subclasses__ ()
@@ -2721,3 +2728,76 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]:
27212728
27222729        random .shuffle (requests )
27232730        return  requests 
2731+ 
2732+ 
2733+ # ----------------------------------------------------------------------------- 
2734+ # MMStar Dataset Implementation 
2735+ # ----------------------------------------------------------------------------- 
2736+ 
2737+ 
2738+ class  MMStarDataset (HuggingFaceDataset ):
2739+     """ 
2740+     Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar 
2741+     refer to: https://github.com/sgl-project/SpecForge/pull/106 
2742+     """ 
2743+     DEFAULT_OUTPUT_LEN  =  128 
2744+     SUPPORTED_DATASET_PATHS  =  {"Lin-Chen/MMStar" }
2745+     IS_MULTIMODAL  =  True 
2746+ 
2747+     def  sample (
2748+         self ,
2749+         tokenizer : PreTrainedTokenizerBase ,
2750+         num_requests : int ,
2751+         output_len : Optional [int ] =  None ,
2752+         enable_multimodal_chat : bool  =  False ,
2753+         request_id_prefix : str  =  "" ,
2754+         no_oversample : bool  =  False ,
2755+         ** kwargs ,
2756+     ) ->  list [SampleRequest ]:
2757+         # If --hf-output-len is not set, use the default output length. 
2758+         output_len  =  (output_len 
2759+                       if  output_len  is  not   None  else  self .DEFAULT_OUTPUT_LEN )
2760+         sampled_requests : list [SampleRequest ] =  []
2761+ 
2762+         for  ind , item  in  enumerate (self .data ):
2763+             if  len (sampled_requests ) >=  num_requests :
2764+                 break 
2765+             # Split the question text from options 
2766+             # (keep only the part before "Options:"). 
2767+             full_q : str  =  item .get ("question" , "" )
2768+             question_text  =  full_q .split ("Options:" , 1 )[0 ].strip ()
2769+ 
2770+             # Multimodal image content. 
2771+             mm_content  =  process_image (item ["image" ])
2772+ 
2773+             # Compute prompt token length (note: this is plain text length 
2774+             # if enable_multimodal_chat is False). 
2775+             prompt_len  =  len (tokenizer (question_text ).input_ids )
2776+ 
2777+             if  enable_multimodal_chat :
2778+                 # If multimodal content should be embedded in the chat message, 
2779+                 # convert to [{"role":"user","content":[...]}] 
2780+                 prompt  =  self .apply_multimodal_chat_transformation (
2781+                     question_text , mm_content 
2782+                 )
2783+                 mm_for_request  =  None   # Already embedded in chat content. 
2784+             else :
2785+                 # Default: prompt is plain text, 
2786+                 # image is in mm_content for the bench to assemble. 
2787+                 prompt  =  question_text 
2788+                 mm_for_request  =  mm_content 
2789+ 
2790+             sampled_requests .append (
2791+                 SampleRequest (
2792+                     prompt = prompt ,
2793+                     prompt_len = prompt_len ,
2794+                     expected_output_len = output_len ,
2795+                     multi_modal_data = mm_for_request ,
2796+                     request_id = request_id_prefix  +  str (ind ),
2797+                 )
2798+             )
2799+ 
2800+         self .maybe_oversample_requests (
2801+             sampled_requests , num_requests , request_id_prefix , no_oversample 
2802+         )
2803+         return  sampled_requests 
0 commit comments