@@ -1005,7 +1005,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
10051005 default = "random" ,
10061006 choices = [
10071007 "sharegpt" , "burstgpt" , "sonnet" , "random" , "random-mm" , "hf" ,
1008- "custom" , "prefix_repetition"
1008+ "custom" , "prefix_repetition" , "spec_bench"
10091009 ],
10101010 help = "Name of the dataset to benchmark on." ,
10111011 )
@@ -1038,6 +1038,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
10381038 "Skip applying chat template to prompt, used only for custom dataset." ,
10391039 )
10401040
1041+ spec_bench_group = parser .add_argument_group ("spec bench dataset options" )
1042+ spec_bench_group .add_argument (
1043+ "--spec-bench-output-len" ,
1044+ type = int ,
1045+ default = 256 ,
1046+ help =
1047+ "Num of output tokens per request, used only for spec bench dataset." ,
1048+ )
1049+ spec_bench_group .add_argument (
1050+ "--spec-bench-category" ,
1051+ type = str ,
1052+ default = None ,
1053+ help =
1054+ "Category for spec bench dataset. If None, use all categories." ,
1055+ )
1056+
10411057 sonnet_group = parser .add_argument_group ("sonnet dataset options" )
10421058 sonnet_group .add_argument (
10431059 "--sonnet-input-len" ,
@@ -1348,6 +1364,14 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
13481364 else :
13491365 # For datasets that follow a similar structure, use a mapping.
13501366 dataset_mapping = {
1367+ "spec_bench" :
1368+ lambda : SpecBench (dataset_path = args .dataset_path ,
1369+ category = args .spec_bench_category ).sample (
1370+ num_requests = args .num_prompts ,
1371+ tokenizer = tokenizer ,
1372+ output_len = args .spec_bench_output_len ,
1373+ request_id_prefix = args .request_id_prefix ,
1374+ ),
13511375 "sharegpt" :
13521376 lambda : ShareGPTDataset (random_seed = args .seed ,
13531377 dataset_path = args .dataset_path ).sample (
@@ -1482,6 +1506,14 @@ def sample(
14821506 request_id_prefix : str = "" ,
14831507 ** kwargs ,
14841508 ) -> list :
1509+ # load all data if needed
1510+ self .num_available_samples = len (self .data )
1511+ if num_requests <= 0 :
1512+ num_requests = self .num_available_samples
1513+ logger .info ("num_requests is set to 0 or negative, "
1514+ "so using all available samples: %d" ,
1515+ num_requests )
1516+
14851517 sampled_requests = []
14861518 for i , item in enumerate (self .data ):
14871519 if len (sampled_requests ) >= num_requests :
@@ -1513,6 +1545,52 @@ def sample(
15131545 return sampled_requests
15141546
15151547
1548+ # -----------------------------------------------------------------------------
1549+ # Spec Bench Dataset Implementation
1550+ # -----------------------------------------------------------------------------
1551+
1552+
1553+ class SpecBench (CustomDataset ):
1554+ """
1555+ Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
1556+ Download the dataset using:
1557+ wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
1558+ """ # noqa: E501
1559+
1560+ def __init__ (self , ** kwargs ) -> None :
1561+ self .category = kwargs .pop ("category" , None )
1562+ super ().__init__ (** kwargs )
1563+ self .load_data ()
1564+
1565+ def load_data (self ) -> None :
1566+ if self .dataset_path is None :
1567+ raise ValueError ("dataset_path must be provided for loading data." )
1568+
1569+ self .data = []
1570+
1571+ # Load the JSONL file
1572+ jsonl_data = pd .read_json (path_or_buf = self .dataset_path ,
1573+ lines = True )
1574+
1575+ # check if the JSONL file has a 'turns' column
1576+ if "turns" not in jsonl_data .columns :
1577+ raise ValueError ("JSONL file must contain a 'turns' column." )
1578+
1579+ for _ , row in jsonl_data .iterrows ():
1580+ # sample only from a specific category if specified
1581+ if (not self .category ) or (self .category == row ['category' ]):
1582+ prompt = row ["turns" ][0 ]
1583+ self .data .append ({"prompt" : prompt })
1584+
1585+ random .seed (self .random_seed )
1586+ random .shuffle (self .data )
1587+
1588+ def sample (self , ** kwargs ) -> list :
1589+ # leverage CustomDataset sample
1590+ kwargs ["skip_chat_template" ] = False
1591+ return super ().sample (** kwargs )
1592+
1593+
15161594# -----------------------------------------------------------------------------
15171595# Sonnet Dataset Implementation
15181596# -----------------------------------------------------------------------------
0 commit comments