Skip to content

Commit b62ac6c

Browse files
authored
feat: add weighted_average aggregation function support (#208)
## What does this PR do? add weighted_average aggreagtion function support for llamastack/llama-stack#1708
1 parent 5746f91 commit b62ac6c

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

src/llama_stack_client/lib/cli/eval/run_benchmark.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
from tqdm.rich import tqdm
1414

1515
from ..common.utils import create_bar_chart
16-
from .utils import aggregate_accuracy, aggregate_average, aggregate_categorical_count, aggregate_median
16+
from .utils import (
17+
aggregate_accuracy,
18+
aggregate_average,
19+
aggregate_weighted_average,
20+
aggregate_categorical_count,
21+
aggregate_median,
22+
)
1723

1824

1925
@click.command("run-benchmark")
@@ -144,6 +150,8 @@ def run_benchmark(
144150
output_res[scoring_fn].append(aggregate_categorical_count(scoring_results))
145151
elif aggregation_function == "average":
146152
output_res[scoring_fn].append(aggregate_average(scoring_results))
153+
elif aggregation_function == "weighted_average":
154+
output_res[scoring_fn].append(aggregate_weighted_average(scoring_results))
147155
elif aggregation_function == "median":
148156
output_res[scoring_fn].append(aggregate_median(scoring_results))
149157
elif aggregation_function == "accuracy":

src/llama_stack_client/lib/cli/eval/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ def aggregate_average(
2424
}
2525

2626

27+
def aggregate_weighted_average(
28+
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
29+
) -> Dict[str, Any]:
30+
return {
31+
"weighted_average": sum(
32+
result["score"] * result["weight"]
33+
for result in scoring_results
34+
if result["score"] is not None and result["weight"] is not None
35+
)
36+
/ sum(result["weight"] for result in scoring_results if result["weight"] is not None),
37+
}
38+
39+
2740
def aggregate_median(
2841
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
2942
) -> Dict[str, Any]:

0 commit comments

Comments
 (0)