Skip to content

Commit d562f98

Browse files
committed
add spec benc
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
1 parent a9082a4 commit d562f98

File tree

1 file changed

+79
-1
lines changed

1 file changed

+79
-1
lines changed

vllm/benchmarks/datasets.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
10051005
default="random",
10061006
choices=[
10071007
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
1008-
"custom", "prefix_repetition"
1008+
"custom", "prefix_repetition", "spec_bench"
10091009
],
10101010
help="Name of the dataset to benchmark on.",
10111011
)
@@ -1038,6 +1038,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
10381038
"Skip applying chat template to prompt, used only for custom dataset.",
10391039
)
10401040

1041+
spec_bench_group = parser.add_argument_group("spec bench dataset options")
1042+
spec_bench_group.add_argument(
1043+
"--spec-bench-output-len",
1044+
type=int,
1045+
default=256,
1046+
help=
1047+
"Num of output tokens per request, used only for spec bench dataset.",
1048+
)
1049+
spec_bench_group.add_argument(
1050+
"--spec-bench-category",
1051+
type=str,
1052+
default=None,
1053+
help=
1054+
"Category for spec bench dataset. If None, use all categories.",
1055+
)
1056+
10411057
sonnet_group = parser.add_argument_group("sonnet dataset options")
10421058
sonnet_group.add_argument(
10431059
"--sonnet-input-len",
@@ -1348,6 +1364,14 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
13481364
else:
13491365
# For datasets that follow a similar structure, use a mapping.
13501366
dataset_mapping = {
1367+
"spec_bench":
1368+
lambda: SpecBench(dataset_path=args.dataset_path,
1369+
category=args.spec_bench_category).sample(
1370+
num_requests=args.num_prompts,
1371+
tokenizer=tokenizer,
1372+
output_len=args.spec_bench_output_len,
1373+
request_id_prefix=args.request_id_prefix,
1374+
),
13511375
"sharegpt":
13521376
lambda: ShareGPTDataset(random_seed=args.seed,
13531377
dataset_path=args.dataset_path).sample(
@@ -1482,6 +1506,14 @@ def sample(
14821506
request_id_prefix: str = "",
14831507
**kwargs,
14841508
) -> list:
1509+
# load all data if needed
1510+
self.num_available_samples = len(self.data)
1511+
if num_requests <= 0:
1512+
num_requests = self.num_available_samples
1513+
logger.info("num_requests is set to 0 or negative, "
1514+
"so using all available samples: %d",
1515+
num_requests)
1516+
14851517
sampled_requests = []
14861518
for i, item in enumerate(self.data):
14871519
if len(sampled_requests) >= num_requests:
@@ -1513,6 +1545,52 @@ def sample(
15131545
return sampled_requests
15141546

15151547

1548+
# -----------------------------------------------------------------------------
1549+
# Spec Bench Dataset Implementation
1550+
# -----------------------------------------------------------------------------
1551+
1552+
1553+
class SpecBench(CustomDataset):
1554+
"""
1555+
Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
1556+
Download the dataset using:
1557+
wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
1558+
""" # noqa: E501
1559+
1560+
def __init__(self, **kwargs) -> None:
1561+
self.category = kwargs.pop("category", None)
1562+
super().__init__(**kwargs)
1563+
self.load_data()
1564+
1565+
def load_data(self) -> None:
1566+
if self.dataset_path is None:
1567+
raise ValueError("dataset_path must be provided for loading data.")
1568+
1569+
self.data = []
1570+
1571+
# Load the JSONL file
1572+
jsonl_data = pd.read_json(path_or_buf=self.dataset_path,
1573+
lines=True)
1574+
1575+
# check if the JSONL file has a 'turns' column
1576+
if "turns" not in jsonl_data.columns:
1577+
raise ValueError("JSONL file must contain a 'turns' column.")
1578+
1579+
for _, row in jsonl_data.iterrows():
1580+
# sample only from a specific category if specified
1581+
if (not self.category) or (self.category == row['category']):
1582+
prompt = row["turns"][0]
1583+
self.data.append({"prompt": prompt})
1584+
1585+
random.seed(self.random_seed)
1586+
random.shuffle(self.data)
1587+
1588+
def sample(self, **kwargs) -> list:
1589+
# leverage CustomDataset sample
1590+
kwargs["skip_chat_template"] = False
1591+
return super().sample(**kwargs)
1592+
1593+
15161594
# -----------------------------------------------------------------------------
15171595
# Sonnet Dataset Implementation
15181596
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)