@@ -1450,6 +1450,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14501450 ):
14511451 dataset_class = MLPerfDataset
14521452 args .hf_split = "train"
1453+ elif (
1454+ args .dataset_path in MMStarDataset .SUPPORTED_DATASET_PATHS
1455+ or args .hf_name in MMStarDataset .SUPPORTED_DATASET_PATHS
1456+ ):
1457+ dataset_class = MMStarDataset
1458+ args .hf_split = "val"
1459+ args .hf_subset = None
14531460 else :
14541461 supported_datasets = set ([
14551462 dataset_name for cls in HuggingFaceDataset .__subclasses__ ()
@@ -2721,3 +2728,76 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]:
27212728
27222729 random .shuffle (requests )
27232730 return requests
2731+
2732+
2733+ # -----------------------------------------------------------------------------
2734+ # MMStar Dataset Implementation
2735+ # -----------------------------------------------------------------------------
2736+
2737+
2738+ class MMStarDataset (HuggingFaceDataset ):
2739+ """
2740+ Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar
2741+ refer to: https://github.com/sgl-project/SpecForge/pull/106
2742+ """
2743+ DEFAULT_OUTPUT_LEN = 128
2744+ SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar" }
2745+ IS_MULTIMODAL = True
2746+
2747+ def sample (
2748+ self ,
2749+ tokenizer : PreTrainedTokenizerBase ,
2750+ num_requests : int ,
2751+ output_len : Optional [int ] = None ,
2752+ enable_multimodal_chat : bool = False ,
2753+ request_id_prefix : str = "" ,
2754+ no_oversample : bool = False ,
2755+ ** kwargs ,
2756+ ) -> list [SampleRequest ]:
2757+ # If --hf-output-len is not set, use the default output length.
2758+ output_len = (output_len
2759+ if output_len is not None else self .DEFAULT_OUTPUT_LEN )
2760+ sampled_requests : list [SampleRequest ] = []
2761+
2762+ for ind , item in enumerate (self .data ):
2763+ if len (sampled_requests ) >= num_requests :
2764+ break
2765+ # Split the question text from options
2766+ # (keep only the part before "Options:").
2767+ full_q : str = item .get ("question" , "" )
2768+ question_text = full_q .split ("Options:" , 1 )[0 ].strip ()
2769+
2770+ # Multimodal image content.
2771+ mm_content = process_image (item ["image" ])
2772+
2773+ # Compute prompt token length (note: this is plain text length
2774+ # if enable_multimodal_chat is False).
2775+ prompt_len = len (tokenizer (question_text ).input_ids )
2776+
2777+ if enable_multimodal_chat :
2778+ # If multimodal content should be embedded in the chat message,
2779+ # convert to [{"role":"user","content":[...]}]
2780+ prompt = self .apply_multimodal_chat_transformation (
2781+ question_text , mm_content
2782+ )
2783+ mm_for_request = None # Already embedded in chat content.
2784+ else :
2785+ # Default: prompt is plain text,
2786+ # image is in mm_content for the bench to assemble.
2787+ prompt = question_text
2788+ mm_for_request = mm_content
2789+
2790+ sampled_requests .append (
2791+ SampleRequest (
2792+ prompt = prompt ,
2793+ prompt_len = prompt_len ,
2794+ expected_output_len = output_len ,
2795+ multi_modal_data = mm_for_request ,
2796+ request_id = request_id_prefix + str (ind ),
2797+ )
2798+ )
2799+
2800+ self .maybe_oversample_requests (
2801+ sampled_requests , num_requests , request_id_prefix , no_oversample
2802+ )
2803+ return sampled_requests
0 commit comments