diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 882b68ac9e2f..2e0c02d31767 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -1020,7 +1020,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): default="random", choices=[ "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", - "custom", "prefix_repetition" + "custom", "prefix_repetition", "spec_bench" ], help="Name of the dataset to benchmark on.", ) @@ -1053,6 +1053,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "Skip applying chat template to prompt, used only for custom dataset.", ) + spec_bench_group = parser.add_argument_group("spec bench dataset options") + spec_bench_group.add_argument( + "--spec-bench-output-len", + type=int, + default=256, + help= + "Num of output tokens per request, used only for spec bench dataset.", + ) + spec_bench_group.add_argument( + "--spec-bench-category", + type=str, + default=None, + help= + "Category for spec bench dataset. If None, use all categories.", + ) + sonnet_group = parser.add_argument_group("sonnet dataset options") sonnet_group.add_argument( "--sonnet-input-len", @@ -1404,6 +1420,14 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: else: # For datasets that follow a similar structure, use a mapping. dataset_mapping = { + "spec_bench": + lambda: SpecBench(dataset_path=args.dataset_path, + category=args.spec_bench_category).sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.spec_bench_output_len, + request_id_prefix=args.request_id_prefix, + ), "sharegpt": lambda: ShareGPTDataset( random_seed=args.seed, dataset_path=args.dataset_path ).sample( @@ -1541,6 +1565,14 @@ def sample( request_id_prefix: str = "", **kwargs, ) -> list: + # load all data if needed + self.num_available_samples = len(self.data) + if num_requests <= 0: + num_requests = self.num_available_samples + logger.info("num_requests is set to 0 or negative, " + "so using all available samples: %d", + num_requests) + sampled_requests = [] for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: @@ -1572,6 +1604,52 @@ def sample( return sampled_requests +# ----------------------------------------------------------------------------- +# Spec Bench Dataset Implementation +# ----------------------------------------------------------------------------- + + +class SpecBench(CustomDataset): + """ + Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench + Download the dataset using: + wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl + """ # noqa: E501 + + def __init__(self, **kwargs) -> None: + self.category = kwargs.pop("category", None) + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + self.data = [] + + # Load the JSONL file + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, + lines=True) + + # check if the JSONL file has a 'turns' column + if "turns" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'turns' column.") + + for _, row in jsonl_data.iterrows(): + # sample only from a specific category if specified + if (not self.category) or (self.category == row['category']): + prompt = row["turns"][0] + self.data.append({"prompt": prompt}) + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample(self, **kwargs) -> list: + # leverage CustomDataset sample + kwargs["skip_chat_template"] = False + return super().sample(**kwargs) + + # ----------------------------------------------------------------------------- # Sonnet Dataset Implementation # -----------------------------------------------------------------------------