Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 109 additions & 25 deletions src/llama_stack_client/lib/cli/eval/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,22 @@
from tqdm.rich import tqdm

from ..common.utils import create_bar_chart
from .utils import (
aggregate_accuracy,
aggregate_average,
aggregate_categorical_count,
aggregate_median,
)


@click.command("run-benchmark")
@click.argument("eval-task-ids", nargs=-1, required=True)
@click.argument("benchmark-ids", nargs=-1, required=True)
@click.option(
"--eval-task-config",
"--model-id",
required=True,
help="Path to the eval task config file in JSON format",
type=click.Path(exists=True),
help="model id to run the benchmark eval on",
default=None,
type=str,
)
@click.option(
"--output-dir",
Expand All @@ -35,6 +42,34 @@
default=None,
type=int,
)
@click.option(
"--temperature",
required=False,
help="temperature in the sampling params to run generation",
default=0.0,
type=float,
)
@click.option(
"--max-tokens",
required=False,
help="max-tokens in the sampling params to run generation",
default=4096,
type=int,
)
@click.option(
"--top-p",
required=False,
help="top-p in the sampling params to run generation",
default=0.9,
type=float,
)
@click.option(
"--repeat-penalty",
required=False,
help="repeat-penalty in the sampling params to run generation",
default=1.0,
type=float,
)
@click.option(
"--visualize",
is_flag=True,
Expand All @@ -44,36 +79,50 @@
@click.pass_context
def run_benchmark(
ctx,
eval_task_ids: tuple[str, ...],
eval_task_config: str,
benchmark_ids: tuple[str, ...],
model_id: str,
output_dir: str,
num_examples: Optional[int],
temperature: float,
max_tokens: int,
top_p: float,
repeat_penalty: float,
visualize: bool,
):
"""Run a evaluation benchmark task"""

client = ctx.obj["client"]

for eval_task_id in eval_task_ids:
eval_task = client.eval_tasks.retrieve(name=eval_task_id)
scoring_functions = eval_task.scoring_functions
dataset_id = eval_task.dataset_id
for benchmark_id in benchmark_ids:
benchmark = client.benchmarks.retrieve(benchmark_id=benchmark_id)
scoring_functions = benchmark.scoring_functions
dataset_id = benchmark.dataset_id

rows = client.datasetio.get_rows_paginated(
dataset_id=dataset_id, rows_in_page=-1 if num_examples is None else num_examples
dataset_id=dataset_id,
rows_in_page=-1 if num_examples is None else num_examples,
)

with open(eval_task_config, "r") as f:
eval_task_config = json.load(f)

output_res = {}

for r in tqdm(rows.rows):
eval_res = client.eval.evaluate_rows(
task_id=eval_task_id,
for i, r in enumerate(tqdm(rows.rows)):
eval_res = client.eval.evaluate_rows_alpha(
benchmark_id=benchmark_id,
input_rows=[r],
scoring_functions=scoring_functions,
task_config=eval_task_config,
task_config={
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": model_id,
"sampling_params": {
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"repeat_penalty": repeat_penalty,
},
},
},
)
for k in r.keys():
if k not in output_res:
Expand All @@ -90,20 +139,55 @@ def run_benchmark(
output_res[scoring_fn] = []
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])

aggregation_functions = client.scoring_functions.retrieve(
scoring_fn_id=scoring_fn
).params.aggregation_functions

# only output the aggregation result for the last row
if i == len(rows.rows) - 1:
for aggregation_function in aggregation_functions:
scoring_results = output_res[scoring_fn]
if aggregation_function == "categorical_count":
output_res[scoring_fn].append(aggregate_categorical_count(scoring_results))
elif aggregation_function == "average":
output_res[scoring_fn].append(aggregate_average(scoring_results))
elif aggregation_function == "median":
output_res[scoring_fn].append(aggregate_median(scoring_results))
elif aggregation_function == "accuracy":
output_res[scoring_fn].append(aggregate_accuracy(scoring_results))
else:
raise NotImplementedError(
f"Aggregation function {aggregation_function} is not supported yet"
)

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Save results to JSON file
output_file = os.path.join(output_dir, f"{eval_task_id}_results.json")
output_file = os.path.join(output_dir, f"{benchmark_id}_results.json")
with open(output_file, "w") as f:
json.dump(output_res, f, indent=2)

rprint(f"[green]✓[/green] Results saved to: [blue]{output_file}[/blue]!\n")

if visualize:
for scoring_fn in scoring_functions:
res = output_res[scoring_fn]
assert len(res) > 0 and "score" in res[0]
scores = [str(r["score"]) for r in res]
unique_scores = sorted(list(set(scores)))
counts = [scores.count(s) for s in unique_scores]
create_bar_chart(counts, unique_scores, title=f"{scoring_fn}")
aggregation_functions = client.scoring_functions.retrieve(
scoring_fn_id=scoring_fn
).params.aggregation_functions

for aggregation_function in aggregation_functions:
res = output_res[scoring_fn]
assert len(res) > 0 and "score" in res[0]
if aggregation_function == "categorical_count":
scores = [str(r["score"]) for r in res]
unique_scores = sorted(list(set(scores)))
counts = [scores.count(s) for s in unique_scores]
create_bar_chart(
counts,
unique_scores,
title=f"{scoring_fn}-{aggregation_function}",
)
else:
raise NotImplementedError(
f"Aggregation function {aggregation_function} ius not supported for visualization yet"
)
45 changes: 45 additions & 0 deletions src/llama_stack_client/lib/cli/eval/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Any, Dict, List, Union


def aggregate_categorical_count(
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
) -> Dict[str, Any]:
scores = [str(r["score"]) for r in scoring_results]
unique_scores = sorted(list(set(scores)))
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}


def aggregate_average(
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
) -> Dict[str, Any]:
return {
"average": sum(result["score"] for result in scoring_results if result["score"] is not None)
/ len([_ for _ in scoring_results if _["score"] is not None]),
}


def aggregate_median(
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
) -> Dict[str, Any]:
scores = [r["score"] for r in scoring_results if r["score"] is not None]
median = statistics.median(scores) if scores else None
return {"median": median}


def aggregate_accuracy(
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
) -> Dict[str, Any]:
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)

return {
"accuracy": avg_score,
"num_correct": num_correct,
"num_total": len(scoring_results),
}