@@ -46,7 +46,7 @@ class SampleRequest:
4646 Represents a single inference request for benchmarking.
4747 """
4848
49- prompt : str
49+ prompt : Union [ str , Any ]
5050 prompt_len : int
5151 expected_output_len : int
5252 multi_modal_data : Optional [Union [MultiModalDataDict , dict ]] = None
@@ -84,6 +84,20 @@ def __init__(
8484 if random_seed is not None else self .DEFAULT_SEED )
8585 self .data = None
8686
87+ def apply_multimodal_chat_transformation (
88+ self ,
89+ prompt : str ,
90+ mm_content : Optional [MultiModalDataDict ] = None ) -> list [dict ]:
91+ """
92+ Transform a prompt and optional multimodal content into a chat format.
93+ This method is used for chat models that expect a specific
94+ conversation format.
95+ """
96+ content = [{"text" : prompt , "type" : "text" }]
97+ if mm_content is not None :
98+ content .append (mm_content )
99+ return [{"role" : "user" , "content" : content }]
100+
87101 def load_data (self ) -> None :
88102 """
89103 Load data from the dataset path into self.data.
@@ -338,6 +352,7 @@ def sample(self,
338352 lora_path : Optional [str ] = None ,
339353 max_loras : Optional [int ] = None ,
340354 output_len : Optional [int ] = None ,
355+ enable_multimodal_chat : bool = False ,
341356 ** kwargs ) -> list :
342357 samples : list = []
343358 for entry in self .data :
@@ -358,6 +373,9 @@ def sample(self,
358373 skip_min_output_len_check = output_len
359374 is not None ):
360375 continue
376+ if enable_multimodal_chat :
377+ prompt = self .apply_multimodal_chat_transformation (
378+ prompt , None )
361379 samples .append (
362380 SampleRequest (
363381 prompt = prompt ,
@@ -550,34 +568,32 @@ def load_data(self) -> None:
550568 split = self .dataset_split ,
551569 streaming = True ,
552570 )
553-
554- if "conversations" not in self .data .features :
555- raise ValueError ("HF Dataset must have a 'conversations' column." )
556-
571+ if self .data .features is None or "conversations" \
572+ not in self .data .features :
573+ raise ValueError (
574+ "HuggingFaceDataset currently only supports datasets with "
575+ "a 'conversations' column like lmms-lab/LLaVA-OneVision-Data. "
576+ "Please consider contributing if you would like to add "
577+ "support for additional dataset formats." )
557578 # Shuffle and filter examples with at least 2 conversations.
558579 self .data = self .data .shuffle (seed = self .random_seed ).filter (
559580 lambda x : len (x ["conversations" ]) >= 2 )
560581
561582 def sample (self ,
562583 tokenizer : PreTrainedTokenizerBase ,
563584 num_requests : int ,
564- lora_path : Optional [str ] = None ,
565- max_loras : Optional [int ] = None ,
566585 output_len : Optional [int ] = None ,
586+ enable_multimodal_chat : bool = False ,
567587 ** kwargs ) -> list :
568588 sampled_requests = []
569589 dynamic_output = output_len is None
570590
571591 for item in self .data :
572592 if len (sampled_requests ) >= num_requests :
573593 break
574-
575594 conv = item ["conversations" ]
576595 prompt , completion = conv [0 ]["value" ], conv [1 ]["value" ]
577596
578- lora_request , tokenizer = self .get_random_lora_request (
579- tokenizer , lora_path = lora_path , max_loras = max_loras )
580-
581597 prompt_ids = tokenizer (prompt ).input_ids
582598 completion_ids = tokenizer (completion ).input_ids
583599 prompt_len = len (prompt_ids )
@@ -587,16 +603,20 @@ def sample(self,
587603 if dynamic_output and not is_valid_sequence (
588604 prompt_len , completion_len ):
589605 continue
590-
591606 mm_content = process_image (
592607 item ["image" ]) if "image" in item else None
608+ if enable_multimodal_chat :
609+ # Note: when chat is enabled the request prompt_len is no longer
610+ # accurate and we will be using request output to count the
611+ # actual prompt len and output len
612+ prompt = self .apply_multimodal_chat_transformation (
613+ prompt , mm_content )
593614 sampled_requests .append (
594615 SampleRequest (
595616 prompt = prompt ,
596617 prompt_len = prompt_len ,
597618 expected_output_len = output_len ,
598619 multi_modal_data = mm_content ,
599- lora_request = lora_request ,
600620 ))
601621 return sampled_requests
602622
@@ -606,7 +626,7 @@ def sample(self,
606626# -----------------------------------------------------------------------------
607627
608628
609- class VisionArenaDataset (BenchmarkDataset ):
629+ class VisionArenaDataset (HuggingFaceDataset ):
610630 """
611631 Vision Arena Dataset.
612632 """
@@ -617,14 +637,9 @@ class VisionArenaDataset(BenchmarkDataset):
617637
618638 def __init__ (
619639 self ,
620- dataset_split : str ,
621- dataset_subset : Optional [str ] = None ,
622640 ** kwargs ,
623641 ) -> None :
624642 super ().__init__ (** kwargs )
625- self .dataset_split = dataset_split
626- self .dataset_subset = dataset_subset
627-
628643 if self .dataset_path != self .VISION_ARENA_DATASET_PATH :
629644 raise ValueError (f"Only support Vision Arena dataset.\
630645 This data path { self .dataset_path } is not valid." )
@@ -645,18 +660,24 @@ def load_data(self) -> None:
645660 def sample (self ,
646661 tokenizer : PreTrainedTokenizerBase ,
647662 num_requests : int ,
648- output_len : int = DEFAULT_OUTPUT_LEN ,
663+ output_len : Optional [int ] = None ,
664+ enable_multimodal_chat : bool = False ,
649665 ** kwargs ) -> list :
650- # TODO (jenniferzhao): Add support for offline benchmark sampling
651666 output_len = (output_len
652667 if output_len is not None else self .DEFAULT_OUTPUT_LEN )
653668 sampled_requests = []
654669 for item in self .data :
655670 if len (sampled_requests ) >= num_requests :
656671 break
657672 prompt = item ["turns" ][0 ][0 ]["content" ]
658- prompt_len = len (tokenizer (prompt ).input_ids )
659673 mm_content = process_image (item ["images" ][0 ])
674+ prompt_len = len (tokenizer (prompt ).input_ids )
675+ if enable_multimodal_chat :
676+ # Note: when chat is enabled the request prompt_len is no longer
677+ # accurate and we will be using request output to count the
678+ # actual prompt len
679+ prompt = self .apply_multimodal_chat_transformation (
680+ prompt , mm_content )
660681 sampled_requests .append (
661682 SampleRequest (
662683 prompt = prompt ,
0 commit comments