5858
5959class TaskType (Enum ):
6060 GENERATION = "generation"
61- EMBEDDING = "embedding "
61+ POOLING = "pooling "
6262
6363
6464@dataclass
@@ -1084,10 +1084,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
10841084 parser .add_argument (
10851085 "--percentile-metrics" ,
10861086 type = str ,
1087- default = "ttft,tpot,itl" ,
1087+ default = None ,
10881088 help = "Comma-separated list of selected metrics to report percentils. "
10891089 "This argument specifies the metrics to report percentiles. "
1090- 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' ,
1090+ 'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
1091+ 'If not specified, defaults to "ttft,tpot,itl" for generative models '
1092+ 'and "e2el" for pooling models.' ,
10911093 )
10921094 parser .add_argument (
10931095 "--metric-percentiles" ,
@@ -1310,7 +1312,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
13101312 goodput_config_dict = check_goodput_args (args )
13111313
13121314 backend = args .backend
1313- task_type = TaskType .EMBEDDING if "embeddings" in backend else TaskType .GENERATION
1315+ task_type = (
1316+ TaskType .POOLING
1317+ if "embeddings" in backend or "rerank" in backend
1318+ else TaskType .GENERATION
1319+ )
13141320
13151321 # Collect the sampling parameters.
13161322 if task_type == TaskType .GENERATION :
@@ -1336,12 +1342,17 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
13361342
13371343 if "temperature" not in sampling_params :
13381344 sampling_params ["temperature" ] = 0.0 # Default to greedy decoding.
1345+
1346+ default_percentile_metrics = "ttft,tpot,itl"
13391347 else :
13401348 sampling_params = {}
1349+ default_percentile_metrics = "e2el"
13411350
13421351 extra_body = args .extra_body or {}
13431352 extra_body = {** sampling_params , ** extra_body }
13441353
1354+ percentile_metrics : str = args .percentile_metrics or default_percentile_metrics
1355+
13451356 # Avoid GC processing "static" data - reduce pause times.
13461357 gc .collect ()
13471358 gc .freeze ()
@@ -1360,7 +1371,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
13601371 burstiness = args .burstiness ,
13611372 disable_tqdm = args .disable_tqdm ,
13621373 profile = args .profile ,
1363- selected_percentile_metrics = args . percentile_metrics .split ("," ),
1374+ selected_percentile_metrics = percentile_metrics .split ("," ),
13641375 selected_percentiles = [float (p ) for p in args .metric_percentiles .split ("," )],
13651376 ignore_eos = args .ignore_eos ,
13661377 goodput_config_dict = goodput_config_dict ,
0 commit comments