Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/llama_stack_client/lib/cli/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
# the root directory of this source tree.

import click
import yaml
from typing import Optional
import json

from llama_models.llama3.api.datatypes import URL
from .list import list_datasets


Expand All @@ -15,5 +19,49 @@ def datasets():
pass


@datasets.command()
@click.option("--dataset-id", required=True, help="Id of the dataset")
@click.option("--provider-id", help="Provider ID for the dataset", default=None)
@click.option("--provider-dataset-id", help="Provider's dataset ID", default=None)
@click.option("--metadata", type=str, help="Metadata of the dataset")
@click.option("--url", type=str, help="URL of the dataset", required=True)
@click.option("--schema", type=str, help="JSON schema of the dataset", required=True)
@click.pass_context
def register(
ctx,
dataset_id: str,
provider_id: Optional[str],
provider_dataset_id: Optional[str],
metadata: Optional[str],
url: str,
schema: str,
):
"""Create a new dataset"""
client = ctx.obj["client"]

try:
dataset_schema = json.loads(schema)
except json.JSONDecodeError:
raise click.BadParameter("Schema must be valid JSON")

if metadata:
try:
metadata = json.loads(metadata)
except json.JSONDecodeError:
raise click.BadParameter("Metadata must be valid JSON")

response = client.datasets.register(
dataset_id=dataset_id,
dataset_schema=dataset_schema,
url={"uri": url},
provider_id=provider_id,
provider_dataset_id=provider_dataset_id,
metadata=metadata,
)
if response:
click.echo(yaml.dump(response.dict()))


# Register subcommands
datasets.add_command(list_datasets)
datasets.add_command(register)
42 changes: 42 additions & 0 deletions src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@


import click
import json
import yaml
from typing import Optional

from .list import list_eval_tasks

Expand All @@ -16,5 +19,44 @@ def eval_tasks():
pass


@eval_tasks.command()
@click.option("--eval-task-id", required=True, help="ID of the eval task")
@click.option("--dataset-id", required=True, help="ID of the dataset to evaluate")
@click.option("--scoring-functions", required=True, multiple=True, help="Scoring functions to use for evaluation")
@click.option("--provider-id", help="Provider ID for the eval task", default=None)
@click.option("--provider-eval-task-id", help="Provider's eval task ID", default=None)
@click.option("--metadata", type=str, help="Metadata for the eval task in JSON format")
@click.pass_context
def register(
ctx,
eval_task_id: str,
dataset_id: str,
scoring_functions: tuple[str, ...],
provider_id: Optional[str],
provider_eval_task_id: Optional[str],
metadata: Optional[str],
):
"""Register a new eval task"""
client = ctx.obj["client"]

if metadata:
try:
metadata = json.loads(metadata)
except json.JSONDecodeError:
raise click.BadParameter("Metadata must be valid JSON")

response = client.eval_tasks.register(
eval_task_id=eval_task_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
provider_id=provider_id,
provider_eval_task_id=provider_eval_task_id,
metadata=metadata,
)
if response:
click.echo(yaml.dump(response.dict()))


# Register subcommands
eval_tasks.add_command(list_eval_tasks)
eval_tasks.add_command(register)
6 changes: 4 additions & 2 deletions src/llama_stack_client/lib/cli/eval_tasks/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ def list_eval_tasks(ctx):

client = ctx.obj["client"]

headers = ["identifier", "provider_id", "description", "type"]

headers = []
eval_tasks_list_response = client.eval_tasks.list()
if eval_tasks_list_response and len(eval_tasks_list_response) > 0:
headers = sorted(eval_tasks_list_response[0].__dict__.keys())

if eval_tasks_list_response:
print_table_from_response(eval_tasks_list_response, headers)
28 changes: 0 additions & 28 deletions src/llama_stack_client/lib/cli/memory_banks/list.py

This file was deleted.

72 changes: 69 additions & 3 deletions src/llama_stack_client/lib/cli/memory_banks/memory_banks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
# the root directory of this source tree.

import click

from .list import list_memory_banks
from typing import Optional
import yaml
from llama_stack_client.lib.cli.common.utils import print_table_from_response


@click.group()
Expand All @@ -15,5 +16,70 @@ def memory_banks():
pass


@click.command("list")
@click.pass_context
def list(ctx):
"""Show available memory banks on distribution endpoint"""

client = ctx.obj["client"]

