Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/llama_stack_client/lib/cli/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# the root directory of this source tree.
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from functools import wraps


def create_bar_chart(data, labels, title=""):
Expand All @@ -28,3 +30,24 @@ def create_bar_chart(data, labels, title=""):
table.add_row(label, f"[{color}]{bar}[/] {value}/{total_count}")

console.print(table)


def handle_client_errors(operation_name):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
console = Console()
console.print(
Panel.fit(
f"[bold red]Failed to {operation_name}[/bold red]\n\n"
f"[yellow]Error Type:[/yellow] {e.__class__.__name__}\n"
f"[yellow]Details:[/yellow] {str(e)}"
)
)

return wrapper

return decorator
3 changes: 3 additions & 0 deletions src/llama_stack_client/lib/cli/datasets/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from rich.console import Console
from rich.table import Table

from ..common.utils import handle_client_errors


@click.command("list")
@click.pass_context
@handle_client_errors("list datasets")
def list_datasets(ctx):
"""Show available datasets on distribution endpoint"""
client = ctx.obj["client"]
Expand Down
3 changes: 3 additions & 0 deletions src/llama_stack_client/lib/cli/datasets/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import click
import yaml

from ..common.utils import handle_client_errors


def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
Expand All @@ -38,6 +40,7 @@ def data_url_from_file(file_path: str) -> str:
)
@click.option("--schema", type=str, help="JSON schema of the dataset", required=True)
@click.pass_context
@handle_client_errors("register dataset")
def register(
ctx,
dataset_id: str,
Expand Down
2 changes: 2 additions & 0 deletions src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import click
import yaml

from ..common.utils import handle_client_errors
from .list import list_eval_tasks


Expand All @@ -28,6 +29,7 @@ def eval_tasks():
@click.option("--provider-eval-task-id", help="Provider's eval task ID", default=None)
@click.option("--metadata", type=str, help="Metadata for the eval task in JSON format")
@click.pass_context
@handle_client_errors("register eval task")
def register(
ctx,
eval_task_id: str,
Expand Down
4 changes: 4 additions & 0 deletions src/llama_stack_client/lib/cli/eval_tasks/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from rich.table import Table


from ..common.utils import handle_client_errors


@click.command("list")
@click.pass_context
@handle_client_errors("list eval tasks")
def list_eval_tasks(ctx):
"""Show available eval tasks on distribution endpoint"""

Expand Down
82 changes: 66 additions & 16 deletions src/llama_stack_client/lib/cli/memory_banks/memory_banks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from rich.console import Console
from rich.table import Table

from ..common.utils import handle_client_errors


@click.group()
def memory_banks():
Expand All @@ -20,36 +22,66 @@ def memory_banks():

@click.command("list")
@click.pass_context
@handle_client_errors("list memory banks")
def list(ctx):
"""Show available memory banks on distribution endpoint"""

client = ctx.obj["client"]
console = Console()
memory_banks_list_response = client.memory_banks.list()
headers = []
if memory_banks_list_response and len(memory_banks_list_response) > 0:
headers = sorted(memory_banks_list_response[0].__dict__.keys())

if memory_banks_list_response:
table = Table()
for header in headers:
table.add_column(header)
# Add our specific columns
table.add_column("identifier")
table.add_column("provider_id")
table.add_column("provider_resource_id")
table.add_column("memory_bank_type")
table.add_column("params")

for item in memory_banks_list_response:
table.add_row(*[str(getattr(item, header)) for header in headers])
# Create a dict of all attributes
item_dict = item.__dict__

# Extract our main columns
identifier = str(item_dict.pop("identifier", ""))
provider_id = str(item_dict.pop("provider_id", ""))
provider_resource_id = str(item_dict.pop("provider_resource_id", ""))
memory_bank_type = str(item_dict.pop("memory_bank_type", ""))
# Convert remaining attributes to YAML string for params column
params = yaml.dump(item_dict, default_flow_style=False)

table.add_row(identifier, provider_id, provider_resource_id, memory_bank_type, params)

console.print(table)


@memory_banks.command()
@click.option("--memory-bank-id", required=True, help="Id of the memory bank")
@click.argument("memory-bank-id")
@click.option("--type", type=click.Choice(["vector", "keyvalue", "keyword", "graph"]), required=True)
@click.option("--provider-id", help="Provider ID for the memory bank", default=None)
@click.option("--provider-memory-bank-id", help="Provider's memory bank ID", default=None)
@click.option("--chunk-size", type=int, help="Chunk size in tokens (for vector type)", default=512)
@click.option("--embedding-model", type=str, help="Embedding model (for vector type)", default="all-MiniLM-L6-v2")
@click.option("--overlap-size", type=int, help="Overlap size in tokens (for vector type)", default=64)
@click.option(
"--chunk-size",
type=int,
help="Chunk size in tokens (for vector type)",
default=512,
)
@click.option(
"--embedding-model",
type=str,
help="Embedding model (for vector type)",
default="all-MiniLM-L6-v2",
)
@click.option(
"--overlap-size",
type=int,
help="Overlap size in tokens (for vector type)",
default=64,
)
@click.pass_context
def create(
@handle_client_errors("register memory bank")
def register(
ctx,
memory_bank_id: str,
type: str,
Expand All @@ -65,18 +97,24 @@ def create(
config = None
if type == "vector":
config = {
"type": "vector",
"memory_bank_type": "vector",
"chunk_size_in_tokens": chunk_size,
"embedding_model": embedding_model,
}
if overlap_size:
config["overlap_size_in_tokens"] = overlap_size
elif type == "keyvalue":
config = {"type": "keyvalue"}
config = {"memory_bank_type": "keyvalue"}
elif type == "keyword":
config = {"type": "keyword"}
config = {"memory_bank_type": "keyword"}
elif type == "graph":
config = {"type": "graph"}
config = {"memory_bank_type": "graph"}

from rich import print as rprint
from rich.pretty import pprint

rprint("\n[bold blue]Memory Bank Configuration:[/bold blue]")
pprint(config, expand_all=True)

response = client.memory_banks.register(
memory_bank_id=memory_bank_id,
Expand All @@ -88,6 +126,18 @@ def create(
click.echo(yaml.dump(response.dict()))


@memory_banks.command()
@click.argument("memory-bank-id")
@click.pass_context
@handle_client_errors("delete memory bank")
def unregister(ctx, memory_bank_id: str):
"""Delete a memory bank"""
client = ctx.obj["client"]
client.memory_banks.unregister(memory_bank_id=memory_bank_id)
click.echo(f"Memory bank '{memory_bank_id}' deleted successfully")


# Register subcommands
memory_banks.add_command(list)
memory_banks.add_command(create)
memory_banks.add_command(register)
memory_banks.add_command(unregister)
65 changes: 22 additions & 43 deletions src/llama_stack_client/lib/cli/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from rich.console import Console
from rich.table import Table

from ..common.utils import handle_client_errors


@click.group()
def models():
Expand All @@ -19,6 +21,7 @@ def models():

@click.command(name="list", help="Show available llama models at distribution endpoint")
@click.pass_context
@handle_client_errors("list models")
def list_models(ctx):
client = ctx.obj["client"]
console = Console()
Expand All @@ -43,6 +46,7 @@ def list_models(ctx):
@click.command(name="get")
@click.argument("model_id")
@click.pass_context
@handle_client_errors("get model details")
def get_model(ctx, model_id: str):
"""Show available llama models at distribution endpoint"""
client = ctx.obj["client"]
Expand All @@ -51,9 +55,10 @@ def get_model(ctx, model_id: str):
models_get_response = client.models.retrieve(identifier=model_id)

if not models_get_response:
click.echo(
console.print(
f"Model {model_id} is not found at distribution endpoint. "
"Please ensure endpoint is serving specified model."
"Please ensure endpoint is serving specified model.",
style="bold red",
)
return

Expand All @@ -72,62 +77,36 @@ def get_model(ctx, model_id: str):
@click.option("--provider-model-id", help="Provider's model ID", default=None)
@click.option("--metadata", help="JSON metadata for the model", default=None)
@click.pass_context
@handle_client_errors("register model")
def register_model(
ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str]
):
"""Register a new model at distribution endpoint"""
client = ctx.obj["client"]
console = Console()

try:
response = client.models.register(
model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata
)
if response:
click.echo(f"Successfully registered model {model_id}")
except Exception as e:
click.echo(f"Failed to register model: {str(e)}")


@click.command(name="update", help="Update an existing model at distribution endpoint")
@click.argument("model_id")
@click.option("--provider-id", help="Provider ID for the model", default=None)
@click.option("--provider-model-id", help="Provider's model ID", default=None)
@click.option("--metadata", help="JSON metadata for the model", default=None)
@click.pass_context
def update_model(
ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str]
):
"""Update an existing model at distribution endpoint"""
client = ctx.obj["client"]

try:
response = client.models.update(
model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata
)
if response:
click.echo(f"Successfully updated model {model_id}")
except Exception as e:
click.echo(f"Failed to update model: {str(e)}")
response = client.models.register(
model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata
)
if response:
console.print(f"[green]Successfully registered model {model_id}[/green]")


@click.command(name="delete", help="Delete a model from distribution endpoint")
@click.command(name="unregister", help="Unregister a model from distribution endpoint")
@click.argument("model_id")
@click.pass_context
def delete_model(ctx, model_id: str):
"""Delete a model from distribution endpoint"""
@handle_client_errors("unregister model")
def unregister_model(ctx, model_id: str):
client = ctx.obj["client"]
console = Console()

try:
response = client.models.delete(model_id=model_id)
if response:
click.echo(f"Successfully deleted model {model_id}")
except Exception as e:
click.echo(f"Failed to delete model: {str(e)}")
response = client.models.unregister(model_id=model_id)
if response:
console.print(f"[green]Successfully deleted model {model_id}[/green]")


# Register subcommands
models.add_command(list_models)
models.add_command(get_model)
models.add_command(register_model)
models.add_command(update_model)
models.add_command(delete_model)
models.add_command(unregister_model)
3 changes: 2 additions & 1 deletion src/llama_stack_client/lib/cli/providers/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from rich.console import Console
from rich.table import Table


from ..common.utils import handle_client_errors
@click.command("list")
@click.pass_context
@handle_client_errors("list providers")
def list_providers(ctx):
"""Show available providers on distribution endpoint"""
client = ctx.obj["client"]
Expand Down
3 changes: 3 additions & 0 deletions src/llama_stack_client/lib/cli/scoring_functions/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from rich.console import Console
from rich.table import Table

from ..common.utils import handle_client_errors


@click.command("list")
@click.pass_context
@handle_client_errors("list scoring functions")
def list_scoring_functions(ctx):
"""Show available scoring functions on distribution endpoint"""

Expand Down
4 changes: 4 additions & 0 deletions src/llama_stack_client/lib/cli/shields/shields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from rich.console import Console
from rich.table import Table

from ..common.utils import handle_client_errors


@click.group()
def shields():
Expand All @@ -20,6 +22,7 @@ def shields():

@click.command("list")
@click.pass_context
@handle_client_errors("list shields")
def list(ctx):
"""Show available safety shields on distribution endpoint"""
client = ctx.obj["client"]
Expand All @@ -46,6 +49,7 @@ def list(ctx):
@click.option("--provider-shield-id", help="Provider's shield ID", default=None)
@click.option("--params", type=str, help="JSON configuration parameters for the shield", default=None)
@click.pass_context
@handle_client_errors("register shield")
def register(
ctx,
shield_id: str,
Expand Down