From febab1027f4e75addf168311c84c922ea581262a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 14 Nov 2024 18:42:07 -0500 Subject: [PATCH 1/3] pretty outputs --- .../lib/cli/common/utils.py | 29 ++++++++++++++----- .../lib/cli/eval/run_benchmark.py | 27 +++++++++++++++-- .../lib/cli/llama_stack_client.py | 2 +- 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/llama_stack_client/lib/cli/common/utils.py b/src/llama_stack_client/lib/cli/common/utils.py index b05a035a..6d52d793 100644 --- a/src/llama_stack_client/lib/cli/common/utils.py +++ b/src/llama_stack_client/lib/cli/common/utils.py @@ -3,15 +3,28 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from tabulate import tabulate +from rich.console import Console +from rich.table import Table -def print_table_from_response(response, headers=()): - if not headers: - headers = sorted(response[0].__dict__.keys()) +def create_bar_chart(data, labels, title=""): + """Create a bar chart using Rich Table.""" - rows = [] - for spec in response: - rows.append([spec.__dict__[headers[i]] for i in range(len(headers))]) + console = Console() + table = Table(title=title) + table.add_column("Score") + table.add_column("Count") - print(tabulate(rows, headers=headers, tablefmt="grid")) + max_value = max(data) + total_count = sum(data) + + # Define a list of colors to cycle through + colors = ["green", "blue", "red", "yellow", "magenta", "cyan"] + + for i, (label, value) in enumerate(zip(labels, data)): + bar_length = int((value / max_value) * 20) # Adjust bar length as needed + bar = "█" * bar_length + " " * (20 - bar_length) + color = colors[i % len(colors)] + table.add_row(label, f"[{color}]{bar}[/] {value}/{total_count}") + + console.print(table) diff --git a/src/llama_stack_client/lib/cli/eval/run_benchmark.py b/src/llama_stack_client/lib/cli/eval/run_benchmark.py index 24f1c791..7beeee0e 100644 --- a/src/llama_stack_client/lib/cli/eval/run_benchmark.py +++ b/src/llama_stack_client/lib/cli/eval/run_benchmark.py @@ -9,8 +9,11 @@ from typing import Optional import click +from rich import print as rprint from tqdm.rich import tqdm +from ..common.utils import create_bar_chart + @click.command("run_benchmark") @click.argument("eval-task-ids", nargs=-1, required=True) @@ -28,9 +31,20 @@ @click.option( "--num-examples", required=False, help="Number of examples to evaluate on, useful for debugging", default=None ) +@click.option( + "--visualize", + is_flag=True, + default=False, + help="Visualize evaluation results after completion", +) @click.pass_context def run_benchmark( - ctx, eval_task_ids: tuple[str, ...], eval_task_config: str, output_dir: str, num_examples: Optional[int] + ctx, + eval_task_ids: tuple[str, ...], + eval_task_config: str, + output_dir: str, + num_examples: Optional[int], + visualize: bool, ): """Run a evaluation benchmark""" @@ -79,4 +93,13 @@ def run_benchmark( with open(output_file, "w") as f: json.dump(output_res, f, indent=2) - print(f"Results saved to: {output_file}") + rprint(f"[green]✓[/green] Results saved to: [blue]{output_file}[/blue]!\n") + + if visualize: + for scoring_fn in ["llm-as-judge::llm_as_judge_base"]: + res = output_res[scoring_fn] + assert len(res) > 0 and "score" in res[0] + scores = [r["score"] for r in res] + unique_scores = sorted(list(set([r["score"] for r in res]))) + counts = [scores.count(s) for s in unique_scores] + create_bar_chart(counts, unique_scores, title=f"ScoringFunction = {scoring_fn}") diff --git a/src/llama_stack_client/lib/cli/llama_stack_client.py b/src/llama_stack_client/lib/cli/llama_stack_client.py index f6ef91bd..a8b8f30e 100644 --- a/src/llama_stack_client/lib/cli/llama_stack_client.py +++ b/src/llama_stack_client/lib/cli/llama_stack_client.py @@ -57,7 +57,7 @@ def cli(ctx, endpoint: str, config: str | None): base_url=endpoint, provider_data={ "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""), - "togethers_api_key": os.environ.get("TOGETHERS_API_KEY", ""), + "together_api_key": os.environ.get("TOGETHER_API_KEY", ""), }, ) ctx.obj = {"client": client} From 82db984d3e371618b524281eb442ecf7e673705e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 14 Nov 2024 22:26:59 -0500 Subject: [PATCH 2/3] requirements, fix --- pyproject.toml | 1 - requirements-dev.lock | 4 ---- requirements.lock | 4 ---- src/llama_stack_client/lib/cli/__init__.py | 7 +++++++ src/llama_stack_client/lib/cli/eval/run_benchmark.py | 6 +++--- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c8bf36a5..0ce8fb4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,6 @@ dependencies = [ "distro>=1.7.0, <2", "sniffio", "cached-property; python_version < '3.8'", - "tabulate>=0.9.0", ] requires-python = ">= 3.7" classifiers = [ diff --git a/requirements-dev.lock b/requirements-dev.lock index 518599b1..6a4b3f93 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -90,10 +90,6 @@ sniffio==1.3.0 # via anyio # via httpx # via llama-stack-client -tabulate==0.9.0 - # via llama-stack-client -termcolor==2.4.0 - # via llama-stack-client time-machine==2.9.0 tomli==2.0.1 # via mypy diff --git a/requirements.lock b/requirements.lock index 23271295..7a439f22 100644 --- a/requirements.lock +++ b/requirements.lock @@ -39,10 +39,6 @@ sniffio==1.3.0 # via anyio # via httpx # via llama-stack-client -tabulate==0.9.0 - # via llama-stack-client -termcolor==2.4.0 - # via llama-stack-client typing-extensions==4.8.0 # via anyio # via llama-stack-client diff --git a/src/llama_stack_client/lib/cli/__init__.py b/src/llama_stack_client/lib/cli/__init__.py index 756f351d..77737e7d 100644 --- a/src/llama_stack_client/lib/cli/__init__.py +++ b/src/llama_stack_client/lib/cli/__init__.py @@ -3,3 +3,10 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +# Ignore tqdm experimental warning +import warnings + +from tqdm import TqdmExperimentalWarning + +warnings.filterwarnings("ignore", category=TqdmExperimentalWarning) diff --git a/src/llama_stack_client/lib/cli/eval/run_benchmark.py b/src/llama_stack_client/lib/cli/eval/run_benchmark.py index 7beeee0e..d1163587 100644 --- a/src/llama_stack_client/lib/cli/eval/run_benchmark.py +++ b/src/llama_stack_client/lib/cli/eval/run_benchmark.py @@ -96,10 +96,10 @@ def run_benchmark( rprint(f"[green]✓[/green] Results saved to: [blue]{output_file}[/blue]!\n") if visualize: - for scoring_fn in ["llm-as-judge::llm_as_judge_base"]: + for scoring_fn in scoring_functions: res = output_res[scoring_fn] assert len(res) > 0 and "score" in res[0] - scores = [r["score"] for r in res] - unique_scores = sorted(list(set([r["score"] for r in res]))) + scores = [str(r["score"]) for r in res] + unique_scores = sorted(list(set(scores))) counts = [scores.count(s) for s in unique_scores] create_bar_chart(counts, unique_scores, title=f"ScoringFunction = {scoring_fn}") From b6d8d1021aec1ec0fe65e58f48c4ec0efb6fba8b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 14 Nov 2024 22:32:50 -0500 Subject: [PATCH 3/3] format --- src/llama_stack_client/lib/cli/eval/run_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_stack_client/lib/cli/eval/run_benchmark.py b/src/llama_stack_client/lib/cli/eval/run_benchmark.py index d1163587..9cde1fb1 100644 --- a/src/llama_stack_client/lib/cli/eval/run_benchmark.py +++ b/src/llama_stack_client/lib/cli/eval/run_benchmark.py @@ -102,4 +102,4 @@ def run_benchmark( scores = [str(r["score"]) for r in res] unique_scores = sorted(list(set(scores))) counts = [scores.count(s) for s in unique_scores] - create_bar_chart(counts, unique_scores, title=f"ScoringFunction = {scoring_fn}") + create_bar_chart(counts, unique_scores, title=f"{scoring_fn}")