Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/llama_stack_client/lib/cli/eval/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from tqdm.rich import tqdm

from ..common.utils import create_bar_chart
from .utils import aggregate_accuracy, aggregate_average, aggregate_categorical_count, aggregate_median
from .utils import (
aggregate_accuracy,
aggregate_average,
aggregate_weighted_average,
aggregate_categorical_count,
aggregate_median,
)


@click.command("run-benchmark")
Expand Down Expand Up @@ -94,9 +100,7 @@ def run_benchmark(
scoring_functions = benchmark.scoring_functions
dataset_id = benchmark.dataset_id

results = client.datasets.iterrows(
dataset_id=dataset_id, limit=-1 if num_examples is None else num_examples
)
results = client.datasets.iterrows(dataset_id=dataset_id, limit=-1 if num_examples is None else num_examples)

output_res = {}

Expand Down Expand Up @@ -146,6 +150,8 @@ def run_benchmark(
output_res[scoring_fn].append(aggregate_categorical_count(scoring_results))
elif aggregation_function == "average":
output_res[scoring_fn].append(aggregate_average(scoring_results))
elif aggregation_function == "weighted_average":
output_res[scoring_fn].append(aggregate_weighted_average(scoring_results))
elif aggregation_function == "median":
output_res[scoring_fn].append(aggregate_median(scoring_results))
elif aggregation_function == "accuracy":
Expand Down
13 changes: 13 additions & 0 deletions src/llama_stack_client/lib/cli/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ def aggregate_average(
}


def aggregate_weighted_average(
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
) -> Dict[str, Any]:
return {
"weighted_average": sum(
result["score"] * result["weight"]
for result in scoring_results
if result["score"] is not None and result["weight"] is not None
)
/ sum(result["weight"] for result in scoring_results if result["weight"] is not None),
}


def aggregate_median(
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
) -> Dict[str, Any]:
Expand Down