Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions src/llama_stack_client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@
class LlamaStackClient(SyncAPIClient):
agents: resources.AgentsResource
batch_inferences: resources.BatchInferencesResource
datasets: resources.DatasetsResource
eval: resources.EvalResource
inspect: resources.InspectResource
inference: resources.InferenceResource
memory: resources.MemoryResource
memory_banks: resources.MemoryBanksResource
datasets: resources.DatasetsResource
models: resources.ModelsResource
post_training: resources.PostTrainingResource
providers: resources.ProvidersResource
Expand All @@ -51,7 +52,7 @@ class LlamaStackClient(SyncAPIClient):
datasetio: resources.DatasetioResource
scoring: resources.ScoringResource
scoring_functions: resources.ScoringFunctionsResource
eval: resources.EvalResource
eval_tasks: resources.EvalTasksResource
with_raw_response: LlamaStackClientWithRawResponse
with_streaming_response: LlamaStackClientWithStreamedResponse

Expand Down Expand Up @@ -85,7 +86,6 @@ def __init__(
base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL")
if base_url is None:
base_url = f"http://any-hosted-llama-stack.com"

if provider_data:
if default_headers is None:
default_headers = {}
Expand All @@ -104,11 +104,12 @@ def __init__(

self.agents = resources.AgentsResource(self)
self.batch_inferences = resources.BatchInferencesResource(self)
self.datasets = resources.DatasetsResource(self)
self.eval = resources.EvalResource(self)
self.inspect = resources.InspectResource(self)
self.inference = resources.InferenceResource(self)
self.memory = resources.MemoryResource(self)
self.memory_banks = resources.MemoryBanksResource(self)
self.datasets = resources.DatasetsResource(self)
self.models = resources.ModelsResource(self)
self.post_training = resources.PostTrainingResource(self)
self.providers = resources.ProvidersResource(self)
Expand All @@ -120,7 +121,7 @@ def __init__(
self.datasetio = resources.DatasetioResource(self)
self.scoring = resources.ScoringResource(self)
self.scoring_functions = resources.ScoringFunctionsResource(self)
self.eval = resources.EvalResource(self)
self.eval_tasks = resources.EvalTasksResource(self)
self.with_raw_response = LlamaStackClientWithRawResponse(self)
self.with_streaming_response = LlamaStackClientWithStreamedResponse(self)

Expand Down Expand Up @@ -224,11 +225,12 @@ def _make_status_error(
class AsyncLlamaStackClient(AsyncAPIClient):
agents: resources.AsyncAgentsResource
batch_inferences: resources.AsyncBatchInferencesResource
datasets: resources.AsyncDatasetsResource
eval: resources.AsyncEvalResource
inspect: resources.AsyncInspectResource
inference: resources.AsyncInferenceResource
memory: resources.AsyncMemoryResource
memory_banks: resources.AsyncMemoryBanksResource
datasets: resources.AsyncDatasetsResource
models: resources.AsyncModelsResource
post_training: resources.AsyncPostTrainingResource
providers: resources.AsyncProvidersResource
Expand All @@ -240,7 +242,7 @@ class AsyncLlamaStackClient(AsyncAPIClient):
datasetio: resources.AsyncDatasetioResource
scoring: resources.AsyncScoringResource
scoring_functions: resources.AsyncScoringFunctionsResource
eval: resources.AsyncEvalResource
eval_tasks: resources.AsyncEvalTasksResource
with_raw_response: AsyncLlamaStackClientWithRawResponse
with_streaming_response: AsyncLlamaStackClientWithStreamedResponse

Expand Down Expand Up @@ -293,11 +295,12 @@ def __init__(

self.agents = resources.AsyncAgentsResource(self)
self.batch_inferences = resources.AsyncBatchInferencesResource(self)
self.datasets = resources.AsyncDatasetsResource(self)
self.eval = resources.AsyncEvalResource(self)
self.inspect = resources.AsyncInspectResource(self)
self.inference = resources.AsyncInferenceResource(self)
self.memory = resources.AsyncMemoryResource(self)
self.memory_banks = resources.AsyncMemoryBanksResource(self)
self.datasets = resources.AsyncDatasetsResource(self)
self.models = resources.AsyncModelsResource(self)
self.post_training = resources.AsyncPostTrainingResource(self)
self.providers = resources.AsyncProvidersResource(self)
Expand All @@ -309,7 +312,7 @@ def __init__(
self.datasetio = resources.AsyncDatasetioResource(self)
self.scoring = resources.AsyncScoringResource(self)
self.scoring_functions = resources.AsyncScoringFunctionsResource(self)
self.eval = resources.AsyncEvalResource(self)
self.eval_tasks = resources.AsyncEvalTasksResource(self)
self.with_raw_response = AsyncLlamaStackClientWithRawResponse(self)
self.with_streaming_response = AsyncLlamaStackClientWithStreamedResponse(self)

Expand Down Expand Up @@ -414,11 +417,12 @@ class LlamaStackClientWithRawResponse:
def __init__(self, client: LlamaStackClient) -> None:
self.agents = resources.AgentsResourceWithRawResponse(client.agents)
self.batch_inferences = resources.BatchInferencesResourceWithRawResponse(client.batch_inferences)
self.datasets = resources.DatasetsResourceWithRawResponse(client.datasets)
self.eval = resources.EvalResourceWithRawResponse(client.eval)
self.inspect = resources.InspectResourceWithRawResponse(client.inspect)
self.inference = resources.InferenceResourceWithRawResponse(client.inference)
self.memory = resources.MemoryResourceWithRawResponse(client.memory)
self.memory_banks = resources.MemoryBanksResourceWithRawResponse(client.memory_banks)
self.datasets = resources.DatasetsResourceWithRawResponse(client.datasets)
self.models = resources.ModelsResourceWithRawResponse(client.models)
self.post_training = resources.PostTrainingResourceWithRawResponse(client.post_training)
self.providers = resources.ProvidersResourceWithRawResponse(client.providers)
Expand All @@ -432,18 +436,19 @@ def __init__(self, client: LlamaStackClient) -> None:
self.datasetio = resources.DatasetioResourceWithRawResponse(client.datasetio)
self.scoring = resources.ScoringResourceWithRawResponse(client.scoring)
self.scoring_functions = resources.ScoringFunctionsResourceWithRawResponse(client.scoring_functions)
self.eval = resources.EvalResourceWithRawResponse(client.eval)
self.eval_tasks = resources.EvalTasksResourceWithRawResponse(client.eval_tasks)


class AsyncLlamaStackClientWithRawResponse:
def __init__(self, client: AsyncLlamaStackClient) -> None:
self.agents = resources.AsyncAgentsResourceWithRawResponse(client.agents)
self.batch_inferences = resources.AsyncBatchInferencesResourceWithRawResponse(client.batch_inferences)
self.datasets = resources.AsyncDatasetsResourceWithRawResponse(client.datasets)
self.eval = resources.AsyncEvalResourceWithRawResponse(client.eval)
self.inspect = resources.AsyncInspectResourceWithRawResponse(client.inspect)
self.inference = resources.AsyncInferenceResourceWithRawResponse(client.inference)
self.memory = resources.AsyncMemoryResourceWithRawResponse(client.memory)
self.memory_banks = resources.AsyncMemoryBanksResourceWithRawResponse(client.memory_banks)
self.datasets = resources.AsyncDatasetsResourceWithRawResponse(client.datasets)
self.models = resources.AsyncModelsResourceWithRawResponse(client.models)
self.post_training = resources.AsyncPostTrainingResourceWithRawResponse(client.post_training)
self.providers = resources.AsyncProvidersResourceWithRawResponse(client.providers)
Expand All @@ -457,18 +462,19 @@ def __init__(self, client: AsyncLlamaStackClient) -> None:
self.datasetio = resources.AsyncDatasetioResourceWithRawResponse(client.datasetio)
self.scoring = resources.AsyncScoringResourceWithRawResponse(client.scoring)
self.scoring_functions = resources.AsyncScoringFunctionsResourceWithRawResponse(client.scoring_functions)
self.eval = resources.AsyncEvalResourceWithRawResponse(client.eval)
self.eval_tasks = resources.AsyncEvalTasksResourceWithRawResponse(client.eval_tasks)


class LlamaStackClientWithStreamedResponse:
def __init__(self, client: LlamaStackClient) -> None:
self.agents = resources.AgentsResourceWithStreamingResponse(client.agents)
self.batch_inferences = resources.BatchInferencesResourceWithStreamingResponse(client.batch_inferences)
self.datasets = resources.DatasetsResourceWithStreamingResponse(client.datasets)
self.eval = resources.EvalResourceWithStreamingResponse(client.eval)
self.inspect = resources.InspectResourceWithStreamingResponse(client.inspect)
self.inference = resources.InferenceResourceWithStreamingResponse(client.inference)
self.memory = resources.MemoryResourceWithStreamingResponse(client.memory)
self.memory_banks = resources.MemoryBanksResourceWithStreamingResponse(client.memory_banks)
self.datasets = resources.DatasetsResourceWithStreamingResponse(client.datasets)
self.models = resources.ModelsResourceWithStreamingResponse(client.models)
self.post_training = resources.PostTrainingResourceWithStreamingResponse(client.post_training)
self.providers = resources.ProvidersResourceWithStreamingResponse(client.providers)
Expand All @@ -482,18 +488,19 @@ def __init__(self, client: LlamaStackClient) -> None:
self.datasetio = resources.DatasetioResourceWithStreamingResponse(client.datasetio)
self.scoring = resources.ScoringResourceWithStreamingResponse(client.scoring)
self.scoring_functions = resources.ScoringFunctionsResourceWithStreamingResponse(client.scoring_functions)
self.eval = resources.EvalResourceWithStreamingResponse(client.eval)
self.eval_tasks = resources.EvalTasksResourceWithStreamingResponse(client.eval_tasks)


class AsyncLlamaStackClientWithStreamedResponse:
def __init__(self, client: AsyncLlamaStackClient) -> None:
self.agents = resources.AsyncAgentsResourceWithStreamingResponse(client.agents)
self.batch_inferences = resources.AsyncBatchInferencesResourceWithStreamingResponse(client.batch_inferences)
self.datasets = resources.AsyncDatasetsResourceWithStreamingResponse(client.datasets)
self.eval = resources.AsyncEvalResourceWithStreamingResponse(client.eval)
self.inspect = resources.AsyncInspectResourceWithStreamingResponse(client.inspect)
self.inference = resources.AsyncInferenceResourceWithStreamingResponse(client.inference)
self.memory = resources.AsyncMemoryResourceWithStreamingResponse(client.memory)
self.memory_banks = resources.AsyncMemoryBanksResourceWithStreamingResponse(client.memory_banks)
self.datasets = resources.AsyncDatasetsResourceWithStreamingResponse(client.datasets)
self.models = resources.AsyncModelsResourceWithStreamingResponse(client.models)
self.post_training = resources.AsyncPostTrainingResourceWithStreamingResponse(client.post_training)
self.providers = resources.AsyncProvidersResourceWithStreamingResponse(client.providers)
Expand All @@ -507,7 +514,7 @@ def __init__(self, client: AsyncLlamaStackClient) -> None:
self.datasetio = resources.AsyncDatasetioResourceWithStreamingResponse(client.datasetio)
self.scoring = resources.AsyncScoringResourceWithStreamingResponse(client.scoring)
self.scoring_functions = resources.AsyncScoringFunctionsResourceWithStreamingResponse(client.scoring_functions)
self.eval = resources.AsyncEvalResourceWithStreamingResponse(client.eval)
self.eval_tasks = resources.AsyncEvalTasksResourceWithStreamingResponse(client.eval_tasks)


Client = LlamaStackClient
Expand Down
6 changes: 4 additions & 2 deletions src/llama_stack_client/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
from datetime import date, datetime
from typing_extensions import Self
from typing_extensions import Self, Literal

import pydantic
from pydantic.fields import FieldInfo
Expand Down Expand Up @@ -137,9 +137,11 @@ def model_dump(
exclude_unset: bool = False,
exclude_defaults: bool = False,
warnings: bool = True,
mode: Literal["json", "python"] = "python",
) -> dict[str, Any]:
if PYDANTIC_V2:
if PYDANTIC_V2 or hasattr(model, "model_dump"):
return model.model_dump(
mode=mode,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
Expand Down
9 changes: 6 additions & 3 deletions src/llama_stack_client/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
PropertyInfo,
is_list,
is_given,
json_safe,
lru_cache,
is_mapping,
parse_date,
Expand Down Expand Up @@ -279,8 +280,8 @@ def model_dump(
Returns:
A dictionary representation of the model.
"""
if mode != "python":
raise ValueError("mode is only supported in Pydantic v2")
if mode not in {"json", "python"}:
raise ValueError("mode must be either 'json' or 'python'")
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
Expand All @@ -289,7 +290,7 @@ def model_dump(
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
return super().dict( # pyright: ignore[reportDeprecated]
dumped = super().dict( # pyright: ignore[reportDeprecated]
include=include,
exclude=exclude,
by_alias=by_alias,
Expand All @@ -298,6 +299,8 @@ def model_dump(
exclude_none=exclude_none,
)

return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped

@override
def model_dump_json(
self,
Expand Down
1 change: 1 addition & 0 deletions src/llama_stack_client/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
is_list as is_list,
is_given as is_given,
is_tuple as is_tuple,
json_safe as json_safe,
lru_cache as lru_cache,
is_mapping as is_mapping,
is_tuple_t as is_tuple_t,
Expand Down
9 changes: 7 additions & 2 deletions src/llama_stack_client/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ def _transform_recursive(
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
if isinstance(data, dict):
return cast(object, data)

inner_type = extract_type_arg(stripped_type, 0)
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]

Expand All @@ -186,7 +191,7 @@ def _transform_recursive(
return data

if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)
return model_dump(data, exclude_unset=True, mode="json")

annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
Expand Down Expand Up @@ -324,7 +329,7 @@ async def _async_transform_recursive(
return data

if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)
return model_dump(data, exclude_unset=True, mode="json")

annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
Expand Down
17 changes: 17 additions & 0 deletions src/llama_stack_client/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
overload,
)
from pathlib import Path
from datetime import date, datetime
from typing_extensions import TypeGuard

import sniffio
Expand Down Expand Up @@ -395,3 +396,19 @@ def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
maxsize=maxsize,
)
return cast(Any, wrapper) # type: ignore[no-any-return]


def json_safe(data: object) -> object:
"""Translates a mapping / sequence recursively in the same fashion
as `pydantic` v2's `model_dump(mode="json")`.
"""
if is_mapping(data):
return {json_safe(key): json_safe(value) for key, value in data.items()}

if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
return [json_safe(item) for item in data]

if isinstance(data, (datetime, date)):
return data.isoformat()

return data
9 changes: 9 additions & 0 deletions src/llama_stack_client/lib/cli/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .datasets import DatasetsParser

__all__ = ["DatasetsParser"]
27 changes: 27 additions & 0 deletions src/llama_stack_client/lib/cli/datasets/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import argparse

from llama_stack_client.lib.cli.subcommand import Subcommand
from .list import DatasetsList


class DatasetsParser(Subcommand):
"""Parser for datasets commands"""

@classmethod
def create(cls, subparsers: argparse._SubParsersAction):
parser = subparsers.add_parser(
"datasets",
help="Manage datasets",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.set_defaults(func=lambda _: parser.print_help())

# Create subcommands
datasets_subparsers = parser.add_subparsers(title="subcommands")
DatasetsList(datasets_subparsers)
45 changes: 45 additions & 0 deletions src/llama_stack_client/lib/cli/datasets/list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import argparse

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.cli.common.utils import print_table_from_response
from llama_stack_client.lib.cli.configure import get_config
from llama_stack_client.lib.cli.subcommand import Subcommand


class DatasetsList(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama-stack-client datasets list",
description="Show available datasets on distribution endpoint",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_datasets_list_cmd)

def _add_arguments(self):
self.parser.add_argument(
"--endpoint",
type=str,
help="Llama Stack distribution endpoint",
)

def _run_datasets_list_cmd(self, args: argparse.Namespace):
args.endpoint = get_config().get("endpoint") or args.endpoint

client = LlamaStackClient(
base_url=args.endpoint,
)

headers = ["identifier", "provider_id", "metadata", "type"]

datasets_list_response = client.datasets.list()
if datasets_list_response:
print_table_from_response(datasets_list_response, headers)
Loading