diff --git a/src/llama_stack_client/lib/cli/common/utils.py b/src/llama_stack_client/lib/cli/common/utils.py index 6d52d793..4f7893af 100644 --- a/src/llama_stack_client/lib/cli/common/utils.py +++ b/src/llama_stack_client/lib/cli/common/utils.py @@ -5,6 +5,8 @@ # the root directory of this source tree. from rich.console import Console from rich.table import Table +from rich.panel import Panel +from functools import wraps def create_bar_chart(data, labels, title=""): @@ -28,3 +30,24 @@ def create_bar_chart(data, labels, title=""): table.add_row(label, f"[{color}]{bar}[/] {value}/{total_count}") console.print(table) + + +def handle_client_errors(operation_name): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + console = Console() + console.print( + Panel.fit( + f"[bold red]Failed to {operation_name}[/bold red]\n\n" + f"[yellow]Error Type:[/yellow] {e.__class__.__name__}\n" + f"[yellow]Details:[/yellow] {str(e)}" + ) + ) + + return wrapper + + return decorator diff --git a/src/llama_stack_client/lib/cli/datasets/list.py b/src/llama_stack_client/lib/cli/datasets/list.py index 71beed95..d6be700c 100644 --- a/src/llama_stack_client/lib/cli/datasets/list.py +++ b/src/llama_stack_client/lib/cli/datasets/list.py @@ -8,9 +8,12 @@ from rich.console import Console from rich.table import Table +from ..common.utils import handle_client_errors + @click.command("list") @click.pass_context +@handle_client_errors("list datasets") def list_datasets(ctx): """Show available datasets on distribution endpoint""" client = ctx.obj["client"] diff --git a/src/llama_stack_client/lib/cli/datasets/register.py b/src/llama_stack_client/lib/cli/datasets/register.py index 5887cf42..f7bbb31a 100644 --- a/src/llama_stack_client/lib/cli/datasets/register.py +++ b/src/llama_stack_client/lib/cli/datasets/register.py @@ -12,6 +12,8 @@ import click import yaml +from ..common.utils import handle_client_errors + def data_url_from_file(file_path: str) -> str: if not os.path.exists(file_path): @@ -38,6 +40,7 @@ def data_url_from_file(file_path: str) -> str: ) @click.option("--schema", type=str, help="JSON schema of the dataset", required=True) @click.pass_context +@handle_client_errors("register dataset") def register( ctx, dataset_id: str, diff --git a/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py b/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py index f393ffb9..ca5af744 100644 --- a/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py +++ b/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py @@ -11,6 +11,7 @@ import click import yaml +from ..common.utils import handle_client_errors from .list import list_eval_tasks @@ -28,6 +29,7 @@ def eval_tasks(): @click.option("--provider-eval-task-id", help="Provider's eval task ID", default=None) @click.option("--metadata", type=str, help="Metadata for the eval task in JSON format") @click.pass_context +@handle_client_errors("register eval task") def register( ctx, eval_task_id: str, diff --git a/src/llama_stack_client/lib/cli/eval_tasks/list.py b/src/llama_stack_client/lib/cli/eval_tasks/list.py index de104e67..6c054e8e 100644 --- a/src/llama_stack_client/lib/cli/eval_tasks/list.py +++ b/src/llama_stack_client/lib/cli/eval_tasks/list.py @@ -9,8 +9,12 @@ from rich.table import Table +from ..common.utils import handle_client_errors + + @click.command("list") @click.pass_context +@handle_client_errors("list eval tasks") def list_eval_tasks(ctx): """Show available eval tasks on distribution endpoint""" diff --git a/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py b/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py index 6ee4391f..d38fb6bb 100644 --- a/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py +++ b/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py @@ -11,6 +11,8 @@ from rich.console import Console from rich.table import Table +from ..common.utils import handle_client_errors + @click.group() def memory_banks(): @@ -20,36 +22,66 @@ def memory_banks(): @click.command("list") @click.pass_context +@handle_client_errors("list memory banks") def list(ctx): """Show available memory banks on distribution endpoint""" client = ctx.obj["client"] console = Console() memory_banks_list_response = client.memory_banks.list() - headers = [] - if memory_banks_list_response and len(memory_banks_list_response) > 0: - headers = sorted(memory_banks_list_response[0].__dict__.keys()) if memory_banks_list_response: table = Table() - for header in headers: - table.add_column(header) + # Add our specific columns + table.add_column("identifier") + table.add_column("provider_id") + table.add_column("provider_resource_id") + table.add_column("memory_bank_type") + table.add_column("params") for item in memory_banks_list_response: - table.add_row(*[str(getattr(item, header)) for header in headers]) + # Create a dict of all attributes + item_dict = item.__dict__ + + # Extract our main columns + identifier = str(item_dict.pop("identifier", "")) + provider_id = str(item_dict.pop("provider_id", "")) + provider_resource_id = str(item_dict.pop("provider_resource_id", "")) + memory_bank_type = str(item_dict.pop("memory_bank_type", "")) + # Convert remaining attributes to YAML string for params column + params = yaml.dump(item_dict, default_flow_style=False) + + table.add_row(identifier, provider_id, provider_resource_id, memory_bank_type, params) + console.print(table) @memory_banks.command() -@click.option("--memory-bank-id", required=True, help="Id of the memory bank") +@click.argument("memory-bank-id") @click.option("--type", type=click.Choice(["vector", "keyvalue", "keyword", "graph"]), required=True) @click.option("--provider-id", help="Provider ID for the memory bank", default=None) @click.option("--provider-memory-bank-id", help="Provider's memory bank ID", default=None) -@click.option("--chunk-size", type=int, help="Chunk size in tokens (for vector type)", default=512) -@click.option("--embedding-model", type=str, help="Embedding model (for vector type)", default="all-MiniLM-L6-v2") -@click.option("--overlap-size", type=int, help="Overlap size in tokens (for vector type)", default=64) +@click.option( + "--chunk-size", + type=int, + help="Chunk size in tokens (for vector type)", + default=512, +) +@click.option( + "--embedding-model", + type=str, + help="Embedding model (for vector type)", + default="all-MiniLM-L6-v2", +) +@click.option( + "--overlap-size", + type=int, + help="Overlap size in tokens (for vector type)", + default=64, +) @click.pass_context -def create( +@handle_client_errors("register memory bank") +def register( ctx, memory_bank_id: str, type: str, @@ -65,18 +97,24 @@ def create( config = None if type == "vector": config = { - "type": "vector", + "memory_bank_type": "vector", "chunk_size_in_tokens": chunk_size, "embedding_model": embedding_model, } if overlap_size: config["overlap_size_in_tokens"] = overlap_size elif type == "keyvalue": - config = {"type": "keyvalue"} + config = {"memory_bank_type": "keyvalue"} elif type == "keyword": - config = {"type": "keyword"} + config = {"memory_bank_type": "keyword"} elif type == "graph": - config = {"type": "graph"} + config = {"memory_bank_type": "graph"} + + from rich import print as rprint + from rich.pretty import pprint + + rprint("\n[bold blue]Memory Bank Configuration:[/bold blue]") + pprint(config, expand_all=True) response = client.memory_banks.register( memory_bank_id=memory_bank_id, @@ -88,6 +126,18 @@ def create( click.echo(yaml.dump(response.dict())) +@memory_banks.command() +@click.argument("memory-bank-id") +@click.pass_context +@handle_client_errors("delete memory bank") +def unregister(ctx, memory_bank_id: str): + """Delete a memory bank""" + client = ctx.obj["client"] + client.memory_banks.unregister(memory_bank_id=memory_bank_id) + click.echo(f"Memory bank '{memory_bank_id}' deleted successfully") + + # Register subcommands memory_banks.add_command(list) -memory_banks.add_command(create) +memory_banks.add_command(register) +memory_banks.add_command(unregister) diff --git a/src/llama_stack_client/lib/cli/models/models.py b/src/llama_stack_client/lib/cli/models/models.py index 536391d6..f080e314 100644 --- a/src/llama_stack_client/lib/cli/models/models.py +++ b/src/llama_stack_client/lib/cli/models/models.py @@ -10,6 +10,8 @@ from rich.console import Console from rich.table import Table +from ..common.utils import handle_client_errors + @click.group() def models(): @@ -19,6 +21,7 @@ def models(): @click.command(name="list", help="Show available llama models at distribution endpoint") @click.pass_context +@handle_client_errors("list models") def list_models(ctx): client = ctx.obj["client"] console = Console() @@ -43,6 +46,7 @@ def list_models(ctx): @click.command(name="get") @click.argument("model_id") @click.pass_context +@handle_client_errors("get model details") def get_model(ctx, model_id: str): """Show available llama models at distribution endpoint""" client = ctx.obj["client"] @@ -51,9 +55,10 @@ def get_model(ctx, model_id: str): models_get_response = client.models.retrieve(identifier=model_id) if not models_get_response: - click.echo( + console.print( f"Model {model_id} is not found at distribution endpoint. " - "Please ensure endpoint is serving specified model." + "Please ensure endpoint is serving specified model.", + style="bold red", ) return @@ -72,62 +77,36 @@ def get_model(ctx, model_id: str): @click.option("--provider-model-id", help="Provider's model ID", default=None) @click.option("--metadata", help="JSON metadata for the model", default=None) @click.pass_context +@handle_client_errors("register model") def register_model( ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str] ): """Register a new model at distribution endpoint""" client = ctx.obj["client"] + console = Console() - try: - response = client.models.register( - model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata - ) - if response: - click.echo(f"Successfully registered model {model_id}") - except Exception as e: - click.echo(f"Failed to register model: {str(e)}") - - -@click.command(name="update", help="Update an existing model at distribution endpoint") -@click.argument("model_id") -@click.option("--provider-id", help="Provider ID for the model", default=None) -@click.option("--provider-model-id", help="Provider's model ID", default=None) -@click.option("--metadata", help="JSON metadata for the model", default=None) -@click.pass_context -def update_model( - ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str] -): - """Update an existing model at distribution endpoint""" - client = ctx.obj["client"] - - try: - response = client.models.update( - model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata - ) - if response: - click.echo(f"Successfully updated model {model_id}") - except Exception as e: - click.echo(f"Failed to update model: {str(e)}") + response = client.models.register( + model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata + ) + if response: + console.print(f"[green]Successfully registered model {model_id}[/green]") -@click.command(name="delete", help="Delete a model from distribution endpoint") +@click.command(name="unregister", help="Unregister a model from distribution endpoint") @click.argument("model_id") @click.pass_context -def delete_model(ctx, model_id: str): - """Delete a model from distribution endpoint""" +@handle_client_errors("unregister model") +def unregister_model(ctx, model_id: str): client = ctx.obj["client"] + console = Console() - try: - response = client.models.delete(model_id=model_id) - if response: - click.echo(f"Successfully deleted model {model_id}") - except Exception as e: - click.echo(f"Failed to delete model: {str(e)}") + response = client.models.unregister(model_id=model_id) + if response: + console.print(f"[green]Successfully deleted model {model_id}[/green]") # Register subcommands models.add_command(list_models) models.add_command(get_model) models.add_command(register_model) -models.add_command(update_model) -models.add_command(delete_model) +models.add_command(unregister_model) diff --git a/src/llama_stack_client/lib/cli/providers/list.py b/src/llama_stack_client/lib/cli/providers/list.py index dedbd8c9..de5ad6ea 100644 --- a/src/llama_stack_client/lib/cli/providers/list.py +++ b/src/llama_stack_client/lib/cli/providers/list.py @@ -2,9 +2,10 @@ from rich.console import Console from rich.table import Table - +from ..common.utils import handle_client_errors @click.command("list") @click.pass_context +@handle_client_errors("list providers") def list_providers(ctx): """Show available providers on distribution endpoint""" client = ctx.obj["client"] diff --git a/src/llama_stack_client/lib/cli/scoring_functions/list.py b/src/llama_stack_client/lib/cli/scoring_functions/list.py index cab57980..7a2ad17c 100644 --- a/src/llama_stack_client/lib/cli/scoring_functions/list.py +++ b/src/llama_stack_client/lib/cli/scoring_functions/list.py @@ -8,9 +8,12 @@ from rich.console import Console from rich.table import Table +from ..common.utils import handle_client_errors + @click.command("list") @click.pass_context +@handle_client_errors("list scoring functions") def list_scoring_functions(ctx): """Show available scoring functions on distribution endpoint""" diff --git a/src/llama_stack_client/lib/cli/shields/shields.py b/src/llama_stack_client/lib/cli/shields/shields.py index a95f3cf5..ec12b4ab 100644 --- a/src/llama_stack_client/lib/cli/shields/shields.py +++ b/src/llama_stack_client/lib/cli/shields/shields.py @@ -11,6 +11,8 @@ from rich.console import Console from rich.table import Table +from ..common.utils import handle_client_errors + @click.group() def shields(): @@ -20,6 +22,7 @@ def shields(): @click.command("list") @click.pass_context +@handle_client_errors("list shields") def list(ctx): """Show available safety shields on distribution endpoint""" client = ctx.obj["client"] @@ -46,6 +49,7 @@ def list(ctx): @click.option("--provider-shield-id", help="Provider's shield ID", default=None) @click.option("--params", type=str, help="JSON configuration parameters for the shield", default=None) @click.pass_context +@handle_client_errors("register shield") def register( ctx, shield_id: str,