@@ -537,6 +537,7 @@ async def benchmark(
537537 ignore_eos : bool ,
538538 goodput_config_dict : Dict [str , float ],
539539 max_concurrency : Optional [int ],
540+ lora_modules : Optional [List [str ]],
540541):
541542 if backend in ASYNC_REQUEST_FUNCS :
542543 request_func = ASYNC_REQUEST_FUNCS [backend ]
@@ -562,6 +563,7 @@ async def benchmark(
562563 multi_modal_content = test_mm_content ,
563564 ignore_eos = ignore_eos ,
564565 )
566+
565567 test_output = await request_func (request_func_input = test_input )
566568 if not test_output .success :
567569 raise ValueError (
@@ -570,6 +572,11 @@ async def benchmark(
570572 else :
571573 print ("Initial test run completed. Starting main benchmark run..." )
572574
575+ if lora_modules :
576+ # For each input request, choose a LoRA module at random.
577+ lora_modules = iter (
578+ [random .choice (lora_modules ) for _ in range (len (input_requests ))])
579+
573580 if profile :
574581 print ("Starting profiler..." )
575582 profile_input = RequestFuncInput (model = model_id ,
@@ -616,8 +623,13 @@ async def limited_request_func(request_func_input, pbar):
616623 tasks : List [asyncio .Task ] = []
617624 async for request in get_request (input_requests , request_rate , burstiness ):
618625 prompt , prompt_len , output_len , mm_content = request
619- request_func_input = RequestFuncInput (model = model_id ,
620- model_name = model_name ,
626+ req_model_id , req_model_name = model_id , model_name
627+ if lora_modules :
628+ req_lora_module = next (lora_modules )
629+ req_model_id , req_model_name = req_lora_module , req_lora_module
630+
631+ request_func_input = RequestFuncInput (model = req_model_id ,
632+ model_name = req_model_name ,
621633 prompt = prompt ,
622634 api_url = api_url ,
623635 prompt_len = prompt_len ,
@@ -900,6 +912,7 @@ def main(args: argparse.Namespace):
900912 ignore_eos = args .ignore_eos ,
901913 goodput_config_dict = goodput_config_dict ,
902914 max_concurrency = args .max_concurrency ,
915+ lora_modules = args .lora_modules ,
903916 ))
904917
905918 # Save config and results to json
@@ -1237,5 +1250,12 @@ def main(args: argparse.Namespace):
12371250 "If not specified, the model name will be the "
12381251 "same as the ``--model`` argument. " )
12391252
1253+ parser .add_argument ("--lora-modules" ,
1254+ nargs = '+' ,
1255+ default = None ,
1256+ help = "A subset of LoRA module names passed in when "
1257+ "launching the server. For each request, the "
1258+ "script chooses a LoRA module at random." )
1259+
12401260 args = parser .parse_args ()
12411261 main (args )
0 commit comments