Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions src/llama_stack_client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
benchmarks,
toolgroups,
vector_dbs,
batch_inference,
scoring_functions,
synthetic_data_generation,
)
Expand Down Expand Up @@ -74,7 +73,6 @@ class LlamaStackClient(SyncAPIClient):
tools: tools.ToolsResource
tool_runtime: tool_runtime.ToolRuntimeResource
agents: agents.AgentsResource
batch_inference: batch_inference.BatchInferenceResource
datasets: datasets.DatasetsResource
eval: eval.EvalResource
inspect: inspect.InspectResource
Expand Down Expand Up @@ -155,7 +153,6 @@ def __init__(
self.tools = tools.ToolsResource(self)
self.tool_runtime = tool_runtime.ToolRuntimeResource(self)
self.agents = agents.AgentsResource(self)
self.batch_inference = batch_inference.BatchInferenceResource(self)
self.datasets = datasets.DatasetsResource(self)
self.eval = eval.EvalResource(self)
self.inspect = inspect.InspectResource(self)
Expand Down Expand Up @@ -288,7 +285,6 @@ class AsyncLlamaStackClient(AsyncAPIClient):
tools: tools.AsyncToolsResource
tool_runtime: tool_runtime.AsyncToolRuntimeResource
agents: agents.AsyncAgentsResource
batch_inference: batch_inference.AsyncBatchInferenceResource
datasets: datasets.AsyncDatasetsResource
eval: eval.AsyncEvalResource
inspect: inspect.AsyncInspectResource
Expand Down Expand Up @@ -369,7 +365,6 @@ def __init__(
self.tools = tools.AsyncToolsResource(self)
self.tool_runtime = tool_runtime.AsyncToolRuntimeResource(self)
self.agents = agents.AsyncAgentsResource(self)
self.batch_inference = batch_inference.AsyncBatchInferenceResource(self)
self.datasets = datasets.AsyncDatasetsResource(self)
self.eval = eval.AsyncEvalResource(self)
self.inspect = inspect.AsyncInspectResource(self)
Expand Down Expand Up @@ -503,7 +498,6 @@ def __init__(self, client: LlamaStackClient) -> None:
self.tools = tools.ToolsResourceWithRawResponse(client.tools)
self.tool_runtime = tool_runtime.ToolRuntimeResourceWithRawResponse(client.tool_runtime)
self.agents = agents.AgentsResourceWithRawResponse(client.agents)
self.batch_inference = batch_inference.BatchInferenceResourceWithRawResponse(client.batch_inference)
self.datasets = datasets.DatasetsResourceWithRawResponse(client.datasets)
self.eval = eval.EvalResourceWithRawResponse(client.eval)
self.inspect = inspect.InspectResourceWithRawResponse(client.inspect)
Expand Down Expand Up @@ -531,7 +525,6 @@ def __init__(self, client: AsyncLlamaStackClient) -> None:
self.tools = tools.AsyncToolsResourceWithRawResponse(client.tools)
self.tool_runtime = tool_runtime.AsyncToolRuntimeResourceWithRawResponse(client.tool_runtime)
self.agents = agents.AsyncAgentsResourceWithRawResponse(client.agents)
self.batch_inference = batch_inference.AsyncBatchInferenceResourceWithRawResponse(client.batch_inference)
self.datasets = datasets.AsyncDatasetsResourceWithRawResponse(client.datasets)
self.eval = eval.AsyncEvalResourceWithRawResponse(client.eval)
self.inspect = inspect.AsyncInspectResourceWithRawResponse(client.inspect)
Expand Down Expand Up @@ -561,7 +554,6 @@ def __init__(self, client: LlamaStackClient) -> None:
self.tools = tools.ToolsResourceWithStreamingResponse(client.tools)
self.tool_runtime = tool_runtime.ToolRuntimeResourceWithStreamingResponse(client.tool_runtime)
self.agents = agents.AgentsResourceWithStreamingResponse(client.agents)
self.batch_inference = batch_inference.BatchInferenceResourceWithStreamingResponse(client.batch_inference)
self.datasets = datasets.DatasetsResourceWithStreamingResponse(client.datasets)
self.eval = eval.EvalResourceWithStreamingResponse(client.eval)
self.inspect = inspect.InspectResourceWithStreamingResponse(client.inspect)
Expand Down Expand Up @@ -591,7 +583,6 @@ def __init__(self, client: AsyncLlamaStackClient) -> None:
self.tools = tools.AsyncToolsResourceWithStreamingResponse(client.tools)
self.tool_runtime = tool_runtime.AsyncToolRuntimeResourceWithStreamingResponse(client.tool_runtime)
self.agents = agents.AsyncAgentsResourceWithStreamingResponse(client.agents)
self.batch_inference = batch_inference.AsyncBatchInferenceResourceWithStreamingResponse(client.batch_inference)
self.datasets = datasets.AsyncDatasetsResourceWithStreamingResponse(client.datasets)
self.eval = eval.AsyncEvalResourceWithStreamingResponse(client.eval)
self.inspect = inspect.AsyncInspectResourceWithStreamingResponse(client.inspect)
Expand Down
123 changes: 0 additions & 123 deletions src/llama_stack_client/_decoders/jsonl.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/llama_stack_client/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None:
setattr(typ, "__pydantic_config__", config) # noqa: B010


# our use of subclasssing here causes weirdness for type checkers,
# our use of subclassing here causes weirdness for type checkers,
# so we just pretend that we don't subclass
if TYPE_CHECKING:
GenericModel = BaseModel
Expand Down
22 changes: 0 additions & 22 deletions src/llama_stack_client/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
from ._exceptions import LlamaStackClientError, APIResponseValidationError
from ._decoders.jsonl import JSONLDecoder, AsyncJSONLDecoder

if TYPE_CHECKING:
from ._models import FinalRequestOptions
Expand Down Expand Up @@ -139,27 +138,6 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:

origin = get_origin(cast_to) or cast_to

if inspect.isclass(origin):
if issubclass(cast(Any, origin), JSONLDecoder):
return cast(
R,
cast("type[JSONLDecoder[Any]]", cast_to)(
raw_iterator=self.http_response.iter_bytes(chunk_size=64),
line_type=extract_type_arg(cast_to, 0),
http_response=self.http_response,
),
)

if issubclass(cast(Any, origin), AsyncJSONLDecoder):
return cast(
R,
cast("type[AsyncJSONLDecoder[Any]]", cast_to)(
raw_iterator=self.http_response.aiter_bytes(chunk_size=64),
line_type=extract_type_arg(cast_to, 0),
http_response=self.http_response,
),
)

if self._is_sse_stream:
if to:
if not is_stream_class_type(to):
Expand Down
49 changes: 47 additions & 2 deletions src/llama_stack_client/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, override, get_type_hints
from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints

import anyio
import pydantic

from ._utils import (
is_list,
is_given,
lru_cache,
is_mapping,
is_iterable,
)
Expand Down Expand Up @@ -108,6 +110,7 @@ class Params(TypedDict, total=False):
return cast(_T, transformed)


@lru_cache(maxsize=8096)
def _get_annotated_type(type_: type) -> type | None:
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.

Expand All @@ -126,7 +129,7 @@ def _get_annotated_type(type_: type) -> type | None:
def _maybe_transform_key(key: str, type_: type) -> str:
"""Transform the given `data` based on the annotations provided in `type_`.

Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata.
"""
annotated_type = _get_annotated_type(type_)
if annotated_type is None:
Expand All @@ -142,6 +145,10 @@ def _maybe_transform_key(key: str, type_: type) -> str:
return key


def _no_transform_needed(annotation: type) -> bool:
return annotation == float or annotation == int


def _transform_recursive(
data: object,
*,
Expand Down Expand Up @@ -184,6 +191,15 @@ def _transform_recursive(
return cast(object, data)

inner_type = extract_type_arg(stripped_type, 0)
if _no_transform_needed(inner_type):
# for some types there is no need to transform anything, so we can get a small
# perf boost from skipping that work.
#
# but we still need to convert to a list to ensure the data is json-serializable
if is_list(data):
return data
return list(data)

return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]

if is_union_type(stripped_type):
Expand Down Expand Up @@ -245,6 +261,11 @@ def _transform_typeddict(
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
if not is_given(value):
# we don't need to include `NotGiven` values here as they'll
# be stripped out before the request is sent anyway
continue

type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
Expand Down Expand Up @@ -332,6 +353,15 @@ async def _async_transform_recursive(
return cast(object, data)

inner_type = extract_type_arg(stripped_type, 0)
if _no_transform_needed(inner_type):
# for some types there is no need to transform anything, so we can get a small
# perf boost from skipping that work.
#
# but we still need to convert to a list to ensure the data is json-serializable
if is_list(data):
return data
return list(data)

return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]

if is_union_type(stripped_type):
Expand Down Expand Up @@ -393,10 +423,25 @@ async def _async_transform_typeddict(
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
if not is_given(value):
# we don't need to include `NotGiven` values here as they'll
# be stripped out before the request is sent anyway
continue

type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
return result


@lru_cache(maxsize=8096)
def get_type_hints(
obj: Any,
globalns: dict[str, Any] | None = None,
localns: Mapping[str, Any] | None = None,
include_extras: bool = False,
) -> dict[str, Any]:
return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
2 changes: 2 additions & 0 deletions src/llama_stack_client/_utils/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_origin,
)

from ._utils import lru_cache
from .._types import InheritsGeneric
from .._compat import is_union as _is_union

Expand Down Expand Up @@ -66,6 +67,7 @@ def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:


# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
@lru_cache(maxsize=8096)
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))
Expand Down
14 changes: 13 additions & 1 deletion src/llama_stack_client/lib/inference/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,24 @@ def print(self, flush=True):


class InferenceStreamLogEventPrinter:
def __init__(self):
self.is_thinking = False

def yield_printable_events(self, chunk):
event = chunk.event
if event.event_type == "start":
yield InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="")
elif event.event_type == "progress":
yield InferenceStreamPrintableEvent(event.delta.text, color="yellow", end="")
if event.delta.type == "reasoning":
if not self.is_thinking:
yield InferenceStreamPrintableEvent("<thinking> ", color="magenta", end="")
self.is_thinking = True
yield InferenceStreamPrintableEvent(event.delta.reasoning, color="magenta", end="")
else:
if self.is_thinking:
yield InferenceStreamPrintableEvent("</thinking>", color="magenta", end="")
self.is_thinking = False
yield InferenceStreamPrintableEvent(event.delta.text, color="yellow", end="")
elif event.event_type == "complete":
yield InferenceStreamPrintableEvent("")

Expand Down
Loading