diff --git a/src/llama_stack_client/lib/cli/eval/run_benchmark.py b/src/llama_stack_client/lib/cli/eval/run_benchmark.py index ac03c564..933b1338 100644 --- a/src/llama_stack_client/lib/cli/eval/run_benchmark.py +++ b/src/llama_stack_client/lib/cli/eval/run_benchmark.py @@ -13,15 +13,22 @@ from tqdm.rich import tqdm from ..common.utils import create_bar_chart +from .utils import ( + aggregate_accuracy, + aggregate_average, + aggregate_categorical_count, + aggregate_median, +) @click.command("run-benchmark") -@click.argument("eval-task-ids", nargs=-1, required=True) +@click.argument("benchmark-ids", nargs=-1, required=True) @click.option( - "--eval-task-config", + "--model-id", required=True, - help="Path to the eval task config file in JSON format", - type=click.Path(exists=True), + help="model id to run the benchmark eval on", + default=None, + type=str, ) @click.option( "--output-dir", @@ -35,6 +42,34 @@ default=None, type=int, ) +@click.option( + "--temperature", + required=False, + help="temperature in the sampling params to run generation", + default=0.0, + type=float, +) +@click.option( + "--max-tokens", + required=False, + help="max-tokens in the sampling params to run generation", + default=4096, + type=int, +) +@click.option( + "--top-p", + required=False, + help="top-p in the sampling params to run generation", + default=0.9, + type=float, +) +@click.option( + "--repeat-penalty", + required=False, + help="repeat-penalty in the sampling params to run generation", + default=1.0, + type=float, +) @click.option( "--visualize", is_flag=True, @@ -44,36 +79,50 @@ @click.pass_context def run_benchmark( ctx, - eval_task_ids: tuple[str, ...], - eval_task_config: str, + benchmark_ids: tuple[str, ...], + model_id: str, output_dir: str, num_examples: Optional[int], + temperature: float, + max_tokens: int, + top_p: float, + repeat_penalty: float, visualize: bool, ): """Run a evaluation benchmark task""" client = ctx.obj["client"] - for eval_task_id in eval_task_ids: - eval_task = client.eval_tasks.retrieve(name=eval_task_id) - scoring_functions = eval_task.scoring_functions - dataset_id = eval_task.dataset_id + for benchmark_id in benchmark_ids: + benchmark = client.benchmarks.retrieve(benchmark_id=benchmark_id) + scoring_functions = benchmark.scoring_functions + dataset_id = benchmark.dataset_id rows = client.datasetio.get_rows_paginated( - dataset_id=dataset_id, rows_in_page=-1 if num_examples is None else num_examples + dataset_id=dataset_id, + rows_in_page=-1 if num_examples is None else num_examples, ) - with open(eval_task_config, "r") as f: - eval_task_config = json.load(f) - output_res = {} - for r in tqdm(rows.rows): - eval_res = client.eval.evaluate_rows( - task_id=eval_task_id, + for i, r in enumerate(tqdm(rows.rows)): + eval_res = client.eval.evaluate_rows_alpha( + benchmark_id=benchmark_id, input_rows=[r], scoring_functions=scoring_functions, - task_config=eval_task_config, + task_config={ + "type": "benchmark", + "eval_candidate": { + "type": "model", + "model": model_id, + "sampling_params": { + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "repeat_penalty": repeat_penalty, + }, + }, + }, ) for k in r.keys(): if k not in output_res: @@ -90,10 +139,31 @@ def run_benchmark( output_res[scoring_fn] = [] output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) + aggregation_functions = client.scoring_functions.retrieve( + scoring_fn_id=scoring_fn + ).params.aggregation_functions + + # only output the aggregation result for the last row + if i == len(rows.rows) - 1: + for aggregation_function in aggregation_functions: + scoring_results = output_res[scoring_fn] + if aggregation_function == "categorical_count": + output_res[scoring_fn].append(aggregate_categorical_count(scoring_results)) + elif aggregation_function == "average": + output_res[scoring_fn].append(aggregate_average(scoring_results)) + elif aggregation_function == "median": + output_res[scoring_fn].append(aggregate_median(scoring_results)) + elif aggregation_function == "accuracy": + output_res[scoring_fn].append(aggregate_accuracy(scoring_results)) + else: + raise NotImplementedError( + f"Aggregation function {aggregation_function} is not supported yet" + ) + # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) # Save results to JSON file - output_file = os.path.join(output_dir, f"{eval_task_id}_results.json") + output_file = os.path.join(output_dir, f"{benchmark_id}_results.json") with open(output_file, "w") as f: json.dump(output_res, f, indent=2) @@ -101,9 +171,23 @@ def run_benchmark( if visualize: for scoring_fn in scoring_functions: - res = output_res[scoring_fn] - assert len(res) > 0 and "score" in res[0] - scores = [str(r["score"]) for r in res] - unique_scores = sorted(list(set(scores))) - counts = [scores.count(s) for s in unique_scores] - create_bar_chart(counts, unique_scores, title=f"{scoring_fn}") + aggregation_functions = client.scoring_functions.retrieve( + scoring_fn_id=scoring_fn + ).params.aggregation_functions + + for aggregation_function in aggregation_functions: + res = output_res[scoring_fn] + assert len(res) > 0 and "score" in res[0] + if aggregation_function == "categorical_count": + scores = [str(r["score"]) for r in res] + unique_scores = sorted(list(set(scores))) + counts = [scores.count(s) for s in unique_scores] + create_bar_chart( + counts, + unique_scores, + title=f"{scoring_fn}-{aggregation_function}", + ) + else: + raise NotImplementedError( + f"Aggregation function {aggregation_function} ius not supported for visualization yet" + ) diff --git a/src/llama_stack_client/lib/cli/eval/utils.py b/src/llama_stack_client/lib/cli/eval/utils.py new file mode 100644 index 00000000..102b8817 --- /dev/null +++ b/src/llama_stack_client/lib/cli/eval/utils.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Union + + +def aggregate_categorical_count( + scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]], +) -> Dict[str, Any]: + scores = [str(r["score"]) for r in scoring_results] + unique_scores = sorted(list(set(scores))) + return {"categorical_count": {s: scores.count(s) for s in unique_scores}} + + +def aggregate_average( + scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]], +) -> Dict[str, Any]: + return { + "average": sum(result["score"] for result in scoring_results if result["score"] is not None) + / len([_ for _ in scoring_results if _["score"] is not None]), + } + + +def aggregate_median( + scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]], +) -> Dict[str, Any]: + scores = [r["score"] for r in scoring_results if r["score"] is not None] + median = statistics.median(scores) if scores else None + return {"median": median} + + +def aggregate_accuracy( + scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]], +) -> Dict[str, Any]: + num_correct = sum(result["score"] for result in scoring_results) + avg_score = num_correct / len(scoring_results) + + return { + "accuracy": avg_score, + "num_correct": num_correct, + "num_total": len(scoring_results), + }