diff --git a/src/llama_stack_client/lib/cli/eval/run_benchmark.py b/src/llama_stack_client/lib/cli/eval/run_benchmark.py index d049c104..e088137e 100644 --- a/src/llama_stack_client/lib/cli/eval/run_benchmark.py +++ b/src/llama_stack_client/lib/cli/eval/run_benchmark.py @@ -13,7 +13,13 @@ from tqdm.rich import tqdm from ..common.utils import create_bar_chart -from .utils import aggregate_accuracy, aggregate_average, aggregate_categorical_count, aggregate_median +from .utils import ( + aggregate_accuracy, + aggregate_average, + aggregate_weighted_average, + aggregate_categorical_count, + aggregate_median, +) @click.command("run-benchmark") @@ -94,9 +100,7 @@ def run_benchmark( scoring_functions = benchmark.scoring_functions dataset_id = benchmark.dataset_id - results = client.datasets.iterrows( - dataset_id=dataset_id, limit=-1 if num_examples is None else num_examples - ) + results = client.datasets.iterrows(dataset_id=dataset_id, limit=-1 if num_examples is None else num_examples) output_res = {} @@ -146,6 +150,8 @@ def run_benchmark( output_res[scoring_fn].append(aggregate_categorical_count(scoring_results)) elif aggregation_function == "average": output_res[scoring_fn].append(aggregate_average(scoring_results)) + elif aggregation_function == "weighted_average": + output_res[scoring_fn].append(aggregate_weighted_average(scoring_results)) elif aggregation_function == "median": output_res[scoring_fn].append(aggregate_median(scoring_results)) elif aggregation_function == "accuracy": diff --git a/src/llama_stack_client/lib/cli/eval/utils.py b/src/llama_stack_client/lib/cli/eval/utils.py index 102b8817..96d8d54c 100644 --- a/src/llama_stack_client/lib/cli/eval/utils.py +++ b/src/llama_stack_client/lib/cli/eval/utils.py @@ -24,6 +24,19 @@ def aggregate_average( } +def aggregate_weighted_average( + scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]], +) -> Dict[str, Any]: + return { + "weighted_average": sum( + result["score"] * result["weight"] + for result in scoring_results + if result["score"] is not None and result["weight"] is not None + ) + / sum(result["weight"] for result in scoring_results if result["weight"] is not None), + } + + def aggregate_median( scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]], ) -> Dict[str, Any]: