diff --git a/src/llama_stack_client/lib/cli/datasets/list.py b/src/llama_stack_client/lib/cli/datasets/list.py index 731dc4fd..6b229629 100644 --- a/src/llama_stack_client/lib/cli/datasets/list.py +++ b/src/llama_stack_client/lib/cli/datasets/list.py @@ -5,8 +5,8 @@ # the root directory of this source tree. import click - -from llama_stack_client.lib.cli.common.utils import print_table_from_response +from rich.table import Table +from rich.console import Console @click.command("list") @@ -14,9 +14,15 @@ def list_datasets(ctx): """Show available datasets on distribution endpoint""" client = ctx.obj["client"] - + console = Console() headers = ["identifier", "provider_id", "metadata", "type"] datasets_list_response = client.datasets.list() if datasets_list_response: - print_table_from_response(datasets_list_response, headers) + table = Table() + for header in headers: + table.add_column(header) + + for item in datasets_list_response: + table.add_row(*[str(getattr(item, header)) for header in headers]) + console.print(table) diff --git a/src/llama_stack_client/lib/cli/eval_tasks/list.py b/src/llama_stack_client/lib/cli/eval_tasks/list.py index e0a1418f..841b8aaa 100644 --- a/src/llama_stack_client/lib/cli/eval_tasks/list.py +++ b/src/llama_stack_client/lib/cli/eval_tasks/list.py @@ -5,8 +5,8 @@ # the root directory of this source tree. import click - -from llama_stack_client.lib.cli.common.utils import print_table_from_response +from rich.table import Table +from rich.console import Console @click.command("list") @@ -15,11 +15,17 @@ def list_eval_tasks(ctx): """Show available eval tasks on distribution endpoint""" client = ctx.obj["client"] - + console = Console() headers = [] eval_tasks_list_response = client.eval_tasks.list() if eval_tasks_list_response and len(eval_tasks_list_response) > 0: headers = sorted(eval_tasks_list_response[0].__dict__.keys()) if eval_tasks_list_response: - print_table_from_response(eval_tasks_list_response, headers) + table = Table() + for header in headers: + table.add_column(header) + + for item in eval_tasks_list_response: + table.add_row(*[str(getattr(item, header)) for header in headers]) + console.print(table) diff --git a/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py b/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py index 4361a1c8..d432bfa5 100644 --- a/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py +++ b/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py @@ -7,7 +7,8 @@ import click from typing import Optional import yaml -from llama_stack_client.lib.cli.common.utils import print_table_from_response +from rich.table import Table +from rich.console import Console @click.group() @@ -22,14 +23,20 @@ def list(ctx): """Show available memory banks on distribution endpoint""" client = ctx.obj["client"] - + console = Console() memory_banks_list_response = client.memory_banks.list() headers = [] if memory_banks_list_response and len(memory_banks_list_response) > 0: headers = sorted(memory_banks_list_response[0].__dict__.keys()) if memory_banks_list_response: - print_table_from_response(memory_banks_list_response, headers) + table = Table() + for header in headers: + table.add_column(header) + + for item in memory_banks_list_response: + table.add_row(*[str(getattr(item, header)) for header in headers]) + console.print(table) @memory_banks.command() diff --git a/src/llama_stack_client/lib/cli/models/models.py b/src/llama_stack_client/lib/cli/models/models.py index 1e22795e..bf694ec9 100644 --- a/src/llama_stack_client/lib/cli/models/models.py +++ b/src/llama_stack_client/lib/cli/models/models.py @@ -5,8 +5,8 @@ # the root directory of this source tree. import click -from tabulate import tabulate -from llama_stack_client.lib.cli.common.utils import print_table_from_response +from rich.table import Table +from rich.console import Console from typing import Optional @@ -20,11 +20,23 @@ def models(): @click.pass_context def list_models(ctx): client = ctx.obj["client"] + console = Console() headers = ["identifier", "provider_id", "provider_resource_id", "metadata"] response = client.models.list() if response: - print_table_from_response(response, headers) + table = Table() + for header in headers: + table.add_column(header) + + for item in response: + table.add_row( + str(getattr(item, headers[0])), + str(getattr(item, headers[1])), + str(getattr(item, headers[2])), + str(getattr(item, headers[3])), + ) + console.print(table) @click.command(name="get") @@ -33,6 +45,7 @@ def list_models(ctx): def get_model(ctx, model_id: str): """Show available llama models at distribution endpoint""" client = ctx.obj["client"] + console = Console() models_get_response = client.models.retrieve(identifier=model_id) @@ -44,10 +57,12 @@ def get_model(ctx, model_id: str): return headers = sorted(models_get_response.__dict__.keys()) - rows = [] - rows.append([models_get_response.__dict__[headers[i]] for i in range(len(headers))]) + table = Table() + for header in headers: + table.add_column(header) - click.echo(tabulate(rows, headers=headers, tablefmt="grid")) + table.add_row(*[str(models_get_response.__dict__[header]) for header in headers]) + console.print(table) @click.command(name="register", help="Register a new model at distribution endpoint") diff --git a/src/llama_stack_client/lib/cli/providers/list.py b/src/llama_stack_client/lib/cli/providers/list.py index 2babd42d..974ede00 100644 --- a/src/llama_stack_client/lib/cli/providers/list.py +++ b/src/llama_stack_client/lib/cli/providers/list.py @@ -1,5 +1,6 @@ import click -from tabulate import tabulate +from rich.table import Table +from rich.console import Console @click.command("list") @@ -7,14 +8,16 @@ def list_providers(ctx): """Show available providers on distribution endpoint""" client = ctx.obj["client"] - + console = Console() headers = ["API", "Provider ID", "Provider Type"] providers_response = client.providers.list() - rows = [] + table = Table() + for header in headers: + table.add_column(header) for k, v in providers_response.items(): for provider_info in v: - rows.append([k, provider_info.provider_id, provider_info.provider_type]) + table.add_row(k, provider_info.provider_id, provider_info.provider_type) - click.echo(tabulate(rows, headers=headers, tablefmt="grid")) + console.print(table) diff --git a/src/llama_stack_client/lib/cli/scoring_functions/list.py b/src/llama_stack_client/lib/cli/scoring_functions/list.py index 73abd0bc..55e4369d 100644 --- a/src/llama_stack_client/lib/cli/scoring_functions/list.py +++ b/src/llama_stack_client/lib/cli/scoring_functions/list.py @@ -5,8 +5,8 @@ # the root directory of this source tree. import click - -from llama_stack_client.lib.cli.common.utils import print_table_from_response +from rich.table import Table +from rich.console import Console @click.command("list") @@ -15,7 +15,7 @@ def list_scoring_functions(ctx): """Show available scoring functions on distribution endpoint""" client = ctx.obj["client"] - + console = Console() headers = [ "identifier", "provider_id", @@ -25,4 +25,10 @@ def list_scoring_functions(ctx): scoring_functions_list_response = client.scoring_functions.list() if scoring_functions_list_response: - print_table_from_response(scoring_functions_list_response, headers) + table = Table() + for header in headers: + table.add_column(header) + + for item in scoring_functions_list_response: + table.add_row(*[str(getattr(item, header)) for header in headers]) + console.print(table) diff --git a/src/llama_stack_client/lib/cli/shields/shields.py b/src/llama_stack_client/lib/cli/shields/shields.py index 574a2dbe..ae963520 100644 --- a/src/llama_stack_client/lib/cli/shields/shields.py +++ b/src/llama_stack_client/lib/cli/shields/shields.py @@ -7,7 +7,8 @@ import click from typing import Optional import yaml -from llama_stack_client.lib.cli.common.utils import print_table_from_response +from rich.table import Table +from rich.console import Console @click.group() @@ -21,6 +22,7 @@ def shields(): def list(ctx): """Show available safety shields on distribution endpoint""" client = ctx.obj["client"] + console = Console() shields_list_response = client.shields.list() headers = [] @@ -28,7 +30,13 @@ def list(ctx): headers = sorted(shields_list_response[0].__dict__.keys()) if shields_list_response: - print_table_from_response(shields_list_response, headers) + table = Table() + for header in headers: + table.add_column(header) + + for item in shields_list_response: + table.add_row(*[str(getattr(item, header)) for header in headers]) + console.print(table) @shields.command()