@@ -771,6 +771,60 @@ def sample(self,
771771 return sampled_requests
772772
773773
774+ # -----------------------------------------------------------------------------
775+ # MT-Bench Dataset Implementation
776+ # -----------------------------------------------------------------------------
777+
778+
779+ class MTBenchDataset (HuggingFaceDataset ):
780+ """
781+ MT-Bench Dataset.
782+ https://huggingface.co/datasets/philschmid/mt-bench
783+
784+ We create a single turn dataset for MT-Bench.
785+ This is similar to Spec decoding benchmark setup in vLLM
786+ https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
787+ """ # noqa: E501
788+
789+ DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
790+ SUPPORTED_DATASET_PATHS = {
791+ "philschmid/mt-bench" ,
792+ }
793+
794+ def sample (self ,
795+ tokenizer : PreTrainedTokenizerBase ,
796+ num_requests : int ,
797+ output_len : Optional [int ] = None ,
798+ enable_multimodal_chat : bool = False ,
799+ ** kwargs ) -> list :
800+ output_len = (output_len
801+ if output_len is not None else self .DEFAULT_OUTPUT_LEN )
802+ sampled_requests = []
803+
804+ for item in self .data :
805+ if len (sampled_requests ) >= num_requests :
806+ break
807+ prompt = item ['turns' ][0 ]
808+
809+ # apply template
810+ prompt = tokenizer .apply_chat_template ([{
811+ "role" : "user" ,
812+ "content" : prompt
813+ }],
814+ add_generation_prompt = True ,
815+ tokenize = False )
816+
817+ prompt_len = len (tokenizer (prompt ).input_ids )
818+ sampled_requests .append (
819+ SampleRequest (
820+ prompt = prompt ,
821+ prompt_len = prompt_len ,
822+ expected_output_len = output_len ,
823+ ))
824+ self .maybe_oversample_requests (sampled_requests , num_requests )
825+ return sampled_requests
826+
827+
774828# -----------------------------------------------------------------------------
775829# AIMO Dataset Implementation
776830# -----------------------------------------------------------------------------
0 commit comments