diff --git a/src/llama_stack_client/_client.py b/src/llama_stack_client/_client.py
index 00922520..7066ae2a 100644
--- a/src/llama_stack_client/_client.py
+++ b/src/llama_stack_client/_client.py
@@ -41,7 +41,6 @@
benchmarks,
toolgroups,
vector_dbs,
- batch_inference,
scoring_functions,
synthetic_data_generation,
)
@@ -74,7 +73,6 @@ class LlamaStackClient(SyncAPIClient):
tools: tools.ToolsResource
tool_runtime: tool_runtime.ToolRuntimeResource
agents: agents.AgentsResource
- batch_inference: batch_inference.BatchInferenceResource
datasets: datasets.DatasetsResource
eval: eval.EvalResource
inspect: inspect.InspectResource
@@ -155,7 +153,6 @@ def __init__(
self.tools = tools.ToolsResource(self)
self.tool_runtime = tool_runtime.ToolRuntimeResource(self)
self.agents = agents.AgentsResource(self)
- self.batch_inference = batch_inference.BatchInferenceResource(self)
self.datasets = datasets.DatasetsResource(self)
self.eval = eval.EvalResource(self)
self.inspect = inspect.InspectResource(self)
@@ -288,7 +285,6 @@ class AsyncLlamaStackClient(AsyncAPIClient):
tools: tools.AsyncToolsResource
tool_runtime: tool_runtime.AsyncToolRuntimeResource
agents: agents.AsyncAgentsResource
- batch_inference: batch_inference.AsyncBatchInferenceResource
datasets: datasets.AsyncDatasetsResource
eval: eval.AsyncEvalResource
inspect: inspect.AsyncInspectResource
@@ -369,7 +365,6 @@ def __init__(
self.tools = tools.AsyncToolsResource(self)
self.tool_runtime = tool_runtime.AsyncToolRuntimeResource(self)
self.agents = agents.AsyncAgentsResource(self)
- self.batch_inference = batch_inference.AsyncBatchInferenceResource(self)
self.datasets = datasets.AsyncDatasetsResource(self)
self.eval = eval.AsyncEvalResource(self)
self.inspect = inspect.AsyncInspectResource(self)
@@ -503,7 +498,6 @@ def __init__(self, client: LlamaStackClient) -> None:
self.tools = tools.ToolsResourceWithRawResponse(client.tools)
self.tool_runtime = tool_runtime.ToolRuntimeResourceWithRawResponse(client.tool_runtime)
self.agents = agents.AgentsResourceWithRawResponse(client.agents)
- self.batch_inference = batch_inference.BatchInferenceResourceWithRawResponse(client.batch_inference)
self.datasets = datasets.DatasetsResourceWithRawResponse(client.datasets)
self.eval = eval.EvalResourceWithRawResponse(client.eval)
self.inspect = inspect.InspectResourceWithRawResponse(client.inspect)
@@ -531,7 +525,6 @@ def __init__(self, client: AsyncLlamaStackClient) -> None:
self.tools = tools.AsyncToolsResourceWithRawResponse(client.tools)
self.tool_runtime = tool_runtime.AsyncToolRuntimeResourceWithRawResponse(client.tool_runtime)
self.agents = agents.AsyncAgentsResourceWithRawResponse(client.agents)
- self.batch_inference = batch_inference.AsyncBatchInferenceResourceWithRawResponse(client.batch_inference)
self.datasets = datasets.AsyncDatasetsResourceWithRawResponse(client.datasets)
self.eval = eval.AsyncEvalResourceWithRawResponse(client.eval)
self.inspect = inspect.AsyncInspectResourceWithRawResponse(client.inspect)
@@ -561,7 +554,6 @@ def __init__(self, client: LlamaStackClient) -> None:
self.tools = tools.ToolsResourceWithStreamingResponse(client.tools)
self.tool_runtime = tool_runtime.ToolRuntimeResourceWithStreamingResponse(client.tool_runtime)
self.agents = agents.AgentsResourceWithStreamingResponse(client.agents)
- self.batch_inference = batch_inference.BatchInferenceResourceWithStreamingResponse(client.batch_inference)
self.datasets = datasets.DatasetsResourceWithStreamingResponse(client.datasets)
self.eval = eval.EvalResourceWithStreamingResponse(client.eval)
self.inspect = inspect.InspectResourceWithStreamingResponse(client.inspect)
@@ -591,7 +583,6 @@ def __init__(self, client: AsyncLlamaStackClient) -> None:
self.tools = tools.AsyncToolsResourceWithStreamingResponse(client.tools)
self.tool_runtime = tool_runtime.AsyncToolRuntimeResourceWithStreamingResponse(client.tool_runtime)
self.agents = agents.AsyncAgentsResourceWithStreamingResponse(client.agents)
- self.batch_inference = batch_inference.AsyncBatchInferenceResourceWithStreamingResponse(client.batch_inference)
self.datasets = datasets.AsyncDatasetsResourceWithStreamingResponse(client.datasets)
self.eval = eval.AsyncEvalResourceWithStreamingResponse(client.eval)
self.inspect = inspect.AsyncInspectResourceWithStreamingResponse(client.inspect)
diff --git a/src/llama_stack_client/_decoders/jsonl.py b/src/llama_stack_client/_decoders/jsonl.py
deleted file mode 100644
index ac5ac74f..00000000
--- a/src/llama_stack_client/_decoders/jsonl.py
+++ /dev/null
@@ -1,123 +0,0 @@
-from __future__ import annotations
-
-import json
-from typing_extensions import Generic, TypeVar, Iterator, AsyncIterator
-
-import httpx
-
-from .._models import construct_type_unchecked
-
-_T = TypeVar("_T")
-
-
-class JSONLDecoder(Generic[_T]):
- """A decoder for [JSON Lines](https://jsonlines.org) format.
-
- This class provides an iterator over a byte-iterator that parses each JSON Line
- into a given type.
- """
-
- http_response: httpx.Response
- """The HTTP response this decoder was constructed from"""
-
- def __init__(
- self,
- *,
- raw_iterator: Iterator[bytes],
- line_type: type[_T],
- http_response: httpx.Response,
- ) -> None:
- super().__init__()
- self.http_response = http_response
- self._raw_iterator = raw_iterator
- self._line_type = line_type
- self._iterator = self.__decode__()
-
- def close(self) -> None:
- """Close the response body stream.
-
- This is called automatically if you consume the entire stream.
- """
- self.http_response.close()
-
- def __decode__(self) -> Iterator[_T]:
- buf = b""
- for chunk in self._raw_iterator:
- for line in chunk.splitlines(keepends=True):
- buf += line
- if buf.endswith((b"\r", b"\n", b"\r\n")):
- yield construct_type_unchecked(
- value=json.loads(buf),
- type_=self._line_type,
- )
- buf = b""
-
- # flush
- if buf:
- yield construct_type_unchecked(
- value=json.loads(buf),
- type_=self._line_type,
- )
-
- def __next__(self) -> _T:
- return self._iterator.__next__()
-
- def __iter__(self) -> Iterator[_T]:
- for item in self._iterator:
- yield item
-
-
-class AsyncJSONLDecoder(Generic[_T]):
- """A decoder for [JSON Lines](https://jsonlines.org) format.
-
- This class provides an async iterator over a byte-iterator that parses each JSON Line
- into a given type.
- """
-
- http_response: httpx.Response
-
- def __init__(
- self,
- *,
- raw_iterator: AsyncIterator[bytes],
- line_type: type[_T],
- http_response: httpx.Response,
- ) -> None:
- super().__init__()
- self.http_response = http_response
- self._raw_iterator = raw_iterator
- self._line_type = line_type
- self._iterator = self.__decode__()
-
- async def close(self) -> None:
- """Close the response body stream.
-
- This is called automatically if you consume the entire stream.
- """
- await self.http_response.aclose()
-
- async def __decode__(self) -> AsyncIterator[_T]:
- buf = b""
- async for chunk in self._raw_iterator:
- for line in chunk.splitlines(keepends=True):
- buf += line
- if buf.endswith((b"\r", b"\n", b"\r\n")):
- yield construct_type_unchecked(
- value=json.loads(buf),
- type_=self._line_type,
- )
- buf = b""
-
- # flush
- if buf:
- yield construct_type_unchecked(
- value=json.loads(buf),
- type_=self._line_type,
- )
-
- async def __anext__(self) -> _T:
- return await self._iterator.__anext__()
-
- async def __aiter__(self) -> AsyncIterator[_T]:
- async for item in self._iterator:
- yield item
diff --git a/src/llama_stack_client/_models.py b/src/llama_stack_client/_models.py
index b51a1bf5..34935716 100644
--- a/src/llama_stack_client/_models.py
+++ b/src/llama_stack_client/_models.py
@@ -681,7 +681,7 @@ def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None:
setattr(typ, "__pydantic_config__", config) # noqa: B010
-# our use of subclasssing here causes weirdness for type checkers,
+# our use of subclassing here causes weirdness for type checkers,
# so we just pretend that we don't subclass
if TYPE_CHECKING:
GenericModel = BaseModel
diff --git a/src/llama_stack_client/_response.py b/src/llama_stack_client/_response.py
index ea35182f..1938ae74 100644
--- a/src/llama_stack_client/_response.py
+++ b/src/llama_stack_client/_response.py
@@ -30,7 +30,6 @@
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
from ._exceptions import LlamaStackClientError, APIResponseValidationError
-from ._decoders.jsonl import JSONLDecoder, AsyncJSONLDecoder
if TYPE_CHECKING:
from ._models import FinalRequestOptions
@@ -139,27 +138,6 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
origin = get_origin(cast_to) or cast_to
- if inspect.isclass(origin):
- if issubclass(cast(Any, origin), JSONLDecoder):
- return cast(
- R,
- cast("type[JSONLDecoder[Any]]", cast_to)(
- raw_iterator=self.http_response.iter_bytes(chunk_size=64),
- line_type=extract_type_arg(cast_to, 0),
- http_response=self.http_response,
- ),
- )
-
- if issubclass(cast(Any, origin), AsyncJSONLDecoder):
- return cast(
- R,
- cast("type[AsyncJSONLDecoder[Any]]", cast_to)(
- raw_iterator=self.http_response.aiter_bytes(chunk_size=64),
- line_type=extract_type_arg(cast_to, 0),
- http_response=self.http_response,
- ),
- )
-
if self._is_sse_stream:
if to:
if not is_stream_class_type(to):
diff --git a/src/llama_stack_client/_utils/_transform.py b/src/llama_stack_client/_utils/_transform.py
index 18afd9d8..b0cc20a7 100644
--- a/src/llama_stack_client/_utils/_transform.py
+++ b/src/llama_stack_client/_utils/_transform.py
@@ -5,13 +5,15 @@
import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime
-from typing_extensions import Literal, get_args, override, get_type_hints
+from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
import anyio
import pydantic
from ._utils import (
is_list,
+ is_given,
+ lru_cache,
is_mapping,
is_iterable,
)
@@ -108,6 +110,7 @@ class Params(TypedDict, total=False):
return cast(_T, transformed)
+@lru_cache(maxsize=8096)
def _get_annotated_type(type_: type) -> type | None:
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
@@ -126,7 +129,7 @@ def _get_annotated_type(type_: type) -> type | None:
def _maybe_transform_key(key: str, type_: type) -> str:
"""Transform the given `data` based on the annotations provided in `type_`.
- Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
+ Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata.
"""
annotated_type = _get_annotated_type(type_)
if annotated_type is None:
@@ -142,6 +145,10 @@ def _maybe_transform_key(key: str, type_: type) -> str:
return key
+def _no_transform_needed(annotation: type) -> bool:
+ return annotation == float or annotation == int
+
+
def _transform_recursive(
data: object,
*,
@@ -184,6 +191,15 @@ def _transform_recursive(
return cast(object, data)
inner_type = extract_type_arg(stripped_type, 0)
+ if _no_transform_needed(inner_type):
+ # for some types there is no need to transform anything, so we can get a small
+ # perf boost from skipping that work.
+ #
+ # but we still need to convert to a list to ensure the data is json-serializable
+ if is_list(data):
+ return data
+ return list(data)
+
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
@@ -245,6 +261,11 @@ def _transform_typeddict(
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
+ if not is_given(value):
+ # we don't need to include `NotGiven` values here as they'll
+ # be stripped out before the request is sent anyway
+ continue
+
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
@@ -332,6 +353,15 @@ async def _async_transform_recursive(
return cast(object, data)
inner_type = extract_type_arg(stripped_type, 0)
+ if _no_transform_needed(inner_type):
+ # for some types there is no need to transform anything, so we can get a small
+ # perf boost from skipping that work.
+ #
+ # but we still need to convert to a list to ensure the data is json-serializable
+ if is_list(data):
+ return data
+ return list(data)
+
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
@@ -393,6 +423,11 @@ async def _async_transform_typeddict(
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
+ if not is_given(value):
+ # we don't need to include `NotGiven` values here as they'll
+ # be stripped out before the request is sent anyway
+ continue
+
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
@@ -400,3 +435,13 @@ async def _async_transform_typeddict(
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
return result
+
+
+@lru_cache(maxsize=8096)
+def get_type_hints(
+ obj: Any,
+ globalns: dict[str, Any] | None = None,
+ localns: Mapping[str, Any] | None = None,
+ include_extras: bool = False,
+) -> dict[str, Any]:
+ return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
diff --git a/src/llama_stack_client/_utils/_typing.py b/src/llama_stack_client/_utils/_typing.py
index 278749b1..1958820f 100644
--- a/src/llama_stack_client/_utils/_typing.py
+++ b/src/llama_stack_client/_utils/_typing.py
@@ -13,6 +13,7 @@
get_origin,
)
+from ._utils import lru_cache
from .._types import InheritsGeneric
from .._compat import is_union as _is_union
@@ -66,6 +67,7 @@ def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
+@lru_cache(maxsize=8096)
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))
diff --git a/src/llama_stack_client/lib/inference/event_logger.py b/src/llama_stack_client/lib/inference/event_logger.py
index e8c63a1f..f74f1bd4 100644
--- a/src/llama_stack_client/lib/inference/event_logger.py
+++ b/src/llama_stack_client/lib/inference/event_logger.py
@@ -22,12 +22,24 @@ def print(self, flush=True):
class InferenceStreamLogEventPrinter:
+ def __init__(self):
+ self.is_thinking = False
+
def yield_printable_events(self, chunk):
event = chunk.event
if event.event_type == "start":
yield InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="")
elif event.event_type == "progress":
- yield InferenceStreamPrintableEvent(event.delta.text, color="yellow", end="")
+ if event.delta.type == "reasoning":
+ if not self.is_thinking:
+ yield InferenceStreamPrintableEvent(" ", color="magenta", end="")
+ self.is_thinking = True
+ yield InferenceStreamPrintableEvent(event.delta.reasoning, color="magenta", end="")
+ else:
+ if self.is_thinking:
+ yield InferenceStreamPrintableEvent("", color="magenta", end="")
+ self.is_thinking = False
+ yield InferenceStreamPrintableEvent(event.delta.text, color="yellow", end="")
elif event.event_type == "complete":
yield InferenceStreamPrintableEvent("")
diff --git a/src/llama_stack_client/resources/__init__.py b/src/llama_stack_client/resources/__init__.py
index 865d77e0..0e3373dc 100644
--- a/src/llama_stack_client/resources/__init__.py
+++ b/src/llama_stack_client/resources/__init__.py
@@ -152,14 +152,6 @@
PostTrainingResourceWithStreamingResponse,
AsyncPostTrainingResourceWithStreamingResponse,
)
-from .batch_inference import (
- BatchInferenceResource,
- AsyncBatchInferenceResource,
- BatchInferenceResourceWithRawResponse,
- AsyncBatchInferenceResourceWithRawResponse,
- BatchInferenceResourceWithStreamingResponse,
- AsyncBatchInferenceResourceWithStreamingResponse,
-)
from .scoring_functions import (
ScoringFunctionsResource,
AsyncScoringFunctionsResource,
@@ -202,12 +194,6 @@
"AsyncAgentsResourceWithRawResponse",
"AgentsResourceWithStreamingResponse",
"AsyncAgentsResourceWithStreamingResponse",
- "BatchInferenceResource",
- "AsyncBatchInferenceResource",
- "BatchInferenceResourceWithRawResponse",
- "AsyncBatchInferenceResourceWithRawResponse",
- "BatchInferenceResourceWithStreamingResponse",
- "AsyncBatchInferenceResourceWithStreamingResponse",
"DatasetsResource",
"AsyncDatasetsResource",
"DatasetsResourceWithRawResponse",
diff --git a/src/llama_stack_client/resources/agents/turn.py b/src/llama_stack_client/resources/agents/turn.py
index 23c44677..6b1b4ae2 100644
--- a/src/llama_stack_client/resources/agents/turn.py
+++ b/src/llama_stack_client/resources/agents/turn.py
@@ -218,7 +218,9 @@ def create(
"tool_config": tool_config,
"toolgroups": toolgroups,
},
- turn_create_params.TurnCreateParams,
+ turn_create_params.TurnCreateParamsStreaming
+ if stream
+ else turn_create_params.TurnCreateParamsNonStreaming,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -410,7 +412,9 @@ def resume(
"tool_responses": tool_responses,
"stream": stream,
},
- turn_resume_params.TurnResumeParams,
+ turn_resume_params.TurnResumeParamsStreaming
+ if stream
+ else turn_resume_params.TurnResumeParamsNonStreaming,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -608,7 +612,9 @@ async def create(
"tool_config": tool_config,
"toolgroups": toolgroups,
},
- turn_create_params.TurnCreateParams,
+ turn_create_params.TurnCreateParamsStreaming
+ if stream
+ else turn_create_params.TurnCreateParamsNonStreaming,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -800,7 +806,9 @@ async def resume(
"tool_responses": tool_responses,
"stream": stream,
},
- turn_resume_params.TurnResumeParams,
+ turn_resume_params.TurnResumeParamsStreaming
+ if stream
+ else turn_resume_params.TurnResumeParamsNonStreaming,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
diff --git a/src/llama_stack_client/resources/batch_inference.py b/src/llama_stack_client/resources/batch_inference.py
deleted file mode 100644
index 92d437cb..00000000
--- a/src/llama_stack_client/resources/batch_inference.py
+++ /dev/null
@@ -1,326 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from __future__ import annotations
-
-from typing import List, Iterable
-from typing_extensions import Literal
-
-import httpx
-
-from ..types import batch_inference_completion_params, batch_inference_chat_completion_params
-from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
-from .._utils import (
- maybe_transform,
- async_maybe_transform,
-)
-from .._compat import cached_property
-from .._resource import SyncAPIResource, AsyncAPIResource
-from .._response import (
- to_raw_response_wrapper,
- to_streamed_response_wrapper,
- async_to_raw_response_wrapper,
- async_to_streamed_response_wrapper,
-)
-from .._base_client import make_request_options
-from ..types.shared_params.message import Message
-from ..types.shared.batch_completion import BatchCompletion
-from ..types.shared_params.response_format import ResponseFormat
-from ..types.shared_params.sampling_params import SamplingParams
-from ..types.shared_params.interleaved_content import InterleavedContent
-from ..types.batch_inference_chat_completion_response import BatchInferenceChatCompletionResponse
-
-__all__ = ["BatchInferenceResource", "AsyncBatchInferenceResource"]
-
-
-class BatchInferenceResource(SyncAPIResource):
- @cached_property
- def with_raw_response(self) -> BatchInferenceResourceWithRawResponse:
- """
- This property can be used as a prefix for any HTTP method call to return
- the raw response object instead of the parsed content.
-
- For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers
- """
- return BatchInferenceResourceWithRawResponse(self)
-
- @cached_property
- def with_streaming_response(self) -> BatchInferenceResourceWithStreamingResponse:
- """
- An alternative to `.with_raw_response` that doesn't eagerly read the response body.
-
- For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response
- """
- return BatchInferenceResourceWithStreamingResponse(self)
-
- def chat_completion(
- self,
- *,
- messages_batch: Iterable[Iterable[Message]],
- model: str,
- logprobs: batch_inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN,
- response_format: ResponseFormat | NotGiven = NOT_GIVEN,
- sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
- tool_choice: Literal["auto", "required", "none"] | NotGiven = NOT_GIVEN,
- tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN,
- tools: Iterable[batch_inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> BatchInferenceChatCompletionResponse:
- """
- Args:
- response_format: Configuration for JSON schema-guided response generation.
-
- tool_choice: Whether tool use is required or automatic. This is a hint to the model which may
- not be followed. It depends on the Instruction Following capabilities of the
- model.
-
- tool_prompt_format: Prompt format for calling custom / zero shot tools.
-
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- return self._post(
- "/v1/batch-inference/chat-completion",
- body=maybe_transform(
- {
- "messages_batch": messages_batch,
- "model": model,
- "logprobs": logprobs,
- "response_format": response_format,
- "sampling_params": sampling_params,
- "tool_choice": tool_choice,
- "tool_prompt_format": tool_prompt_format,
- "tools": tools,
- },
- batch_inference_chat_completion_params.BatchInferenceChatCompletionParams,
- ),
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=BatchInferenceChatCompletionResponse,
- )
-
- def completion(
- self,
- *,
- content_batch: List[InterleavedContent],
- model: str,
- logprobs: batch_inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
- response_format: ResponseFormat | NotGiven = NOT_GIVEN,
- sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> BatchCompletion:
- """
- Args:
- response_format: Configuration for JSON schema-guided response generation.
-
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- return self._post(
- "/v1/batch-inference/completion",
- body=maybe_transform(
- {
- "content_batch": content_batch,
- "model": model,
- "logprobs": logprobs,
- "response_format": response_format,
- "sampling_params": sampling_params,
- },
- batch_inference_completion_params.BatchInferenceCompletionParams,
- ),
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=BatchCompletion,
- )
-
-
-class AsyncBatchInferenceResource(AsyncAPIResource):
- @cached_property
- def with_raw_response(self) -> AsyncBatchInferenceResourceWithRawResponse:
- """
- This property can be used as a prefix for any HTTP method call to return
- the raw response object instead of the parsed content.
-
- For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers
- """
- return AsyncBatchInferenceResourceWithRawResponse(self)
-
- @cached_property
- def with_streaming_response(self) -> AsyncBatchInferenceResourceWithStreamingResponse:
- """
- An alternative to `.with_raw_response` that doesn't eagerly read the response body.
-
- For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response
- """
- return AsyncBatchInferenceResourceWithStreamingResponse(self)
-
- async def chat_completion(
- self,
- *,
- messages_batch: Iterable[Iterable[Message]],
- model: str,
- logprobs: batch_inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN,
- response_format: ResponseFormat | NotGiven = NOT_GIVEN,
- sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
- tool_choice: Literal["auto", "required", "none"] | NotGiven = NOT_GIVEN,
- tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN,
- tools: Iterable[batch_inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> BatchInferenceChatCompletionResponse:
- """
- Args:
- response_format: Configuration for JSON schema-guided response generation.
-
- tool_choice: Whether tool use is required or automatic. This is a hint to the model which may
- not be followed. It depends on the Instruction Following capabilities of the
- model.
-
- tool_prompt_format: Prompt format for calling custom / zero shot tools.
-
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- return await self._post(
- "/v1/batch-inference/chat-completion",
- body=await async_maybe_transform(
- {
- "messages_batch": messages_batch,
- "model": model,
- "logprobs": logprobs,
- "response_format": response_format,
- "sampling_params": sampling_params,
- "tool_choice": tool_choice,
- "tool_prompt_format": tool_prompt_format,
- "tools": tools,
- },
- batch_inference_chat_completion_params.BatchInferenceChatCompletionParams,
- ),
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=BatchInferenceChatCompletionResponse,
- )
-
- async def completion(
- self,
- *,
- content_batch: List[InterleavedContent],
- model: str,
- logprobs: batch_inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
- response_format: ResponseFormat | NotGiven = NOT_GIVEN,
- sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> BatchCompletion:
- """
- Args:
- response_format: Configuration for JSON schema-guided response generation.
-
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- return await self._post(
- "/v1/batch-inference/completion",
- body=await async_maybe_transform(
- {
- "content_batch": content_batch,
- "model": model,
- "logprobs": logprobs,
- "response_format": response_format,
- "sampling_params": sampling_params,
- },
- batch_inference_completion_params.BatchInferenceCompletionParams,
- ),
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=BatchCompletion,
- )
-
-
-class BatchInferenceResourceWithRawResponse:
- def __init__(self, batch_inference: BatchInferenceResource) -> None:
- self._batch_inference = batch_inference
-
- self.chat_completion = to_raw_response_wrapper(
- batch_inference.chat_completion,
- )
- self.completion = to_raw_response_wrapper(
- batch_inference.completion,
- )
-
-
-class AsyncBatchInferenceResourceWithRawResponse:
- def __init__(self, batch_inference: AsyncBatchInferenceResource) -> None:
- self._batch_inference = batch_inference
-
- self.chat_completion = async_to_raw_response_wrapper(
- batch_inference.chat_completion,
- )
- self.completion = async_to_raw_response_wrapper(
- batch_inference.completion,
- )
-
-
-class BatchInferenceResourceWithStreamingResponse:
- def __init__(self, batch_inference: BatchInferenceResource) -> None:
- self._batch_inference = batch_inference
-
- self.chat_completion = to_streamed_response_wrapper(
- batch_inference.chat_completion,
- )
- self.completion = to_streamed_response_wrapper(
- batch_inference.completion,
- )
-
-
-class AsyncBatchInferenceResourceWithStreamingResponse:
- def __init__(self, batch_inference: AsyncBatchInferenceResource) -> None:
- self._batch_inference = batch_inference
-
- self.chat_completion = async_to_streamed_response_wrapper(
- batch_inference.chat_completion,
- )
- self.completion = async_to_streamed_response_wrapper(
- batch_inference.completion,
- )
diff --git a/src/llama_stack_client/resources/datasets.py b/src/llama_stack_client/resources/datasets.py
index ed56ac80..aaa27a5e 100644
--- a/src/llama_stack_client/resources/datasets.py
+++ b/src/llama_stack_client/resources/datasets.py
@@ -119,7 +119,15 @@ def iterrows(
) -> DatasetIterrowsResponse:
"""Get a paginated list of rows from a dataset.
- Uses cursor-based pagination.
+ Uses offset-based pagination where:
+
+ - start_index: The starting index (0-based). If None, starts from beginning.
+ - limit: Number of items to return. If None or -1, returns all items.
+
+ The response includes:
+
+ - data: List of items for the current page
+ - has_more: Whether there are more items available after this set
Args:
limit: The number of rows to get.
@@ -344,7 +352,15 @@ async def iterrows(
) -> DatasetIterrowsResponse:
"""Get a paginated list of rows from a dataset.
- Uses cursor-based pagination.
+ Uses offset-based pagination where:
+
+ - start_index: The starting index (0-based). If None, starts from beginning.
+ - limit: Number of items to return. If None or -1, returns all items.
+
+ The response includes:
+
+ - data: List of items for the current page
+ - has_more: Whether there are more items available after this set
Args:
limit: The number of rows to get.
diff --git a/src/llama_stack_client/resources/inference.py b/src/llama_stack_client/resources/inference.py
index c4d85852..428956ce 100644
--- a/src/llama_stack_client/resources/inference.py
+++ b/src/llama_stack_client/resources/inference.py
@@ -11,6 +11,8 @@
inference_completion_params,
inference_embeddings_params,
inference_chat_completion_params,
+ inference_batch_completion_params,
+ inference_batch_chat_completion_params,
)
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import (
@@ -31,12 +33,14 @@
from ..types.completion_response import CompletionResponse
from ..types.embeddings_response import EmbeddingsResponse
from ..types.shared_params.message import Message
+from ..types.shared.batch_completion import BatchCompletion
from ..types.shared_params.response_format import ResponseFormat
from ..types.shared_params.sampling_params import SamplingParams
from ..types.shared.chat_completion_response import ChatCompletionResponse
from ..types.shared_params.interleaved_content import InterleavedContent
from ..types.chat_completion_response_stream_chunk import ChatCompletionResponseStreamChunk
from ..types.shared_params.interleaved_content_item import InterleavedContentItem
+from ..types.inference_batch_chat_completion_response import InferenceBatchChatCompletionResponse
__all__ = ["InferenceResource", "AsyncInferenceResource"]
@@ -61,6 +65,106 @@ def with_streaming_response(self) -> InferenceResourceWithStreamingResponse:
"""
return InferenceResourceWithStreamingResponse(self)
+ def batch_chat_completion(
+ self,
+ *,
+ messages_batch: Iterable[Iterable[Message]],
+ model_id: str,
+ logprobs: inference_batch_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ tool_config: inference_batch_chat_completion_params.ToolConfig | NotGiven = NOT_GIVEN,
+ tools: Iterable[inference_batch_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> InferenceBatchChatCompletionResponse:
+ """
+ Args:
+ response_format: Configuration for JSON schema-guided response generation.
+
+ sampling_params: Sampling parameters.
+
+ tool_config: Configuration for tool use.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._post(
+ "/v1/inference/batch-chat-completion",
+ body=maybe_transform(
+ {
+ "messages_batch": messages_batch,
+ "model_id": model_id,
+ "logprobs": logprobs,
+ "response_format": response_format,
+ "sampling_params": sampling_params,
+ "tool_config": tool_config,
+ "tools": tools,
+ },
+ inference_batch_chat_completion_params.InferenceBatchChatCompletionParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=InferenceBatchChatCompletionResponse,
+ )
+
+ def batch_completion(
+ self,
+ *,
+ content_batch: List[InterleavedContent],
+ model_id: str,
+ logprobs: inference_batch_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> BatchCompletion:
+ """
+ Args:
+ response_format: Configuration for JSON schema-guided response generation.
+
+ sampling_params: Sampling parameters.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return self._post(
+ "/v1/inference/batch-completion",
+ body=maybe_transform(
+ {
+ "content_batch": content_batch,
+ "model_id": model_id,
+ "logprobs": logprobs,
+ "response_format": response_format,
+ "sampling_params": sampling_params,
+ },
+ inference_batch_completion_params.InferenceBatchCompletionParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=BatchCompletion,
+ )
+
@overload
def chat_completion(
self,
@@ -303,7 +407,9 @@ def chat_completion(
"tool_prompt_format": tool_prompt_format,
"tools": tools,
},
- inference_chat_completion_params.InferenceChatCompletionParams,
+ inference_chat_completion_params.InferenceChatCompletionParamsStreaming
+ if stream
+ else inference_chat_completion_params.InferenceChatCompletionParamsNonStreaming,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -481,7 +587,9 @@ def completion(
"sampling_params": sampling_params,
"stream": stream,
},
- inference_completion_params.InferenceCompletionParams,
+ inference_completion_params.InferenceCompletionParamsStreaming
+ if stream
+ else inference_completion_params.InferenceCompletionParamsNonStreaming,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -573,6 +681,106 @@ def with_streaming_response(self) -> AsyncInferenceResourceWithStreamingResponse
"""
return AsyncInferenceResourceWithStreamingResponse(self)
+ async def batch_chat_completion(
+ self,
+ *,
+ messages_batch: Iterable[Iterable[Message]],
+ model_id: str,
+ logprobs: inference_batch_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ tool_config: inference_batch_chat_completion_params.ToolConfig | NotGiven = NOT_GIVEN,
+ tools: Iterable[inference_batch_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> InferenceBatchChatCompletionResponse:
+ """
+ Args:
+ response_format: Configuration for JSON schema-guided response generation.
+
+ sampling_params: Sampling parameters.
+
+ tool_config: Configuration for tool use.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._post(
+ "/v1/inference/batch-chat-completion",
+ body=await async_maybe_transform(
+ {
+ "messages_batch": messages_batch,
+ "model_id": model_id,
+ "logprobs": logprobs,
+ "response_format": response_format,
+ "sampling_params": sampling_params,
+ "tool_config": tool_config,
+ "tools": tools,
+ },
+ inference_batch_chat_completion_params.InferenceBatchChatCompletionParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=InferenceBatchChatCompletionResponse,
+ )
+
+ async def batch_completion(
+ self,
+ *,
+ content_batch: List[InterleavedContent],
+ model_id: str,
+ logprobs: inference_batch_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> BatchCompletion:
+ """
+ Args:
+ response_format: Configuration for JSON schema-guided response generation.
+
+ sampling_params: Sampling parameters.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ return await self._post(
+ "/v1/inference/batch-completion",
+ body=await async_maybe_transform(
+ {
+ "content_batch": content_batch,
+ "model_id": model_id,
+ "logprobs": logprobs,
+ "response_format": response_format,
+ "sampling_params": sampling_params,
+ },
+ inference_batch_completion_params.InferenceBatchCompletionParams,
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=BatchCompletion,
+ )
+
@overload
async def chat_completion(
self,
@@ -815,7 +1023,9 @@ async def chat_completion(
"tool_prompt_format": tool_prompt_format,
"tools": tools,
},
- inference_chat_completion_params.InferenceChatCompletionParams,
+ inference_chat_completion_params.InferenceChatCompletionParamsStreaming
+ if stream
+ else inference_chat_completion_params.InferenceChatCompletionParamsNonStreaming,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -993,7 +1203,9 @@ async def completion(
"sampling_params": sampling_params,
"stream": stream,
},
- inference_completion_params.InferenceCompletionParams,
+ inference_completion_params.InferenceCompletionParamsStreaming
+ if stream
+ else inference_completion_params.InferenceCompletionParamsNonStreaming,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -1069,6 +1281,12 @@ class InferenceResourceWithRawResponse:
def __init__(self, inference: InferenceResource) -> None:
self._inference = inference
+ self.batch_chat_completion = to_raw_response_wrapper(
+ inference.batch_chat_completion,
+ )
+ self.batch_completion = to_raw_response_wrapper(
+ inference.batch_completion,
+ )
self.chat_completion = to_raw_response_wrapper(
inference.chat_completion,
)
@@ -1084,6 +1302,12 @@ class AsyncInferenceResourceWithRawResponse:
def __init__(self, inference: AsyncInferenceResource) -> None:
self._inference = inference
+ self.batch_chat_completion = async_to_raw_response_wrapper(
+ inference.batch_chat_completion,
+ )
+ self.batch_completion = async_to_raw_response_wrapper(
+ inference.batch_completion,
+ )
self.chat_completion = async_to_raw_response_wrapper(
inference.chat_completion,
)
@@ -1099,6 +1323,12 @@ class InferenceResourceWithStreamingResponse:
def __init__(self, inference: InferenceResource) -> None:
self._inference = inference
+ self.batch_chat_completion = to_streamed_response_wrapper(
+ inference.batch_chat_completion,
+ )
+ self.batch_completion = to_streamed_response_wrapper(
+ inference.batch_completion,
+ )
self.chat_completion = to_streamed_response_wrapper(
inference.chat_completion,
)
@@ -1114,6 +1344,12 @@ class AsyncInferenceResourceWithStreamingResponse:
def __init__(self, inference: AsyncInferenceResource) -> None:
self._inference = inference
+ self.batch_chat_completion = async_to_streamed_response_wrapper(
+ inference.batch_chat_completion,
+ )
+ self.batch_completion = async_to_streamed_response_wrapper(
+ inference.batch_completion,
+ )
self.chat_completion = async_to_streamed_response_wrapper(
inference.chat_completion,
)
diff --git a/src/llama_stack_client/resources/tool_runtime/tool_runtime.py b/src/llama_stack_client/resources/tool_runtime/tool_runtime.py
index 2bd7347b..aa380f79 100644
--- a/src/llama_stack_client/resources/tool_runtime/tool_runtime.py
+++ b/src/llama_stack_client/resources/tool_runtime/tool_runtime.py
@@ -2,7 +2,7 @@
from __future__ import annotations
-from typing import Dict, Union, Iterable
+from typing import Dict, Type, Union, Iterable, cast
import httpx
@@ -28,10 +28,10 @@
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
+from ..._wrappers import DataWrapper
from ..._base_client import make_request_options
-from ...types.tool_def import ToolDef
-from ..._decoders.jsonl import JSONLDecoder, AsyncJSONLDecoder
from ...types.tool_invocation_result import ToolInvocationResult
+from ...types.tool_runtime_list_tools_response import ToolRuntimeListToolsResponse
__all__ = ["ToolRuntimeResource", "AsyncToolRuntimeResource"]
@@ -110,7 +110,7 @@ def list_tools(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> JSONLDecoder[ToolDef]:
+ ) -> ToolRuntimeListToolsResponse:
"""
Args:
extra_headers: Send extra headers
@@ -121,7 +121,6 @@ def list_tools(
timeout: Override the client-level default timeout for this request, in seconds
"""
- extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
return self._get(
"/v1/tool-runtime/list-tools",
options=make_request_options(
@@ -136,9 +135,9 @@ def list_tools(
},
tool_runtime_list_tools_params.ToolRuntimeListToolsParams,
),
+ post_parser=DataWrapper[ToolRuntimeListToolsResponse]._unwrapper,
),
- cast_to=JSONLDecoder[ToolDef],
- stream=True,
+ cast_to=cast(Type[ToolRuntimeListToolsResponse], DataWrapper[ToolRuntimeListToolsResponse]),
)
@@ -216,7 +215,7 @@ async def list_tools(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> AsyncJSONLDecoder[ToolDef]:
+ ) -> ToolRuntimeListToolsResponse:
"""
Args:
extra_headers: Send extra headers
@@ -227,7 +226,6 @@ async def list_tools(
timeout: Override the client-level default timeout for this request, in seconds
"""
- extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
return await self._get(
"/v1/tool-runtime/list-tools",
options=make_request_options(
@@ -242,9 +240,9 @@ async def list_tools(
},
tool_runtime_list_tools_params.ToolRuntimeListToolsParams,
),
+ post_parser=DataWrapper[ToolRuntimeListToolsResponse]._unwrapper,
),
- cast_to=AsyncJSONLDecoder[ToolDef],
- stream=True,
+ cast_to=cast(Type[ToolRuntimeListToolsResponse], DataWrapper[ToolRuntimeListToolsResponse]),
)
diff --git a/src/llama_stack_client/types/__init__.py b/src/llama_stack_client/types/__init__.py
index b45996a9..a78eae03 100644
--- a/src/llama_stack_client/types/__init__.py
+++ b/src/llama_stack_client/types/__init__.py
@@ -126,19 +126,20 @@
from .list_post_training_jobs_response import ListPostTrainingJobsResponse as ListPostTrainingJobsResponse
from .scoring_function_register_params import ScoringFunctionRegisterParams as ScoringFunctionRegisterParams
from .telemetry_get_span_tree_response import TelemetryGetSpanTreeResponse as TelemetryGetSpanTreeResponse
-from .batch_inference_completion_params import BatchInferenceCompletionParams as BatchInferenceCompletionParams
+from .tool_runtime_list_tools_response import ToolRuntimeListToolsResponse as ToolRuntimeListToolsResponse
+from .inference_batch_completion_params import InferenceBatchCompletionParams as InferenceBatchCompletionParams
from .synthetic_data_generation_response import SyntheticDataGenerationResponse as SyntheticDataGenerationResponse
from .chat_completion_response_stream_chunk import (
ChatCompletionResponseStreamChunk as ChatCompletionResponseStreamChunk,
)
-from .batch_inference_chat_completion_params import (
- BatchInferenceChatCompletionParams as BatchInferenceChatCompletionParams,
+from .inference_batch_chat_completion_params import (
+ InferenceBatchChatCompletionParams as InferenceBatchChatCompletionParams,
)
from .telemetry_save_spans_to_dataset_params import (
TelemetrySaveSpansToDatasetParams as TelemetrySaveSpansToDatasetParams,
)
-from .batch_inference_chat_completion_response import (
- BatchInferenceChatCompletionResponse as BatchInferenceChatCompletionResponse,
+from .inference_batch_chat_completion_response import (
+ InferenceBatchChatCompletionResponse as InferenceBatchChatCompletionResponse,
)
from .post_training_preference_optimize_params import (
PostTrainingPreferenceOptimizeParams as PostTrainingPreferenceOptimizeParams,
diff --git a/src/llama_stack_client/types/batch_inference_chat_completion_params.py b/src/llama_stack_client/types/batch_inference_chat_completion_params.py
deleted file mode 100644
index 3091f5bb..00000000
--- a/src/llama_stack_client/types/batch_inference_chat_completion_params.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from __future__ import annotations
-
-from typing import Dict, Union, Iterable
-from typing_extensions import Literal, Required, TypedDict
-
-from .shared_params.message import Message
-from .shared_params.response_format import ResponseFormat
-from .shared_params.sampling_params import SamplingParams
-from .shared_params.tool_param_definition import ToolParamDefinition
-
-__all__ = ["BatchInferenceChatCompletionParams", "Logprobs", "Tool"]
-
-
-class BatchInferenceChatCompletionParams(TypedDict, total=False):
- messages_batch: Required[Iterable[Iterable[Message]]]
-
- model: Required[str]
-
- logprobs: Logprobs
-
- response_format: ResponseFormat
- """Configuration for JSON schema-guided response generation."""
-
- sampling_params: SamplingParams
-
- tool_choice: Literal["auto", "required", "none"]
- """Whether tool use is required or automatic.
-
- This is a hint to the model which may not be followed. It depends on the
- Instruction Following capabilities of the model.
- """
-
- tool_prompt_format: Literal["json", "function_tag", "python_list"]
- """Prompt format for calling custom / zero shot tools."""
-
- tools: Iterable[Tool]
-
-
-class Logprobs(TypedDict, total=False):
- top_k: int
- """How many tokens (for each position) to return log probabilities for."""
-
-
-class Tool(TypedDict, total=False):
- tool_name: Required[Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]]
-
- description: str
-
- parameters: Dict[str, ToolParamDefinition]
diff --git a/src/llama_stack_client/types/dataset_iterrows_response.py b/src/llama_stack_client/types/dataset_iterrows_response.py
index 48593bb2..9c451a8c 100644
--- a/src/llama_stack_client/types/dataset_iterrows_response.py
+++ b/src/llama_stack_client/types/dataset_iterrows_response.py
@@ -1,6 +1,6 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-from typing import Dict, List, Union, Optional
+from typing import Dict, List, Union
from .._models import BaseModel
@@ -9,10 +9,7 @@
class DatasetIterrowsResponse(BaseModel):
data: List[Dict[str, Union[bool, float, str, List[object], object, None]]]
- """The rows in the current page."""
+ """The list of items for the current page"""
- next_start_index: Optional[int] = None
- """Index into dataset for the first row in the next page.
-
- None if there are no more rows.
- """
+ has_more: bool
+ """Whether there are more items available after this set"""
diff --git a/src/llama_stack_client/types/inference_batch_chat_completion_params.py b/src/llama_stack_client/types/inference_batch_chat_completion_params.py
new file mode 100644
index 00000000..ca53fdbf
--- /dev/null
+++ b/src/llama_stack_client/types/inference_batch_chat_completion_params.py
@@ -0,0 +1,74 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing import Dict, Union, Iterable
+from typing_extensions import Literal, Required, TypedDict
+
+from .shared_params.message import Message
+from .shared_params.response_format import ResponseFormat
+from .shared_params.sampling_params import SamplingParams
+from .shared_params.tool_param_definition import ToolParamDefinition
+
+__all__ = ["InferenceBatchChatCompletionParams", "Logprobs", "ToolConfig", "Tool"]
+
+
+class InferenceBatchChatCompletionParams(TypedDict, total=False):
+ messages_batch: Required[Iterable[Iterable[Message]]]
+
+ model_id: Required[str]
+
+ logprobs: Logprobs
+
+ response_format: ResponseFormat
+ """Configuration for JSON schema-guided response generation."""
+
+ sampling_params: SamplingParams
+ """Sampling parameters."""
+
+ tool_config: ToolConfig
+ """Configuration for tool use."""
+
+ tools: Iterable[Tool]
+
+
+class Logprobs(TypedDict, total=False):
+ top_k: int
+ """How many tokens (for each position) to return log probabilities for."""
+
+
+class ToolConfig(TypedDict, total=False):
+ system_message_behavior: Literal["append", "replace"]
+ """(Optional) Config for how to override the default system prompt.
+
+ - `SystemMessageBehavior.append`: Appends the provided system message to the
+ default system prompt. - `SystemMessageBehavior.replace`: Replaces the default
+ system prompt with the provided system message. The system message can include
+ the string '{{function_definitions}}' to indicate where the function
+ definitions should be inserted.
+ """
+
+ tool_choice: Union[Literal["auto", "required", "none"], str]
+ """(Optional) Whether tool use is automatic, required, or none.
+
+ Can also specify a tool name to use a specific tool. Defaults to
+ ToolChoice.auto.
+ """
+
+ tool_prompt_format: Literal["json", "function_tag", "python_list"]
+ """(Optional) Instructs the model how to format tool calls.
+
+ By default, Llama Stack will attempt to use a format that is best adapted to the
+ model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON
+ object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a
+ tag. - `ToolPromptFormat.python_list`: The tool calls
+ are output as Python syntax -- a list of function calls.
+ """
+
+
+class Tool(TypedDict, total=False):
+ tool_name: Required[Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]]
+
+ description: str
+
+ parameters: Dict[str, ToolParamDefinition]
diff --git a/src/llama_stack_client/types/batch_inference_chat_completion_response.py b/src/llama_stack_client/types/inference_batch_chat_completion_response.py
similarity index 70%
rename from src/llama_stack_client/types/batch_inference_chat_completion_response.py
rename to src/llama_stack_client/types/inference_batch_chat_completion_response.py
index 218b1275..84d6c425 100644
--- a/src/llama_stack_client/types/batch_inference_chat_completion_response.py
+++ b/src/llama_stack_client/types/inference_batch_chat_completion_response.py
@@ -5,8 +5,8 @@
from .._models import BaseModel
from .shared.chat_completion_response import ChatCompletionResponse
-__all__ = ["BatchInferenceChatCompletionResponse"]
+__all__ = ["InferenceBatchChatCompletionResponse"]
-class BatchInferenceChatCompletionResponse(BaseModel):
+class InferenceBatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
diff --git a/src/llama_stack_client/types/batch_inference_completion_params.py b/src/llama_stack_client/types/inference_batch_completion_params.py
similarity index 80%
rename from src/llama_stack_client/types/batch_inference_completion_params.py
rename to src/llama_stack_client/types/inference_batch_completion_params.py
index 3f80d625..cbeb9309 100644
--- a/src/llama_stack_client/types/batch_inference_completion_params.py
+++ b/src/llama_stack_client/types/inference_batch_completion_params.py
@@ -9,13 +9,13 @@
from .shared_params.sampling_params import SamplingParams
from .shared_params.interleaved_content import InterleavedContent
-__all__ = ["BatchInferenceCompletionParams", "Logprobs"]
+__all__ = ["InferenceBatchCompletionParams", "Logprobs"]
-class BatchInferenceCompletionParams(TypedDict, total=False):
+class InferenceBatchCompletionParams(TypedDict, total=False):
content_batch: Required[List[InterleavedContent]]
- model: Required[str]
+ model_id: Required[str]
logprobs: Logprobs
@@ -23,6 +23,7 @@ class BatchInferenceCompletionParams(TypedDict, total=False):
"""Configuration for JSON schema-guided response generation."""
sampling_params: SamplingParams
+ """Sampling parameters."""
class Logprobs(TypedDict, total=False):
diff --git a/src/llama_stack_client/types/job.py b/src/llama_stack_client/types/job.py
index 74c6beb7..4953b3bf 100644
--- a/src/llama_stack_client/types/job.py
+++ b/src/llama_stack_client/types/job.py
@@ -10,4 +10,4 @@
class Job(BaseModel):
job_id: str
- status: Literal["completed", "in_progress", "failed", "scheduled"]
+ status: Literal["completed", "in_progress", "failed", "scheduled", "cancelled"]
diff --git a/src/llama_stack_client/types/post_training/job_status_response.py b/src/llama_stack_client/types/post_training/job_status_response.py
index 250bd82a..5ba60a6a 100644
--- a/src/llama_stack_client/types/post_training/job_status_response.py
+++ b/src/llama_stack_client/types/post_training/job_status_response.py
@@ -14,7 +14,7 @@ class JobStatusResponse(BaseModel):
job_uuid: str
- status: Literal["completed", "in_progress", "failed", "scheduled"]
+ status: Literal["completed", "in_progress", "failed", "scheduled", "cancelled"]
completed_at: Optional[datetime] = None
diff --git a/src/llama_stack_client/types/shared/agent_config.py b/src/llama_stack_client/types/shared/agent_config.py
index 273487ae..04997ac4 100644
--- a/src/llama_stack_client/types/shared/agent_config.py
+++ b/src/llama_stack_client/types/shared/agent_config.py
@@ -68,6 +68,7 @@ class AgentConfig(BaseModel):
"""Configuration for JSON schema-guided response generation."""
sampling_params: Optional[SamplingParams] = None
+ """Sampling parameters."""
tool_choice: Optional[Literal["auto", "required", "none"]] = None
"""Whether tool use is required or automatic.
diff --git a/src/llama_stack_client/types/shared/sampling_params.py b/src/llama_stack_client/types/shared/sampling_params.py
index bb5866a6..7ce2211e 100644
--- a/src/llama_stack_client/types/shared/sampling_params.py
+++ b/src/llama_stack_client/types/shared/sampling_params.py
@@ -1,6 +1,6 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-from typing import Union, Optional
+from typing import List, Union, Optional
from typing_extensions import Literal, Annotated, TypeAlias
from ..._utils import PropertyInfo
@@ -41,7 +41,24 @@ class StrategyTopKSamplingStrategy(BaseModel):
class SamplingParams(BaseModel):
strategy: Strategy
+ """The sampling strategy."""
max_tokens: Optional[int] = None
+ """The maximum number of tokens that can be generated in the completion.
+
+ The token count of your prompt plus max_tokens cannot exceed the model's context
+ length.
+ """
repetition_penalty: Optional[float] = None
+ """Number between -2.0 and 2.0.
+
+ Positive values penalize new tokens based on whether they appear in the text so
+ far, increasing the model's likelihood to talk about new topics.
+ """
+
+ stop: Optional[List[str]] = None
+ """Up to 4 sequences where the API will stop generating further tokens.
+
+ The returned text will not contain the stop sequence.
+ """
diff --git a/src/llama_stack_client/types/shared_params/agent_config.py b/src/llama_stack_client/types/shared_params/agent_config.py
index 0107ee42..f07efa39 100644
--- a/src/llama_stack_client/types/shared_params/agent_config.py
+++ b/src/llama_stack_client/types/shared_params/agent_config.py
@@ -69,6 +69,7 @@ class AgentConfig(TypedDict, total=False):
"""Configuration for JSON schema-guided response generation."""
sampling_params: SamplingParams
+ """Sampling parameters."""
tool_choice: Literal["auto", "required", "none"]
"""Whether tool use is required or automatic.
diff --git a/src/llama_stack_client/types/shared_params/sampling_params.py b/src/llama_stack_client/types/shared_params/sampling_params.py
index 1d9bcaf5..158db1c5 100644
--- a/src/llama_stack_client/types/shared_params/sampling_params.py
+++ b/src/llama_stack_client/types/shared_params/sampling_params.py
@@ -2,7 +2,7 @@
from __future__ import annotations
-from typing import Union
+from typing import List, Union
from typing_extensions import Literal, Required, TypeAlias, TypedDict
__all__ = [
@@ -37,7 +37,24 @@ class StrategyTopKSamplingStrategy(TypedDict, total=False):
class SamplingParams(TypedDict, total=False):
strategy: Required[Strategy]
+ """The sampling strategy."""
max_tokens: int
+ """The maximum number of tokens that can be generated in the completion.
+
+ The token count of your prompt plus max_tokens cannot exceed the model's context
+ length.
+ """
repetition_penalty: float
+ """Number between -2.0 and 2.0.
+
+ Positive values penalize new tokens based on whether they appear in the text so
+ far, increasing the model's likelihood to talk about new topics.
+ """
+
+ stop: List[str]
+ """Up to 4 sequences where the API will stop generating further tokens.
+
+ The returned text will not contain the stop sequence.
+ """
diff --git a/src/llama_stack_client/types/tool_runtime_list_tools_response.py b/src/llama_stack_client/types/tool_runtime_list_tools_response.py
new file mode 100644
index 00000000..cd65754f
--- /dev/null
+++ b/src/llama_stack_client/types/tool_runtime_list_tools_response.py
@@ -0,0 +1,10 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import List
+from typing_extensions import TypeAlias
+
+from .tool_def import ToolDef
+
+__all__ = ["ToolRuntimeListToolsResponse"]
+
+ToolRuntimeListToolsResponse: TypeAlias = List[ToolDef]
diff --git a/tests/api_resources/test_agents.py b/tests/api_resources/test_agents.py
index 03a15837..235d6258 100644
--- a/tests/api_resources/test_agents.py
+++ b/tests/api_resources/test_agents.py
@@ -61,6 +61,7 @@ def test_method_create_with_all_params(self, client: LlamaStackClient) -> None:
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
"tool_choice": "auto",
"tool_config": {
@@ -190,6 +191,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
"tool_choice": "auto",
"tool_config": {
diff --git a/tests/api_resources/test_batch_inference.py b/tests/api_resources/test_batch_inference.py
deleted file mode 100644
index 8e5cb9e5..00000000
--- a/tests/api_resources/test_batch_inference.py
+++ /dev/null
@@ -1,323 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from __future__ import annotations
-
-import os
-from typing import Any, cast
-
-import pytest
-
-from tests.utils import assert_matches_type
-from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient
-from llama_stack_client.types import (
- BatchInferenceChatCompletionResponse,
-)
-from llama_stack_client.types.shared import BatchCompletion
-
-base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
-
-
-class TestBatchInference:
- parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
-
- @parametrize
- def test_method_chat_completion(self, client: LlamaStackClient) -> None:
- batch_inference = client.batch_inference.chat_completion(
- messages_batch=[
- [
- {
- "content": "string",
- "role": "user",
- }
- ]
- ],
- model="model",
- )
- assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"])
-
- @parametrize
- def test_method_chat_completion_with_all_params(self, client: LlamaStackClient) -> None:
- batch_inference = client.batch_inference.chat_completion(
- messages_batch=[
- [
- {
- "content": "string",
- "role": "user",
- "context": "string",
- }
- ]
- ],
- model="model",
- logprobs={"top_k": 0},
- response_format={
- "json_schema": {"foo": True},
- "type": "json_schema",
- },
- sampling_params={
- "strategy": {"type": "greedy"},
- "max_tokens": 0,
- "repetition_penalty": 0,
- },
- tool_choice="auto",
- tool_prompt_format="json",
- tools=[
- {
- "tool_name": "brave_search",
- "description": "description",
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "default": True,
- "description": "description",
- "required": True,
- }
- },
- }
- ],
- )
- assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"])
-
- @parametrize
- def test_raw_response_chat_completion(self, client: LlamaStackClient) -> None:
- response = client.batch_inference.with_raw_response.chat_completion(
- messages_batch=[
- [
- {
- "content": "string",
- "role": "user",
- }
- ]
- ],
- model="model",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- batch_inference = response.parse()
- assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"])
-
- @parametrize
- def test_streaming_response_chat_completion(self, client: LlamaStackClient) -> None:
- with client.batch_inference.with_streaming_response.chat_completion(
- messages_batch=[
- [
- {
- "content": "string",
- "role": "user",
- }
- ]
- ],
- model="model",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- batch_inference = response.parse()
- assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @parametrize
- def test_method_completion(self, client: LlamaStackClient) -> None:
- batch_inference = client.batch_inference.completion(
- content_batch=["string"],
- model="model",
- )
- assert_matches_type(BatchCompletion, batch_inference, path=["response"])
-
- @parametrize
- def test_method_completion_with_all_params(self, client: LlamaStackClient) -> None:
- batch_inference = client.batch_inference.completion(
- content_batch=["string"],
- model="model",
- logprobs={"top_k": 0},
- response_format={
- "json_schema": {"foo": True},
- "type": "json_schema",
- },
- sampling_params={
- "strategy": {"type": "greedy"},
- "max_tokens": 0,
- "repetition_penalty": 0,
- },
- )
- assert_matches_type(BatchCompletion, batch_inference, path=["response"])
-
- @parametrize
- def test_raw_response_completion(self, client: LlamaStackClient) -> None:
- response = client.batch_inference.with_raw_response.completion(
- content_batch=["string"],
- model="model",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- batch_inference = response.parse()
- assert_matches_type(BatchCompletion, batch_inference, path=["response"])
-
- @parametrize
- def test_streaming_response_completion(self, client: LlamaStackClient) -> None:
- with client.batch_inference.with_streaming_response.completion(
- content_batch=["string"],
- model="model",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- batch_inference = response.parse()
- assert_matches_type(BatchCompletion, batch_inference, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
-
-class TestAsyncBatchInference:
- parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
-
- @parametrize
- async def test_method_chat_completion(self, async_client: AsyncLlamaStackClient) -> None:
- batch_inference = await async_client.batch_inference.chat_completion(
- messages_batch=[
- [
- {
- "content": "string",
- "role": "user",
- }
- ]
- ],
- model="model",
- )
- assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"])
-
- @parametrize
- async def test_method_chat_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None:
- batch_inference = await async_client.batch_inference.chat_completion(
- messages_batch=[
- [
- {
- "content": "string",
- "role": "user",
- "context": "string",
- }
- ]
- ],
- model="model",
- logprobs={"top_k": 0},
- response_format={
- "json_schema": {"foo": True},
- "type": "json_schema",
- },
- sampling_params={
- "strategy": {"type": "greedy"},
- "max_tokens": 0,
- "repetition_penalty": 0,
- },
- tool_choice="auto",
- tool_prompt_format="json",
- tools=[
- {
- "tool_name": "brave_search",
- "description": "description",
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "default": True,
- "description": "description",
- "required": True,
- }
- },
- }
- ],
- )
- assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"])
-
- @parametrize
- async def test_raw_response_chat_completion(self, async_client: AsyncLlamaStackClient) -> None:
- response = await async_client.batch_inference.with_raw_response.chat_completion(
- messages_batch=[
- [
- {
- "content": "string",
- "role": "user",
- }
- ]
- ],
- model="model",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- batch_inference = await response.parse()
- assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"])
-
- @parametrize
- async def test_streaming_response_chat_completion(self, async_client: AsyncLlamaStackClient) -> None:
- async with async_client.batch_inference.with_streaming_response.chat_completion(
- messages_batch=[
- [
- {
- "content": "string",
- "role": "user",
- }
- ]
- ],
- model="model",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- batch_inference = await response.parse()
- assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @parametrize
- async def test_method_completion(self, async_client: AsyncLlamaStackClient) -> None:
- batch_inference = await async_client.batch_inference.completion(
- content_batch=["string"],
- model="model",
- )
- assert_matches_type(BatchCompletion, batch_inference, path=["response"])
-
- @parametrize
- async def test_method_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None:
- batch_inference = await async_client.batch_inference.completion(
- content_batch=["string"],
- model="model",
- logprobs={"top_k": 0},
- response_format={
- "json_schema": {"foo": True},
- "type": "json_schema",
- },
- sampling_params={
- "strategy": {"type": "greedy"},
- "max_tokens": 0,
- "repetition_penalty": 0,
- },
- )
- assert_matches_type(BatchCompletion, batch_inference, path=["response"])
-
- @parametrize
- async def test_raw_response_completion(self, async_client: AsyncLlamaStackClient) -> None:
- response = await async_client.batch_inference.with_raw_response.completion(
- content_batch=["string"],
- model="model",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- batch_inference = await response.parse()
- assert_matches_type(BatchCompletion, batch_inference, path=["response"])
-
- @parametrize
- async def test_streaming_response_completion(self, async_client: AsyncLlamaStackClient) -> None:
- async with async_client.batch_inference.with_streaming_response.completion(
- content_batch=["string"],
- model="model",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- batch_inference = await response.parse()
- assert_matches_type(BatchCompletion, batch_inference, path=["response"])
-
- assert cast(Any, response.is_closed) is True
diff --git a/tests/api_resources/test_eval.py b/tests/api_resources/test_eval.py
index 9735b4c4..c519056b 100644
--- a/tests/api_resources/test_eval.py
+++ b/tests/api_resources/test_eval.py
@@ -53,6 +53,7 @@ def test_method_evaluate_rows_with_all_params(self, client: LlamaStackClient) ->
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
"type": "model",
"system_message": {
@@ -185,6 +186,7 @@ def test_method_evaluate_rows_alpha_with_all_params(self, client: LlamaStackClie
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
"type": "model",
"system_message": {
@@ -315,6 +317,7 @@ def test_method_run_eval_with_all_params(self, client: LlamaStackClient) -> None
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
"type": "model",
"system_message": {
@@ -437,6 +440,7 @@ def test_method_run_eval_alpha_with_all_params(self, client: LlamaStackClient) -
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
"type": "model",
"system_message": {
@@ -565,6 +569,7 @@ async def test_method_evaluate_rows_with_all_params(self, async_client: AsyncLla
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
"type": "model",
"system_message": {
@@ -697,6 +702,7 @@ async def test_method_evaluate_rows_alpha_with_all_params(self, async_client: As
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
"type": "model",
"system_message": {
@@ -827,6 +833,7 @@ async def test_method_run_eval_with_all_params(self, async_client: AsyncLlamaSta
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
"type": "model",
"system_message": {
@@ -949,6 +956,7 @@ async def test_method_run_eval_alpha_with_all_params(self, async_client: AsyncLl
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
"type": "model",
"system_message": {
diff --git a/tests/api_resources/test_inference.py b/tests/api_resources/test_inference.py
index 4d078587..d876ae56 100644
--- a/tests/api_resources/test_inference.py
+++ b/tests/api_resources/test_inference.py
@@ -12,8 +12,9 @@
from llama_stack_client.types import (
CompletionResponse,
EmbeddingsResponse,
+ InferenceBatchChatCompletionResponse,
)
-from llama_stack_client.types.shared import ChatCompletionResponse
+from llama_stack_client.types.shared import BatchCompletion, ChatCompletionResponse
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -21,6 +22,160 @@
class TestInference:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
+ @parametrize
+ def test_method_batch_chat_completion(self, client: LlamaStackClient) -> None:
+ inference = client.inference.batch_chat_completion(
+ messages_batch=[
+ [
+ {
+ "content": "string",
+ "role": "user",
+ }
+ ]
+ ],
+ model_id="model_id",
+ )
+ assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"])
+
+ @parametrize
+ def test_method_batch_chat_completion_with_all_params(self, client: LlamaStackClient) -> None:
+ inference = client.inference.batch_chat_completion(
+ messages_batch=[
+ [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ }
+ ]
+ ],
+ model_id="model_id",
+ logprobs={"top_k": 0},
+ response_format={
+ "json_schema": {"foo": True},
+ "type": "json_schema",
+ },
+ sampling_params={
+ "strategy": {"type": "greedy"},
+ "max_tokens": 0,
+ "repetition_penalty": 0,
+ "stop": ["string"],
+ },
+ tool_config={
+ "system_message_behavior": "append",
+ "tool_choice": "auto",
+ "tool_prompt_format": "json",
+ },
+ tools=[
+ {
+ "tool_name": "brave_search",
+ "description": "description",
+ "parameters": {
+ "foo": {
+ "param_type": "param_type",
+ "default": True,
+ "description": "description",
+ "required": True,
+ }
+ },
+ }
+ ],
+ )
+ assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"])
+
+ @parametrize
+ def test_raw_response_batch_chat_completion(self, client: LlamaStackClient) -> None:
+ response = client.inference.with_raw_response.batch_chat_completion(
+ messages_batch=[
+ [
+ {
+ "content": "string",
+ "role": "user",
+ }
+ ]
+ ],
+ model_id="model_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ inference = response.parse()
+ assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"])
+
+ @parametrize
+ def test_streaming_response_batch_chat_completion(self, client: LlamaStackClient) -> None:
+ with client.inference.with_streaming_response.batch_chat_completion(
+ messages_batch=[
+ [
+ {
+ "content": "string",
+ "role": "user",
+ }
+ ]
+ ],
+ model_id="model_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ inference = response.parse()
+ assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ def test_method_batch_completion(self, client: LlamaStackClient) -> None:
+ inference = client.inference.batch_completion(
+ content_batch=["string"],
+ model_id="model_id",
+ )
+ assert_matches_type(BatchCompletion, inference, path=["response"])
+
+ @parametrize
+ def test_method_batch_completion_with_all_params(self, client: LlamaStackClient) -> None:
+ inference = client.inference.batch_completion(
+ content_batch=["string"],
+ model_id="model_id",
+ logprobs={"top_k": 0},
+ response_format={
+ "json_schema": {"foo": True},
+ "type": "json_schema",
+ },
+ sampling_params={
+ "strategy": {"type": "greedy"},
+ "max_tokens": 0,
+ "repetition_penalty": 0,
+ "stop": ["string"],
+ },
+ )
+ assert_matches_type(BatchCompletion, inference, path=["response"])
+
+ @parametrize
+ def test_raw_response_batch_completion(self, client: LlamaStackClient) -> None:
+ response = client.inference.with_raw_response.batch_completion(
+ content_batch=["string"],
+ model_id="model_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ inference = response.parse()
+ assert_matches_type(BatchCompletion, inference, path=["response"])
+
+ @parametrize
+ def test_streaming_response_batch_completion(self, client: LlamaStackClient) -> None:
+ with client.inference.with_streaming_response.batch_completion(
+ content_batch=["string"],
+ model_id="model_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ inference = response.parse()
+ assert_matches_type(BatchCompletion, inference, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
@parametrize
def test_method_chat_completion_overload_1(self, client: LlamaStackClient) -> None:
inference = client.inference.chat_completion(
@@ -54,6 +209,7 @@ def test_method_chat_completion_with_all_params_overload_1(self, client: LlamaSt
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
stream=False,
tool_choice="auto",
@@ -151,6 +307,7 @@ def test_method_chat_completion_with_all_params_overload_2(self, client: LlamaSt
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
tool_choice="auto",
tool_config={
@@ -235,6 +392,7 @@ def test_method_completion_with_all_params_overload_1(self, client: LlamaStackCl
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
stream=False,
)
@@ -290,6 +448,7 @@ def test_method_completion_with_all_params_overload_2(self, client: LlamaStackCl
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
)
inference_stream.response.close()
@@ -370,6 +529,160 @@ def test_streaming_response_embeddings(self, client: LlamaStackClient) -> None:
class TestAsyncInference:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+ @parametrize
+ async def test_method_batch_chat_completion(self, async_client: AsyncLlamaStackClient) -> None:
+ inference = await async_client.inference.batch_chat_completion(
+ messages_batch=[
+ [
+ {
+ "content": "string",
+ "role": "user",
+ }
+ ]
+ ],
+ model_id="model_id",
+ )
+ assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"])
+
+ @parametrize
+ async def test_method_batch_chat_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None:
+ inference = await async_client.inference.batch_chat_completion(
+ messages_batch=[
+ [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ }
+ ]
+ ],
+ model_id="model_id",
+ logprobs={"top_k": 0},
+ response_format={
+ "json_schema": {"foo": True},
+ "type": "json_schema",
+ },
+ sampling_params={
+ "strategy": {"type": "greedy"},
+ "max_tokens": 0,
+ "repetition_penalty": 0,
+ "stop": ["string"],
+ },
+ tool_config={
+ "system_message_behavior": "append",
+ "tool_choice": "auto",
+ "tool_prompt_format": "json",
+ },
+ tools=[
+ {
+ "tool_name": "brave_search",
+ "description": "description",
+ "parameters": {
+ "foo": {
+ "param_type": "param_type",
+ "default": True,
+ "description": "description",
+ "required": True,
+ }
+ },
+ }
+ ],
+ )
+ assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"])
+
+ @parametrize
+ async def test_raw_response_batch_chat_completion(self, async_client: AsyncLlamaStackClient) -> None:
+ response = await async_client.inference.with_raw_response.batch_chat_completion(
+ messages_batch=[
+ [
+ {
+ "content": "string",
+ "role": "user",
+ }
+ ]
+ ],
+ model_id="model_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ inference = await response.parse()
+ assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_batch_chat_completion(self, async_client: AsyncLlamaStackClient) -> None:
+ async with async_client.inference.with_streaming_response.batch_chat_completion(
+ messages_batch=[
+ [
+ {
+ "content": "string",
+ "role": "user",
+ }
+ ]
+ ],
+ model_id="model_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ inference = await response.parse()
+ assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ async def test_method_batch_completion(self, async_client: AsyncLlamaStackClient) -> None:
+ inference = await async_client.inference.batch_completion(
+ content_batch=["string"],
+ model_id="model_id",
+ )
+ assert_matches_type(BatchCompletion, inference, path=["response"])
+
+ @parametrize
+ async def test_method_batch_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None:
+ inference = await async_client.inference.batch_completion(
+ content_batch=["string"],
+ model_id="model_id",
+ logprobs={"top_k": 0},
+ response_format={
+ "json_schema": {"foo": True},
+ "type": "json_schema",
+ },
+ sampling_params={
+ "strategy": {"type": "greedy"},
+ "max_tokens": 0,
+ "repetition_penalty": 0,
+ "stop": ["string"],
+ },
+ )
+ assert_matches_type(BatchCompletion, inference, path=["response"])
+
+ @parametrize
+ async def test_raw_response_batch_completion(self, async_client: AsyncLlamaStackClient) -> None:
+ response = await async_client.inference.with_raw_response.batch_completion(
+ content_batch=["string"],
+ model_id="model_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ inference = await response.parse()
+ assert_matches_type(BatchCompletion, inference, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_batch_completion(self, async_client: AsyncLlamaStackClient) -> None:
+ async with async_client.inference.with_streaming_response.batch_completion(
+ content_batch=["string"],
+ model_id="model_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ inference = await response.parse()
+ assert_matches_type(BatchCompletion, inference, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
@parametrize
async def test_method_chat_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None:
inference = await async_client.inference.chat_completion(
@@ -403,6 +716,7 @@ async def test_method_chat_completion_with_all_params_overload_1(self, async_cli
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
stream=False,
tool_choice="auto",
@@ -500,6 +814,7 @@ async def test_method_chat_completion_with_all_params_overload_2(self, async_cli
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
tool_choice="auto",
tool_config={
@@ -584,6 +899,7 @@ async def test_method_completion_with_all_params_overload_1(self, async_client:
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
stream=False,
)
@@ -639,6 +955,7 @@ async def test_method_completion_with_all_params_overload_2(self, async_client:
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 0,
+ "stop": ["string"],
},
)
await inference_stream.response.aclose()
diff --git a/tests/api_resources/test_tool_runtime.py b/tests/api_resources/test_tool_runtime.py
index ca4279bb..b13e8c1f 100644
--- a/tests/api_resources/test_tool_runtime.py
+++ b/tests/api_resources/test_tool_runtime.py
@@ -10,10 +10,9 @@
from tests.utils import assert_matches_type
from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient
from llama_stack_client.types import (
- ToolDef,
ToolInvocationResult,
+ ToolRuntimeListToolsResponse,
)
-from llama_stack_client._decoders.jsonl import JSONLDecoder, AsyncJSONLDecoder
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -55,22 +54,19 @@ def test_streaming_response_invoke_tool(self, client: LlamaStackClient) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet")
@parametrize
def test_method_list_tools(self, client: LlamaStackClient) -> None:
tool_runtime = client.tool_runtime.list_tools()
- assert_matches_type(JSONLDecoder[ToolDef], tool_runtime, path=["response"])
+ assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"])
- @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet")
@parametrize
def test_method_list_tools_with_all_params(self, client: LlamaStackClient) -> None:
tool_runtime = client.tool_runtime.list_tools(
mcp_endpoint={"uri": "uri"},
tool_group_id="tool_group_id",
)
- assert_matches_type(JSONLDecoder[ToolDef], tool_runtime, path=["response"])
+ assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"])
- @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet")
@parametrize
def test_raw_response_list_tools(self, client: LlamaStackClient) -> None:
response = client.tool_runtime.with_raw_response.list_tools()
@@ -78,9 +74,8 @@ def test_raw_response_list_tools(self, client: LlamaStackClient) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
tool_runtime = response.parse()
- assert_matches_type(JSONLDecoder[ToolDef], tool_runtime, path=["response"])
+ assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"])
- @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet")
@parametrize
def test_streaming_response_list_tools(self, client: LlamaStackClient) -> None:
with client.tool_runtime.with_streaming_response.list_tools() as response:
@@ -88,7 +83,7 @@ def test_streaming_response_list_tools(self, client: LlamaStackClient) -> None:
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
tool_runtime = response.parse()
- assert_matches_type(JSONLDecoder[ToolDef], tool_runtime, path=["response"])
+ assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -130,22 +125,19 @@ async def test_streaming_response_invoke_tool(self, async_client: AsyncLlamaStac
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet")
@parametrize
async def test_method_list_tools(self, async_client: AsyncLlamaStackClient) -> None:
tool_runtime = await async_client.tool_runtime.list_tools()
- assert_matches_type(AsyncJSONLDecoder[ToolDef], tool_runtime, path=["response"])
+ assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"])
- @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet")
@parametrize
async def test_method_list_tools_with_all_params(self, async_client: AsyncLlamaStackClient) -> None:
tool_runtime = await async_client.tool_runtime.list_tools(
mcp_endpoint={"uri": "uri"},
tool_group_id="tool_group_id",
)
- assert_matches_type(AsyncJSONLDecoder[ToolDef], tool_runtime, path=["response"])
+ assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"])
- @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet")
@parametrize
async def test_raw_response_list_tools(self, async_client: AsyncLlamaStackClient) -> None:
response = await async_client.tool_runtime.with_raw_response.list_tools()
@@ -153,9 +145,8 @@ async def test_raw_response_list_tools(self, async_client: AsyncLlamaStackClient
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
tool_runtime = await response.parse()
- assert_matches_type(AsyncJSONLDecoder[ToolDef], tool_runtime, path=["response"])
+ assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"])
- @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet")
@parametrize
async def test_streaming_response_list_tools(self, async_client: AsyncLlamaStackClient) -> None:
async with async_client.tool_runtime.with_streaming_response.list_tools() as response:
@@ -163,6 +154,6 @@ async def test_streaming_response_list_tools(self, async_client: AsyncLlamaStack
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
tool_runtime = await response.parse()
- assert_matches_type(AsyncJSONLDecoder[ToolDef], tool_runtime, path=["response"])
+ assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"])
assert cast(Any, response.is_closed) is True
diff --git a/tests/decoders/test_jsonl.py b/tests/decoders/test_jsonl.py
deleted file mode 100644
index 54af8e49..00000000
--- a/tests/decoders/test_jsonl.py
+++ /dev/null
@@ -1,88 +0,0 @@
-from __future__ import annotations
-
-from typing import Any, Iterator, AsyncIterator
-from typing_extensions import TypeVar
-
-import httpx
-import pytest
-
-from llama_stack_client._decoders.jsonl import JSONLDecoder, AsyncJSONLDecoder
-
-_T = TypeVar("_T")
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
-async def test_basic(sync: bool) -> None:
- def body() -> Iterator[bytes]:
- yield b'{"foo":true}\n'
- yield b'{"bar":false}\n'
-
- iterator = make_jsonl_iterator(
- content=body(),
- sync=sync,
- line_type=object,
- )
-
- assert await iter_next(iterator) == {"foo": True}
- assert await iter_next(iterator) == {"bar": False}
-
- await assert_empty_iter(iterator)
-
-
-@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
-async def test_new_lines_in_json(
- sync: bool,
-) -> None:
- def body() -> Iterator[bytes]:
- yield b'{"content":"Hello, world!\\nHow are you doing?"}'
-
- iterator = make_jsonl_iterator(content=body(), sync=sync, line_type=object)
-
- assert await iter_next(iterator) == {"content": "Hello, world!\nHow are you doing?"}
-
-
-@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
-async def test_multi_byte_character_multiple_chunks(
- sync: bool,
-) -> None:
- def body() -> Iterator[bytes]:
- yield b'{"content":"'
- # bytes taken from the string 'известни' and arbitrarily split
- # so that some multi-byte characters span multiple chunks
- yield b"\xd0"
- yield b"\xb8\xd0\xb7\xd0"
- yield b"\xb2\xd0\xb5\xd1\x81\xd1\x82\xd0\xbd\xd0\xb8"
- yield b'"}\n'
-
- iterator = make_jsonl_iterator(content=body(), sync=sync, line_type=object)
-
- assert await iter_next(iterator) == {"content": "известни"}
-
-
-async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]:
- for chunk in iter:
- yield chunk
-
-
-async def iter_next(iter: Iterator[_T] | AsyncIterator[_T]) -> _T:
- if isinstance(iter, AsyncIterator):
- return await iter.__anext__()
- return next(iter)
-
-
-async def assert_empty_iter(decoder: JSONLDecoder[Any] | AsyncJSONLDecoder[Any]) -> None:
- with pytest.raises((StopAsyncIteration, RuntimeError)):
- await iter_next(decoder)
-
-
-def make_jsonl_iterator(
- content: Iterator[bytes],
- *,
- sync: bool,
- line_type: type[_T],
-) -> JSONLDecoder[_T] | AsyncJSONLDecoder[_T]:
- if sync:
- return JSONLDecoder(line_type=line_type, raw_iterator=content, http_response=httpx.Response(200))
-
- return AsyncJSONLDecoder(line_type=line_type, raw_iterator=to_aiter(content), http_response=httpx.Response(200))
diff --git a/tests/test_client.py b/tests/test_client.py
index 8a2992af..7ad6e189 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -313,9 +313,6 @@ def test_default_headers_option(self) -> None:
assert request.headers.get("x-foo") == "stainless"
assert request.headers.get("x-stainless-lang") == "my-overriding-header"
- def test_validate_headers(self) -> None:
- client = LlamaStackClient(base_url=base_url, _strict_response_validation=True)
-
def test_default_query_option(self) -> None:
client = LlamaStackClient(
base_url=base_url, _strict_response_validation=True, default_query={"query_param": "bar"}
@@ -1103,9 +1100,6 @@ def test_default_headers_option(self) -> None:
assert request.headers.get("x-foo") == "stainless"
assert request.headers.get("x-stainless-lang") == "my-overriding-header"
- def test_validate_headers(self) -> None:
- client = AsyncLlamaStackClient(base_url=base_url, _strict_response_validation=True)
-
def test_default_query_option(self) -> None:
client = AsyncLlamaStackClient(
base_url=base_url, _strict_response_validation=True, default_query={"query_param": "bar"}
@@ -1648,7 +1642,7 @@ def test_get_platform(self) -> None:
import threading
from llama_stack_client._utils import asyncify
- from llama_stack_client._base_client import get_platform
+ from llama_stack_client._base_client import get_platform
async def test_main() -> None:
result = await asyncify(get_platform)()
diff --git a/tests/test_transform.py b/tests/test_transform.py
index 8ceafb36..b6eb411d 100644
--- a/tests/test_transform.py
+++ b/tests/test_transform.py
@@ -8,7 +8,7 @@
import pytest
-from llama_stack_client._types import Base64FileInput
+from llama_stack_client._types import NOT_GIVEN, Base64FileInput
from llama_stack_client._utils import (
PropertyInfo,
transform as _transform,
@@ -432,3 +432,22 @@ async def test_base64_file_input(use_async: bool) -> None:
assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == {
"foo": "SGVsbG8sIHdvcmxkIQ=="
} # type: ignore[comparison-overlap]
+
+
+@parametrize
+@pytest.mark.asyncio
+async def test_transform_skipping(use_async: bool) -> None:
+ # lists of ints are left as-is
+ data = [1, 2, 3]
+ assert await transform(data, List[int], use_async) is data
+
+ # iterables of ints are converted to a list
+ data = iter([1, 2, 3])
+ assert await transform(data, Iterable[int], use_async) == [1, 2, 3]
+
+
+@parametrize
+@pytest.mark.asyncio
+async def test_strips_notgiven(use_async: bool) -> None:
+ assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"}
+ assert await transform({"foo_bar": NOT_GIVEN}, Foo1, use_async) == {}