Skip to content

Commit 3feeeb9

Browse files
[Spec Decode][Benchmark] Add Spec Bench Dataset for benchmarking (#23563)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
1 parent 6f4a82f commit 3feeeb9

File tree

1 file changed

+79
-1
lines changed

1 file changed

+79
-1
lines changed

vllm/benchmarks/datasets.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
10201020
default="random",
10211021
choices=[
10221022
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
1023-
"custom", "prefix_repetition"
1023+
"custom", "prefix_repetition", "spec_bench"
10241024
],
10251025
help="Name of the dataset to benchmark on.",
10261026
)
@@ -1053,6 +1053,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
10531053
"Skip applying chat template to prompt, used only for custom dataset.",
10541054
)
10551055

1056+
spec_bench_group = parser.add_argument_group("spec bench dataset options")
1057+
spec_bench_group.add_argument(
1058+
"--spec-bench-output-len",
1059+
type=int,
1060+
default=256,
1061+
help=
1062+
"Num of output tokens per request, used only for spec bench dataset.",
1063+
)
1064+
spec_bench_group.add_argument(
1065+
"--spec-bench-category",
1066+
type=str,
1067+
default=None,
1068+
help=
1069+
"Category for spec bench dataset. If None, use all categories.",
1070+
)
1071+
10561072
sonnet_group = parser.add_argument_group("sonnet dataset options")
10571073
sonnet_group.add_argument(
10581074
"--sonnet-input-len",
@@ -1404,6 +1420,14 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14041420
else:
14051421
# For datasets that follow a similar structure, use a mapping.
14061422
dataset_mapping = {
1423+
"spec_bench":
1424+
lambda: SpecBench(dataset_path=args.dataset_path,
1425+
category=args.spec_bench_category).sample(
1426+
num_requests=args.num_prompts,
1427+
tokenizer=tokenizer,
1428+
output_len=args.spec_bench_output_len,
1429+
request_id_prefix=args.request_id_prefix,
1430+
),
14071431
"sharegpt": lambda: ShareGPTDataset(
14081432
random_seed=args.seed, dataset_path=args.dataset_path
14091433
).sample(
@@ -1541,6 +1565,14 @@ def sample(
15411565
request_id_prefix: str = "",
15421566
**kwargs,
15431567
) -> list:
1568+
# load all data if needed
1569+
self.num_available_samples = len(self.data)
1570+
if num_requests <= 0:
1571+
num_requests = self.num_available_samples
1572+
logger.info("num_requests is set to 0 or negative, "
1573+
"so using all available samples: %d",
1574+
num_requests)
1575+
15441576
sampled_requests = []
15451577
for i, item in enumerate(self.data):
15461578
if len(sampled_requests) >= num_requests:
@@ -1572,6 +1604,52 @@ def sample(
15721604
return sampled_requests
15731605

15741606

1607+
# -----------------------------------------------------------------------------
1608+
# Spec Bench Dataset Implementation
1609+
# -----------------------------------------------------------------------------
1610+
1611+
1612+
class SpecBench(CustomDataset):
1613+
"""
1614+
Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
1615+
Download the dataset using:
1616+
wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
1617+
""" # noqa: E501
1618+
1619+
def __init__(self, **kwargs) -> None:
1620+
self.category = kwargs.pop("category", None)
1621+
super().__init__(**kwargs)
1622+
self.load_data()
1623+
1624+
def load_data(self) -> None:
1625+
if self.dataset_path is None:
1626+
raise ValueError("dataset_path must be provided for loading data.")
1627+
1628+
self.data = []
1629+
1630+
# Load the JSONL file
1631+
jsonl_data = pd.read_json(path_or_buf=self.dataset_path,
1632+
lines=True)
1633+
1634+
# check if the JSONL file has a 'turns' column
1635+
if "turns" not in jsonl_data.columns:
1636+
raise ValueError("JSONL file must contain a 'turns' column.")
1637+
1638+
for _, row in jsonl_data.iterrows():
1639+
# sample only from a specific category if specified
1640+
if (not self.category) or (self.category == row['category']):
1641+
prompt = row["turns"][0]
1642+
self.data.append({"prompt": prompt})
1643+
1644+
random.seed(self.random_seed)
1645+
random.shuffle(self.data)
1646+
1647+
def sample(self, **kwargs) -> list:
1648+
# leverage CustomDataset sample
1649+
kwargs["skip_chat_template"] = False
1650+
return super().sample(**kwargs)
1651+
1652+
15751653
# -----------------------------------------------------------------------------
15761654
# Sonnet Dataset Implementation
15771655
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)