3434from typing import Any , Optional
3535
3636import numpy as np
37- from backend_request_func import (ASYNC_REQUEST_FUNCS , RequestFuncInput ,
37+ from backend_request_func import (ASYNC_REQUEST_FUNCS ,
38+ OPENAI_COMPATIBLE_BACKENDS , RequestFuncInput ,
3839 RequestFuncOutput )
3940from tqdm .asyncio import tqdm
4041from transformers import PreTrainedTokenizerBase
@@ -260,6 +261,7 @@ async def benchmark(
260261 goodput_config_dict : dict [str , float ],
261262 max_concurrency : Optional [int ],
262263 lora_modules : Optional [Iterable [str ]],
264+ extra_body : Optional [dict ],
263265):
264266 if backend in ASYNC_REQUEST_FUNCS :
265267 request_func = ASYNC_REQUEST_FUNCS [backend ]
@@ -287,6 +289,7 @@ async def benchmark(
287289 logprobs = logprobs ,
288290 multi_modal_content = test_mm_content ,
289291 ignore_eos = ignore_eos ,
292+ extra_body = extra_body ,
290293 )
291294
292295 test_output = await request_func (request_func_input = test_input )
@@ -313,7 +316,8 @@ async def benchmark(
313316 output_len = test_output_len ,
314317 logprobs = logprobs ,
315318 multi_modal_content = test_mm_content ,
316- ignore_eos = ignore_eos )
319+ ignore_eos = ignore_eos ,
320+ extra_body = extra_body )
317321 profile_output = await request_func (request_func_input = profile_input )
318322 if profile_output .success :
319323 print ("Profiler started" )
@@ -363,7 +367,8 @@ async def limited_request_func(request_func_input, pbar):
363367 output_len = output_len ,
364368 logprobs = logprobs ,
365369 multi_modal_content = mm_content ,
366- ignore_eos = ignore_eos )
370+ ignore_eos = ignore_eos ,
371+ extra_body = extra_body )
367372 tasks .append (
368373 asyncio .create_task (
369374 limited_request_func (request_func_input = request_func_input ,
@@ -652,6 +657,26 @@ def main(args: argparse.Namespace):
652657 raise ValueError (f"Unknown dataset: { args .dataset_name } " ) from err
653658 goodput_config_dict = check_goodput_args (args )
654659
660+ # Collect the sampling parameters.
661+ sampling_params = {
662+ k : v
663+ for k , v in {
664+ "top_p" : args .top_p ,
665+ "top_k" : args .top_k ,
666+ "min_p" : args .min_p ,
667+ "temperature" : args .temperature
668+ }.items () if v is not None
669+ }
670+
671+ # Sampling parameters are only supported by openai-compatible backend.
672+ if sampling_params and args .backend not in OPENAI_COMPATIBLE_BACKENDS :
673+ raise ValueError (
674+ "Sampling parameters are only supported by openai-compatible "
675+ "backends." )
676+
677+ if "temperature" not in sampling_params :
678+ sampling_params ["temperature" ] = 0.0 # Default to greedy decoding.
679+
655680 # Avoid GC processing "static" data - reduce pause times.
656681 gc .collect ()
657682 gc .freeze ()
@@ -678,6 +703,7 @@ def main(args: argparse.Namespace):
678703 goodput_config_dict = goodput_config_dict ,
679704 max_concurrency = args .max_concurrency ,
680705 lora_modules = args .lora_modules ,
706+ extra_body = sampling_params ,
681707 ))
682708
683709 # Save config and results to json
@@ -1000,6 +1026,33 @@ def main(args: argparse.Namespace):
10001026 "from the sampled HF dataset." ,
10011027 )
10021028
1029+ sampling_group = parser .add_argument_group ("sampling parameters" )
1030+ sampling_group .add_argument (
1031+ "--top-p" ,
1032+ type = float ,
1033+ default = None ,
1034+ help = "Top-p sampling parameter. Only has effect on openai-compatible "
1035+ "backends." )
1036+ sampling_group .add_argument (
1037+ "--top-k" ,
1038+ type = int ,
1039+ default = None ,
1040+ help = "Top-k sampling parameter. Only has effect on openai-compatible "
1041+ "backends." )
1042+ sampling_group .add_argument (
1043+ "--min-p" ,
1044+ type = float ,
1045+ default = None ,
1046+ help = "Min-p sampling parameter. Only has effect on openai-compatible "
1047+ "backends." )
1048+ sampling_group .add_argument (
1049+ "--temperature" ,
1050+ type = float ,
1051+ default = None ,
1052+ help = "Temperature sampling parameter. Only has effect on "
1053+ "openai-compatible backends. If not specified, default to greedy "
1054+ "decoding (i.e. temperature==0.0)." )
1055+
10031056 parser .add_argument (
10041057 '--tokenizer-mode' ,
10051058 type = str ,
0 commit comments