diff --git a/src/llama_stack_client/lib/cli/datasets/datasets.py b/src/llama_stack_client/lib/cli/datasets/datasets.py index c37a7aa5..d3691000 100644 --- a/src/llama_stack_client/lib/cli/datasets/datasets.py +++ b/src/llama_stack_client/lib/cli/datasets/datasets.py @@ -5,7 +5,11 @@ # the root directory of this source tree. import click +import yaml +from typing import Optional +import json +from llama_models.llama3.api.datatypes import URL from .list import list_datasets @@ -15,5 +19,49 @@ def datasets(): pass +@datasets.command() +@click.option("--dataset-id", required=True, help="Id of the dataset") +@click.option("--provider-id", help="Provider ID for the dataset", default=None) +@click.option("--provider-dataset-id", help="Provider's dataset ID", default=None) +@click.option("--metadata", type=str, help="Metadata of the dataset") +@click.option("--url", type=str, help="URL of the dataset", required=True) +@click.option("--schema", type=str, help="JSON schema of the dataset", required=True) +@click.pass_context +def register( + ctx, + dataset_id: str, + provider_id: Optional[str], + provider_dataset_id: Optional[str], + metadata: Optional[str], + url: str, + schema: str, +): + """Create a new dataset""" + client = ctx.obj["client"] + + try: + dataset_schema = json.loads(schema) + except json.JSONDecodeError: + raise click.BadParameter("Schema must be valid JSON") + + if metadata: + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + raise click.BadParameter("Metadata must be valid JSON") + + response = client.datasets.register( + dataset_id=dataset_id, + dataset_schema=dataset_schema, + url={"uri": url}, + provider_id=provider_id, + provider_dataset_id=provider_dataset_id, + metadata=metadata, + ) + if response: + click.echo(yaml.dump(response.dict())) + + # Register subcommands datasets.add_command(list_datasets) +datasets.add_command(register) diff --git a/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py b/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py index 1983d782..e45c9802 100644 --- a/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py +++ b/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py @@ -6,6 +6,9 @@ import click +import json +import yaml +from typing import Optional from .list import list_eval_tasks @@ -16,5 +19,44 @@ def eval_tasks(): pass +@eval_tasks.command() +@click.option("--eval-task-id", required=True, help="ID of the eval task") +@click.option("--dataset-id", required=True, help="ID of the dataset to evaluate") +@click.option("--scoring-functions", required=True, multiple=True, help="Scoring functions to use for evaluation") +@click.option("--provider-id", help="Provider ID for the eval task", default=None) +@click.option("--provider-eval-task-id", help="Provider's eval task ID", default=None) +@click.option("--metadata", type=str, help="Metadata for the eval task in JSON format") +@click.pass_context +def register( + ctx, + eval_task_id: str, + dataset_id: str, + scoring_functions: tuple[str, ...], + provider_id: Optional[str], + provider_eval_task_id: Optional[str], + metadata: Optional[str], +): + """Register a new eval task""" + client = ctx.obj["client"] + + if metadata: + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + raise click.BadParameter("Metadata must be valid JSON") + + response = client.eval_tasks.register( + eval_task_id=eval_task_id, + dataset_id=dataset_id, + scoring_functions=scoring_functions, + provider_id=provider_id, + provider_eval_task_id=provider_eval_task_id, + metadata=metadata, + ) + if response: + click.echo(yaml.dump(response.dict())) + + # Register subcommands eval_tasks.add_command(list_eval_tasks) +eval_tasks.add_command(register) diff --git a/src/llama_stack_client/lib/cli/eval_tasks/list.py b/src/llama_stack_client/lib/cli/eval_tasks/list.py index 88585a4b..e0a1418f 100644 --- a/src/llama_stack_client/lib/cli/eval_tasks/list.py +++ b/src/llama_stack_client/lib/cli/eval_tasks/list.py @@ -16,8 +16,10 @@ def list_eval_tasks(ctx): client = ctx.obj["client"] - headers = ["identifier", "provider_id", "description", "type"] - + headers = [] eval_tasks_list_response = client.eval_tasks.list() + if eval_tasks_list_response and len(eval_tasks_list_response) > 0: + headers = sorted(eval_tasks_list_response[0].__dict__.keys()) + if eval_tasks_list_response: print_table_from_response(eval_tasks_list_response, headers) diff --git a/src/llama_stack_client/lib/cli/memory_banks/list.py b/src/llama_stack_client/lib/cli/memory_banks/list.py deleted file mode 100644 index a44d7de0..00000000 --- a/src/llama_stack_client/lib/cli/memory_banks/list.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import click - -from llama_stack_client.lib.cli.common.utils import print_table_from_response - - -@click.command("list") -@click.pass_context -def list_memory_banks(ctx): - """Show available memory banks on distribution endpoint""" - - client = ctx.obj["client"] - - headers = [ - "identifier", - "provider_id", - "description", - "type", - ] - - memory_banks_list_response = client.memory_banks.list() - if memory_banks_list_response: - print_table_from_response(memory_banks_list_response, headers) diff --git a/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py b/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py index 088d647e..4361a1c8 100644 --- a/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py +++ b/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py @@ -5,8 +5,9 @@ # the root directory of this source tree. import click - -from .list import list_memory_banks +from typing import Optional +import yaml +from llama_stack_client.lib.cli.common.utils import print_table_from_response @click.group() @@ -15,5 +16,70 @@ def memory_banks(): pass +@click.command("list") +@click.pass_context +def list(ctx): + """Show available memory banks on distribution endpoint""" + + client = ctx.obj["client"] + + memory_banks_list_response = client.memory_banks.list() + headers = [] + if memory_banks_list_response and len(memory_banks_list_response) > 0: + headers = sorted(memory_banks_list_response[0].__dict__.keys()) + + if memory_banks_list_response: + print_table_from_response(memory_banks_list_response, headers) + + +@memory_banks.command() +@click.option("--memory-bank-id", required=True, help="Id of the memory bank") +@click.option("--type", type=click.Choice(["vector", "keyvalue", "keyword", "graph"]), required=True) +@click.option("--provider-id", help="Provider ID for the memory bank", default=None) +@click.option("--provider-memory-bank-id", help="Provider's memory bank ID", default=None) +@click.option("--chunk-size", type=int, help="Chunk size in tokens (for vector type)", default=512) +@click.option("--embedding-model", type=str, help="Embedding model (for vector type)", default="all-MiniLM-L6-v2") +@click.option("--overlap-size", type=int, help="Overlap size in tokens (for vector type)", default=64) +@click.pass_context +def create( + ctx, + memory_bank_id: str, + type: str, + provider_id: Optional[str], + provider_memory_bank_id: Optional[str], + chunk_size: Optional[int], + embedding_model: Optional[str], + overlap_size: Optional[int], +): + """Create a new memory bank""" + client = ctx.obj["client"] + + config = None + if type == "vector": + config = { + "type": "vector", + "chunk_size_in_tokens": chunk_size, + "embedding_model": embedding_model, + } + if overlap_size: + config["overlap_size_in_tokens"] = overlap_size + elif type == "keyvalue": + config = {"type": "keyvalue"} + elif type == "keyword": + config = {"type": "keyword"} + elif type == "graph": + config = {"type": "graph"} + + response = client.memory_banks.register( + memory_bank_id=memory_bank_id, + params=config, + provider_id=provider_id, + provider_memory_bank_id=provider_memory_bank_id, + ) + if response: + click.echo(yaml.dump(response.dict())) + + # Register subcommands -memory_banks.add_command(list_memory_banks) +memory_banks.add_command(list) +memory_banks.add_command(create) diff --git a/src/llama_stack_client/lib/cli/models/get.py b/src/llama_stack_client/lib/cli/models/get.py deleted file mode 100644 index c242fb22..00000000 --- a/src/llama_stack_client/lib/cli/models/get.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import click -from tabulate import tabulate - - -@click.command(name="get") -@click.argument("model_id") -@click.pass_context -def get_model(ctx, model_id: str): - """Show available llama models at distribution endpoint""" - client = ctx.obj["client"] - - models_get_response = client.models.retrieve(identifier=model_id) - - if not models_get_response: - click.echo( - f"Model {model_id} is not found at distribution endpoint. " - "Please ensure endpoint is serving specified model." - ) - return - - headers = sorted(models_get_response.__dict__.keys()) - rows = [] - rows.append([models_get_response.__dict__[headers[i]] for i in range(len(headers))]) - - click.echo(tabulate(rows, headers=headers, tablefmt="grid")) diff --git a/src/llama_stack_client/lib/cli/models/list.py b/src/llama_stack_client/lib/cli/models/list.py deleted file mode 100644 index 7d335224..00000000 --- a/src/llama_stack_client/lib/cli/models/list.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import click - -from llama_stack_client.lib.cli.common.utils import print_table_from_response - - -@click.command(name="list", help="Show available llama models at distribution endpoint") -@click.pass_context -def list_models(ctx): - client = ctx.obj["client"] - - headers = ["identifier", "provider_id", "provider_resource_id", "metadata"] - response = client.models.list() - if response: - print_table_from_response(response, headers) diff --git a/src/llama_stack_client/lib/cli/models/models.py b/src/llama_stack_client/lib/cli/models/models.py index 1655429a..b81a84dc 100644 --- a/src/llama_stack_client/lib/cli/models/models.py +++ b/src/llama_stack_client/lib/cli/models/models.py @@ -5,8 +5,9 @@ # the root directory of this source tree. import click -from llama_stack_client.lib.cli.models.get import get_model -from llama_stack_client.lib.cli.models.list import list_models +from tabulate import tabulate +from llama_stack_client.lib.cli.common.utils import print_table_from_response +from typing import Optional @click.group() @@ -15,6 +16,63 @@ def models(): pass +@click.command(name="list", help="Show available llama models at distribution endpoint") +@click.pass_context +def list_models(ctx): + client = ctx.obj["client"] + + headers = ["identifier", "provider_id", "provider_resource_id", "metadata"] + response = client.models.list() + if response: + print_table_from_response(response, headers) + + +@click.command(name="get") +@click.argument("model_id") +@click.pass_context +def get_model(ctx, model_id: str): + """Show available llama models at distribution endpoint""" + client = ctx.obj["client"] + + models_get_response = client.models.retrieve(identifier=model_id) + + if not models_get_response: + click.echo( + f"Model {model_id} is not found at distribution endpoint. " + "Please ensure endpoint is serving specified model." + ) + return + + headers = sorted(models_get_response.__dict__.keys()) + rows = [] + rows.append([models_get_response.__dict__[headers[i]] for i in range(len(headers))]) + + click.echo(tabulate(rows, headers=headers, tablefmt="grid")) + + +@click.command(name="register", help="Register a new model at distribution endpoint") +@click.argument("model_id") +@click.option("--provider-id", help="Provider ID for the model", default=None) +@click.option("--provider-model-id", help="Provider's model ID", default=None) +@click.option("--metadata", help="JSON metadata for the model", default=None) +@click.pass_context +def register_model( + ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str] +): + """Register a new model at distribution endpoint""" + client = ctx.obj["client"] + + try: + response = client.models.register( + model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata + ) + if response: + click.echo(f"Successfully registered model {model_id}") + except Exception as e: + click.echo(f"Failed to register model: {str(e)}") + + # Register subcommands models.add_command(list_models) models.add_command(get_model) +models.add_command(register_model) diff --git a/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py b/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py index 4dba3fb5..e363a5f9 100644 --- a/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py +++ b/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py @@ -5,6 +5,9 @@ # the root directory of this source tree. import click +import yaml +from typing import Optional +import json from .list import list_scoring_functions @@ -15,5 +18,44 @@ def scoring_functions(): pass +@scoring_functions.command() +@click.option("--scoring-fn-id", required=True, help="Id of the scoring function") +@click.option("--description", required=True, help="Description of the scoring function") +@click.option("--return-type", type=str, required=True, help="Return type of the scoring function") +@click.option("--provider-id", type=str, help="Provider ID for the scoring function", default=None) +@click.option("--provider-scoring-fn-id", type=str, help="Provider's scoring function ID", default=None) +@click.option("--params", type=str, help="Parameters for the scoring function in JSON format", default=None) +@click.pass_context +def register( + ctx, + scoring_fn_id: str, + description: str, + return_type: str, + provider_id: Optional[str], + provider_scoring_fn_id: Optional[str], + params: Optional[str], +): + """Register a new scoring function""" + client = ctx.obj["client"] + + if params: + try: + params = json.loads(params) + except json.JSONDecodeError: + raise click.BadParameter("Parameters must be valid JSON") + + response = client.scoring_functions.register( + scoring_fn_id=scoring_fn_id, + description=description, + return_type=json.loads(return_type), + provider_id=provider_id, + provider_scoring_fn_id=provider_scoring_fn_id, + params=params, + ) + if response: + click.echo(yaml.dump(response.dict())) + + # Register subcommands scoring_functions.add_command(list_scoring_functions) +scoring_functions.add_command(register) diff --git a/src/llama_stack_client/lib/cli/shields/list.py b/src/llama_stack_client/lib/cli/shields/list.py deleted file mode 100644 index ce112204..00000000 --- a/src/llama_stack_client/lib/cli/shields/list.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import click - -from llama_stack_client.lib.cli.common.utils import print_table_from_response - - -@click.command("list") -@click.pass_context -def list_shields(ctx): - """Show available safety shields on distribution endpoint""" - - client = ctx.obj["client"] - - headers = ["identifier", "provider_id", "description", "type"] - - shields_list_response = client.shields.list() - if shields_list_response: - print_table_from_response(shields_list_response, headers) diff --git a/src/llama_stack_client/lib/cli/shields/shields.py b/src/llama_stack_client/lib/cli/shields/shields.py index a4a0373b..574a2dbe 100644 --- a/src/llama_stack_client/lib/cli/shields/shields.py +++ b/src/llama_stack_client/lib/cli/shields/shields.py @@ -5,8 +5,9 @@ # the root directory of this source tree. import click - -from .list import list_shields +from typing import Optional +import yaml +from llama_stack_client.lib.cli.common.utils import print_table_from_response @click.group() @@ -15,5 +16,47 @@ def shields(): pass +@click.command("list") +@click.pass_context +def list(ctx): + """Show available safety shields on distribution endpoint""" + client = ctx.obj["client"] + + shields_list_response = client.shields.list() + headers = [] + if shields_list_response and len(shields_list_response) > 0: + headers = sorted(shields_list_response[0].__dict__.keys()) + + if shields_list_response: + print_table_from_response(shields_list_response, headers) + + +@shields.command() +@click.option("--shield-id", required=True, help="Id of the shield") +@click.option("--provider-id", help="Provider ID for the shield", default=None) +@click.option("--provider-shield-id", help="Provider's shield ID", default=None) +@click.option("--params", type=str, help="JSON configuration parameters for the shield", default=None) +@click.pass_context +def register( + ctx, + shield_id: str, + provider_id: Optional[str], + provider_shield_id: Optional[str], + params: Optional[str], +): + """Register a new safety shield""" + client = ctx.obj["client"] + + response = client.shields.register( + shield_id=shield_id, + params=params, + provider_id=provider_id, + provider_shield_id=provider_shield_id, + ) + if response: + click.echo(yaml.dump(response.dict())) + + # Register subcommands -shields.add_command(list_shields) +shields.add_command(list) +shields.add_command(register)