@@ -1020,7 +1020,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
10201020 default = "random" ,
10211021 choices = [
10221022 "sharegpt" , "burstgpt" , "sonnet" , "random" , "random-mm" , "hf" ,
1023- "custom" , "prefix_repetition"
1023+ "custom" , "prefix_repetition" , "spec_bench"
10241024 ],
10251025 help = "Name of the dataset to benchmark on." ,
10261026 )
@@ -1053,6 +1053,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
10531053 "Skip applying chat template to prompt, used only for custom dataset." ,
10541054 )
10551055
1056+ spec_bench_group = parser .add_argument_group ("spec bench dataset options" )
1057+ spec_bench_group .add_argument (
1058+ "--spec-bench-output-len" ,
1059+ type = int ,
1060+ default = 256 ,
1061+ help =
1062+ "Num of output tokens per request, used only for spec bench dataset." ,
1063+ )
1064+ spec_bench_group .add_argument (
1065+ "--spec-bench-category" ,
1066+ type = str ,
1067+ default = None ,
1068+ help =
1069+ "Category for spec bench dataset. If None, use all categories." ,
1070+ )
1071+
10561072 sonnet_group = parser .add_argument_group ("sonnet dataset options" )
10571073 sonnet_group .add_argument (
10581074 "--sonnet-input-len" ,
@@ -1404,6 +1420,14 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14041420 else :
14051421 # For datasets that follow a similar structure, use a mapping.
14061422 dataset_mapping = {
1423+ "spec_bench" :
1424+ lambda : SpecBench (dataset_path = args .dataset_path ,
1425+ category = args .spec_bench_category ).sample (
1426+ num_requests = args .num_prompts ,
1427+ tokenizer = tokenizer ,
1428+ output_len = args .spec_bench_output_len ,
1429+ request_id_prefix = args .request_id_prefix ,
1430+ ),
14071431 "sharegpt" : lambda : ShareGPTDataset (
14081432 random_seed = args .seed , dataset_path = args .dataset_path
14091433 ).sample (
@@ -1541,6 +1565,14 @@ def sample(
15411565 request_id_prefix : str = "" ,
15421566 ** kwargs ,
15431567 ) -> list :
1568+ # load all data if needed
1569+ self .num_available_samples = len (self .data )
1570+ if num_requests <= 0 :
1571+ num_requests = self .num_available_samples
1572+ logger .info ("num_requests is set to 0 or negative, "
1573+ "so using all available samples: %d" ,
1574+ num_requests )
1575+
15441576 sampled_requests = []
15451577 for i , item in enumerate (self .data ):
15461578 if len (sampled_requests ) >= num_requests :
@@ -1572,6 +1604,52 @@ def sample(
15721604 return sampled_requests
15731605
15741606
1607+ # -----------------------------------------------------------------------------
1608+ # Spec Bench Dataset Implementation
1609+ # -----------------------------------------------------------------------------
1610+
1611+
1612+ class SpecBench (CustomDataset ):
1613+ """
1614+ Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
1615+ Download the dataset using:
1616+ wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
1617+ """ # noqa: E501
1618+
1619+ def __init__ (self , ** kwargs ) -> None :
1620+ self .category = kwargs .pop ("category" , None )
1621+ super ().__init__ (** kwargs )
1622+ self .load_data ()
1623+
1624+ def load_data (self ) -> None :
1625+ if self .dataset_path is None :
1626+ raise ValueError ("dataset_path must be provided for loading data." )
1627+
1628+ self .data = []
1629+
1630+ # Load the JSONL file
1631+ jsonl_data = pd .read_json (path_or_buf = self .dataset_path ,
1632+ lines = True )
1633+
1634+ # check if the JSONL file has a 'turns' column
1635+ if "turns" not in jsonl_data .columns :
1636+ raise ValueError ("JSONL file must contain a 'turns' column." )
1637+
1638+ for _ , row in jsonl_data .iterrows ():
1639+ # sample only from a specific category if specified
1640+ if (not self .category ) or (self .category == row ['category' ]):
1641+ prompt = row ["turns" ][0 ]
1642+ self .data .append ({"prompt" : prompt })
1643+
1644+ random .seed (self .random_seed )
1645+ random .shuffle (self .data )
1646+
1647+ def sample (self , ** kwargs ) -> list :
1648+ # leverage CustomDataset sample
1649+ kwargs ["skip_chat_template" ] = False
1650+ return super ().sample (** kwargs )
1651+
1652+
15751653# -----------------------------------------------------------------------------
15761654# Sonnet Dataset Implementation
15771655# -----------------------------------------------------------------------------
0 commit comments