memory_banks_list_response = client.memory_banks.list()
headers = []
if memory_banks_list_response and len(memory_banks_list_response) > 0:
headers = sorted(memory_banks_list_response[0].__dict__.keys())

if memory_banks_list_response:
print_table_from_response(memory_banks_list_response, headers)


@memory_banks.command()
@click.option("--memory-bank-id", required=True, help="Id of the memory bank")
@click.option("--type", type=click.Choice(["vector", "keyvalue", "keyword", "graph"]), required=True)
@click.option("--provider-id", help="Provider ID for the memory bank", default=None)
@click.option("--provider-memory-bank-id", help="Provider's memory bank ID", default=None)
@click.option("--chunk-size", type=int, help="Chunk size in tokens (for vector type)", default=512)
@click.option("--embedding-model", type=str, help="Embedding model (for vector type)", default="all-MiniLM-L6-v2")
@click.option("--overlap-size", type=int, help="Overlap size in tokens (for vector type)", default=64)
@click.pass_context
def create(
ctx,
memory_bank_id: str,
type: str,
provider_id: Optional[str],
provider_memory_bank_id: Optional[str],
chunk_size: Optional[int],
embedding_model: Optional[str],
overlap_size: Optional[int],
):
"""Create a new memory bank"""
client = ctx.obj["client"]

config = None
if type == "vector":
config = {
"type": "vector",
"chunk_size_in_tokens": chunk_size,
"embedding_model": embedding_model,
}
if overlap_size:
config["overlap_size_in_tokens"] = overlap_size
elif type == "keyvalue":
config = {"type": "keyvalue"}
elif type == "keyword":
config = {"type": "keyword"}
elif type == "graph":
config = {"type": "graph"}

response = client.memory_banks.register(
memory_bank_id=memory_bank_id,
params=config,
provider_id=provider_id,
provider_memory_bank_id=provider_memory_bank_id,
)
if response:
click.echo(yaml.dump(response.dict()))


# Register subcommands
memory_banks.add_command(list_memory_banks)
memory_banks.add_command(list)
memory_banks.add_command(create)
31 changes: 0 additions & 31 deletions src/llama_stack_client/lib/cli/models/get.py

This file was deleted.

20 changes: 0 additions & 20 deletions src/llama_stack_client/lib/cli/models/list.py

This file was deleted.

62 changes: 60 additions & 2 deletions src/llama_stack_client/lib/cli/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
# the root directory of this source tree.

import click
from llama_stack_client.lib.cli.models.get import get_model
from llama_stack_client.lib.cli.models.list import list_models
from tabulate import tabulate
from llama_stack_client.lib.cli.common.utils import print_table_from_response
from typing import Optional


@click.group()
Expand All @@ -15,6 +16,63 @@ def models():
pass


@click.command(name="list", help="Show available llama models at distribution endpoint")
@click.pass_context
def list_models(ctx):
client = ctx.obj["client"]

headers = ["identifier", "provider_id", "provider_resource_id", "metadata"]
response = client.models.list()
if response:
print_table_from_response(response, headers)


@click.command(name="get")
@click.argument("model_id")
@click.pass_context
def get_model(ctx, model_id: str):
"""Show available llama models at distribution endpoint"""
client = ctx.obj["client"]

models_get_response = client.models.retrieve(identifier=model_id)

if not models_get_response:
click.echo(
f"Model {model_id} is not found at distribution endpoint. "
"Please ensure endpoint is serving specified model."
)
return

headers = sorted(models_get_response.__dict__.keys())
rows = []
rows.append([models_get_response.__dict__[headers[i]] for i in range(len(headers))])

click.echo(tabulate(rows, headers=headers, tablefmt="grid"))


@click.command(name="register", help="Register a new model at distribution endpoint")
@click.argument("model_id")
@click.option("--provider-id", help="Provider ID for the model", default=None)
@click.option("--provider-model-id", help="Provider's model ID", default=None)
@click.option("--metadata", help="JSON metadata for the model", default=None)
@click.pass_context
def register_model(
ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str]
):
"""Register a new model at distribution endpoint"""
client = ctx.obj["client"]

try:
response = client.models.register(
model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata
)
if response:
click.echo(f"Successfully registered model {model_id}")
except Exception as e:
click.echo(f"Failed to register model: {str(e)}")


# Register subcommands
models.add_command(list_models)
models.add_command(get_model)
models.add_command(register_model)
Loading