From f3d139ad256789af1f7502dce37c9fa0aa6b5669 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 7 Apr 2025 19:36:02 -0700 Subject: [PATCH] feat: add updated batch inference types --- src/llama_stack_client/_client.py | 9 - src/llama_stack_client/_decoders/jsonl.py | 123 ------- src/llama_stack_client/_models.py | 2 +- src/llama_stack_client/_response.py | 22 -- src/llama_stack_client/_utils/_transform.py | 49 ++- src/llama_stack_client/_utils/_typing.py | 2 + .../lib/inference/event_logger.py | 14 +- src/llama_stack_client/resources/__init__.py | 14 - .../resources/agents/turn.py | 16 +- .../resources/batch_inference.py | 326 ------------------ src/llama_stack_client/resources/datasets.py | 20 +- src/llama_stack_client/resources/inference.py | 244 ++++++++++++- .../resources/tool_runtime/tool_runtime.py | 20 +- src/llama_stack_client/types/__init__.py | 11 +- .../batch_inference_chat_completion_params.py | 51 --- .../types/dataset_iterrows_response.py | 11 +- .../inference_batch_chat_completion_params.py | 74 ++++ ...ference_batch_chat_completion_response.py} | 4 +- ...y => inference_batch_completion_params.py} | 7 +- src/llama_stack_client/types/job.py | 2 +- .../post_training/job_status_response.py | 2 +- .../types/shared/agent_config.py | 1 + .../types/shared/sampling_params.py | 19 +- .../types/shared_params/agent_config.py | 1 + .../types/shared_params/sampling_params.py | 19 +- .../types/tool_runtime_list_tools_response.py | 10 + tests/api_resources/test_agents.py | 2 + tests/api_resources/test_batch_inference.py | 323 ----------------- tests/api_resources/test_eval.py | 8 + tests/api_resources/test_inference.py | 319 ++++++++++++++++- tests/api_resources/test_tool_runtime.py | 27 +- tests/decoders/test_jsonl.py | 88 ----- tests/test_client.py | 8 +- tests/test_transform.py | 21 +- 34 files changed, 840 insertions(+), 1029 deletions(-) delete mode 100644 src/llama_stack_client/_decoders/jsonl.py delete mode 100644 src/llama_stack_client/resources/batch_inference.py delete mode 100644 src/llama_stack_client/types/batch_inference_chat_completion_params.py create mode 100644 src/llama_stack_client/types/inference_batch_chat_completion_params.py rename src/llama_stack_client/types/{batch_inference_chat_completion_response.py => inference_batch_chat_completion_response.py} (70%) rename src/llama_stack_client/types/{batch_inference_completion_params.py => inference_batch_completion_params.py} (80%) create mode 100644 src/llama_stack_client/types/tool_runtime_list_tools_response.py delete mode 100644 tests/api_resources/test_batch_inference.py delete mode 100644 tests/decoders/test_jsonl.py diff --git a/src/llama_stack_client/_client.py b/src/llama_stack_client/_client.py index 00922520..7066ae2a 100644 --- a/src/llama_stack_client/_client.py +++ b/src/llama_stack_client/_client.py @@ -41,7 +41,6 @@ benchmarks, toolgroups, vector_dbs, - batch_inference, scoring_functions, synthetic_data_generation, ) @@ -74,7 +73,6 @@ class LlamaStackClient(SyncAPIClient): tools: tools.ToolsResource tool_runtime: tool_runtime.ToolRuntimeResource agents: agents.AgentsResource - batch_inference: batch_inference.BatchInferenceResource datasets: datasets.DatasetsResource eval: eval.EvalResource inspect: inspect.InspectResource @@ -155,7 +153,6 @@ def __init__( self.tools = tools.ToolsResource(self) self.tool_runtime = tool_runtime.ToolRuntimeResource(self) self.agents = agents.AgentsResource(self) - self.batch_inference = batch_inference.BatchInferenceResource(self) self.datasets = datasets.DatasetsResource(self) self.eval = eval.EvalResource(self) self.inspect = inspect.InspectResource(self) @@ -288,7 +285,6 @@ class AsyncLlamaStackClient(AsyncAPIClient): tools: tools.AsyncToolsResource tool_runtime: tool_runtime.AsyncToolRuntimeResource agents: agents.AsyncAgentsResource - batch_inference: batch_inference.AsyncBatchInferenceResource datasets: datasets.AsyncDatasetsResource eval: eval.AsyncEvalResource inspect: inspect.AsyncInspectResource @@ -369,7 +365,6 @@ def __init__( self.tools = tools.AsyncToolsResource(self) self.tool_runtime = tool_runtime.AsyncToolRuntimeResource(self) self.agents = agents.AsyncAgentsResource(self) - self.batch_inference = batch_inference.AsyncBatchInferenceResource(self) self.datasets = datasets.AsyncDatasetsResource(self) self.eval = eval.AsyncEvalResource(self) self.inspect = inspect.AsyncInspectResource(self) @@ -503,7 +498,6 @@ def __init__(self, client: LlamaStackClient) -> None: self.tools = tools.ToolsResourceWithRawResponse(client.tools) self.tool_runtime = tool_runtime.ToolRuntimeResourceWithRawResponse(client.tool_runtime) self.agents = agents.AgentsResourceWithRawResponse(client.agents) - self.batch_inference = batch_inference.BatchInferenceResourceWithRawResponse(client.batch_inference) self.datasets = datasets.DatasetsResourceWithRawResponse(client.datasets) self.eval = eval.EvalResourceWithRawResponse(client.eval) self.inspect = inspect.InspectResourceWithRawResponse(client.inspect) @@ -531,7 +525,6 @@ def __init__(self, client: AsyncLlamaStackClient) -> None: self.tools = tools.AsyncToolsResourceWithRawResponse(client.tools) self.tool_runtime = tool_runtime.AsyncToolRuntimeResourceWithRawResponse(client.tool_runtime) self.agents = agents.AsyncAgentsResourceWithRawResponse(client.agents) - self.batch_inference = batch_inference.AsyncBatchInferenceResourceWithRawResponse(client.batch_inference) self.datasets = datasets.AsyncDatasetsResourceWithRawResponse(client.datasets) self.eval = eval.AsyncEvalResourceWithRawResponse(client.eval) self.inspect = inspect.AsyncInspectResourceWithRawResponse(client.inspect) @@ -561,7 +554,6 @@ def __init__(self, client: LlamaStackClient) -> None: self.tools = tools.ToolsResourceWithStreamingResponse(client.tools) self.tool_runtime = tool_runtime.ToolRuntimeResourceWithStreamingResponse(client.tool_runtime) self.agents = agents.AgentsResourceWithStreamingResponse(client.agents) - self.batch_inference = batch_inference.BatchInferenceResourceWithStreamingResponse(client.batch_inference) self.datasets = datasets.DatasetsResourceWithStreamingResponse(client.datasets) self.eval = eval.EvalResourceWithStreamingResponse(client.eval) self.inspect = inspect.InspectResourceWithStreamingResponse(client.inspect) @@ -591,7 +583,6 @@ def __init__(self, client: AsyncLlamaStackClient) -> None: self.tools = tools.AsyncToolsResourceWithStreamingResponse(client.tools) self.tool_runtime = tool_runtime.AsyncToolRuntimeResourceWithStreamingResponse(client.tool_runtime) self.agents = agents.AsyncAgentsResourceWithStreamingResponse(client.agents) - self.batch_inference = batch_inference.AsyncBatchInferenceResourceWithStreamingResponse(client.batch_inference) self.datasets = datasets.AsyncDatasetsResourceWithStreamingResponse(client.datasets) self.eval = eval.AsyncEvalResourceWithStreamingResponse(client.eval) self.inspect = inspect.AsyncInspectResourceWithStreamingResponse(client.inspect) diff --git a/src/llama_stack_client/_decoders/jsonl.py b/src/llama_stack_client/_decoders/jsonl.py deleted file mode 100644 index ac5ac74f..00000000 --- a/src/llama_stack_client/_decoders/jsonl.py +++ /dev/null @@ -1,123 +0,0 @@ -from __future__ import annotations - -import json -from typing_extensions import Generic, TypeVar, Iterator, AsyncIterator - -import httpx - -from .._models import construct_type_unchecked - -_T = TypeVar("_T") - - -class JSONLDecoder(Generic[_T]): - """A decoder for [JSON Lines](https://jsonlines.org) format. - - This class provides an iterator over a byte-iterator that parses each JSON Line - into a given type. - """ - - http_response: httpx.Response - """The HTTP response this decoder was constructed from""" - - def __init__( - self, - *, - raw_iterator: Iterator[bytes], - line_type: type[_T], - http_response: httpx.Response, - ) -> None: - super().__init__() - self.http_response = http_response - self._raw_iterator = raw_iterator - self._line_type = line_type - self._iterator = self.__decode__() - - def close(self) -> None: - """Close the response body stream. - - This is called automatically if you consume the entire stream. - """ - self.http_response.close() - - def __decode__(self) -> Iterator[_T]: - buf = b"" - for chunk in self._raw_iterator: - for line in chunk.splitlines(keepends=True): - buf += line - if buf.endswith((b"\r", b"\n", b"\r\n")): - yield construct_type_unchecked( - value=json.loads(buf), - type_=self._line_type, - ) - buf = b"" - - # flush - if buf: - yield construct_type_unchecked( - value=json.loads(buf), - type_=self._line_type, - ) - - def __next__(self) -> _T: - return self._iterator.__next__() - - def __iter__(self) -> Iterator[_T]: - for item in self._iterator: - yield item - - -class AsyncJSONLDecoder(Generic[_T]): - """A decoder for [JSON Lines](https://jsonlines.org) format. - - This class provides an async iterator over a byte-iterator that parses each JSON Line - into a given type. - """ - - http_response: httpx.Response - - def __init__( - self, - *, - raw_iterator: AsyncIterator[bytes], - line_type: type[_T], - http_response: httpx.Response, - ) -> None: - super().__init__() - self.http_response = http_response - self._raw_iterator = raw_iterator - self._line_type = line_type - self._iterator = self.__decode__() - - async def close(self) -> None: - """Close the response body stream. - - This is called automatically if you consume the entire stream. - """ - await self.http_response.aclose() - - async def __decode__(self) -> AsyncIterator[_T]: - buf = b"" - async for chunk in self._raw_iterator: - for line in chunk.splitlines(keepends=True): - buf += line - if buf.endswith((b"\r", b"\n", b"\r\n")): - yield construct_type_unchecked( - value=json.loads(buf), - type_=self._line_type, - ) - buf = b"" - - # flush - if buf: - yield construct_type_unchecked( - value=json.loads(buf), - type_=self._line_type, - ) - - async def __anext__(self) -> _T: - return await self._iterator.__anext__() - - async def __aiter__(self) -> AsyncIterator[_T]: - async for item in self._iterator: - yield item diff --git a/src/llama_stack_client/_models.py b/src/llama_stack_client/_models.py index b51a1bf5..34935716 100644 --- a/src/llama_stack_client/_models.py +++ b/src/llama_stack_client/_models.py @@ -681,7 +681,7 @@ def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None: setattr(typ, "__pydantic_config__", config) # noqa: B010 -# our use of subclasssing here causes weirdness for type checkers, +# our use of subclassing here causes weirdness for type checkers, # so we just pretend that we don't subclass if TYPE_CHECKING: GenericModel = BaseModel diff --git a/src/llama_stack_client/_response.py b/src/llama_stack_client/_response.py index ea35182f..1938ae74 100644 --- a/src/llama_stack_client/_response.py +++ b/src/llama_stack_client/_response.py @@ -30,7 +30,6 @@ from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type from ._exceptions import LlamaStackClientError, APIResponseValidationError -from ._decoders.jsonl import JSONLDecoder, AsyncJSONLDecoder if TYPE_CHECKING: from ._models import FinalRequestOptions @@ -139,27 +138,6 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: origin = get_origin(cast_to) or cast_to - if inspect.isclass(origin): - if issubclass(cast(Any, origin), JSONLDecoder): - return cast( - R, - cast("type[JSONLDecoder[Any]]", cast_to)( - raw_iterator=self.http_response.iter_bytes(chunk_size=64), - line_type=extract_type_arg(cast_to, 0), - http_response=self.http_response, - ), - ) - - if issubclass(cast(Any, origin), AsyncJSONLDecoder): - return cast( - R, - cast("type[AsyncJSONLDecoder[Any]]", cast_to)( - raw_iterator=self.http_response.aiter_bytes(chunk_size=64), - line_type=extract_type_arg(cast_to, 0), - http_response=self.http_response, - ), - ) - if self._is_sse_stream: if to: if not is_stream_class_type(to): diff --git a/src/llama_stack_client/_utils/_transform.py b/src/llama_stack_client/_utils/_transform.py index 18afd9d8..b0cc20a7 100644 --- a/src/llama_stack_client/_utils/_transform.py +++ b/src/llama_stack_client/_utils/_transform.py @@ -5,13 +5,15 @@ import pathlib from typing import Any, Mapping, TypeVar, cast from datetime import date, datetime -from typing_extensions import Literal, get_args, override, get_type_hints +from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints import anyio import pydantic from ._utils import ( is_list, + is_given, + lru_cache, is_mapping, is_iterable, ) @@ -108,6 +110,7 @@ class Params(TypedDict, total=False): return cast(_T, transformed) +@lru_cache(maxsize=8096) def _get_annotated_type(type_: type) -> type | None: """If the given type is an `Annotated` type then it is returned, if not `None` is returned. @@ -126,7 +129,7 @@ def _get_annotated_type(type_: type) -> type | None: def _maybe_transform_key(key: str, type_: type) -> str: """Transform the given `data` based on the annotations provided in `type_`. - Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata. + Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata. """ annotated_type = _get_annotated_type(type_) if annotated_type is None: @@ -142,6 +145,10 @@ def _maybe_transform_key(key: str, type_: type) -> str: return key +def _no_transform_needed(annotation: type) -> bool: + return annotation == float or annotation == int + + def _transform_recursive( data: object, *, @@ -184,6 +191,15 @@ def _transform_recursive( return cast(object, data) inner_type = extract_type_arg(stripped_type, 0) + if _no_transform_needed(inner_type): + # for some types there is no need to transform anything, so we can get a small + # perf boost from skipping that work. + # + # but we still need to convert to a list to ensure the data is json-serializable + if is_list(data): + return data + return list(data) + return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] if is_union_type(stripped_type): @@ -245,6 +261,11 @@ def _transform_typeddict( result: dict[str, object] = {} annotations = get_type_hints(expected_type, include_extras=True) for key, value in data.items(): + if not is_given(value): + # we don't need to include `NotGiven` values here as they'll + # be stripped out before the request is sent anyway + continue + type_ = annotations.get(key) if type_ is None: # we do not have a type annotation for this field, leave it as is @@ -332,6 +353,15 @@ async def _async_transform_recursive( return cast(object, data) inner_type = extract_type_arg(stripped_type, 0) + if _no_transform_needed(inner_type): + # for some types there is no need to transform anything, so we can get a small + # perf boost from skipping that work. + # + # but we still need to convert to a list to ensure the data is json-serializable + if is_list(data): + return data + return list(data) + return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] if is_union_type(stripped_type): @@ -393,6 +423,11 @@ async def _async_transform_typeddict( result: dict[str, object] = {} annotations = get_type_hints(expected_type, include_extras=True) for key, value in data.items(): + if not is_given(value): + # we don't need to include `NotGiven` values here as they'll + # be stripped out before the request is sent anyway + continue + type_ = annotations.get(key) if type_ is None: # we do not have a type annotation for this field, leave it as is @@ -400,3 +435,13 @@ async def _async_transform_typeddict( else: result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) return result + + +@lru_cache(maxsize=8096) +def get_type_hints( + obj: Any, + globalns: dict[str, Any] | None = None, + localns: Mapping[str, Any] | None = None, + include_extras: bool = False, +) -> dict[str, Any]: + return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras) diff --git a/src/llama_stack_client/_utils/_typing.py b/src/llama_stack_client/_utils/_typing.py index 278749b1..1958820f 100644 --- a/src/llama_stack_client/_utils/_typing.py +++ b/src/llama_stack_client/_utils/_typing.py @@ -13,6 +13,7 @@ get_origin, ) +from ._utils import lru_cache from .._types import InheritsGeneric from .._compat import is_union as _is_union @@ -66,6 +67,7 @@ def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]: # Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] +@lru_cache(maxsize=8096) def strip_annotated_type(typ: type) -> type: if is_required_type(typ) or is_annotated_type(typ): return strip_annotated_type(cast(type, get_args(typ)[0])) diff --git a/src/llama_stack_client/lib/inference/event_logger.py b/src/llama_stack_client/lib/inference/event_logger.py index e8c63a1f..f74f1bd4 100644 --- a/src/llama_stack_client/lib/inference/event_logger.py +++ b/src/llama_stack_client/lib/inference/event_logger.py @@ -22,12 +22,24 @@ def print(self, flush=True): class InferenceStreamLogEventPrinter: + def __init__(self): + self.is_thinking = False + def yield_printable_events(self, chunk): event = chunk.event if event.event_type == "start": yield InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="") elif event.event_type == "progress": - yield InferenceStreamPrintableEvent(event.delta.text, color="yellow", end="") + if event.delta.type == "reasoning": + if not self.is_thinking: + yield InferenceStreamPrintableEvent(" ", color="magenta", end="") + self.is_thinking = True + yield InferenceStreamPrintableEvent(event.delta.reasoning, color="magenta", end="") + else: + if self.is_thinking: + yield InferenceStreamPrintableEvent("", color="magenta", end="") + self.is_thinking = False + yield InferenceStreamPrintableEvent(event.delta.text, color="yellow", end="") elif event.event_type == "complete": yield InferenceStreamPrintableEvent("") diff --git a/src/llama_stack_client/resources/__init__.py b/src/llama_stack_client/resources/__init__.py index 865d77e0..0e3373dc 100644 --- a/src/llama_stack_client/resources/__init__.py +++ b/src/llama_stack_client/resources/__init__.py @@ -152,14 +152,6 @@ PostTrainingResourceWithStreamingResponse, AsyncPostTrainingResourceWithStreamingResponse, ) -from .batch_inference import ( - BatchInferenceResource, - AsyncBatchInferenceResource, - BatchInferenceResourceWithRawResponse, - AsyncBatchInferenceResourceWithRawResponse, - BatchInferenceResourceWithStreamingResponse, - AsyncBatchInferenceResourceWithStreamingResponse, -) from .scoring_functions import ( ScoringFunctionsResource, AsyncScoringFunctionsResource, @@ -202,12 +194,6 @@ "AsyncAgentsResourceWithRawResponse", "AgentsResourceWithStreamingResponse", "AsyncAgentsResourceWithStreamingResponse", - "BatchInferenceResource", - "AsyncBatchInferenceResource", - "BatchInferenceResourceWithRawResponse", - "AsyncBatchInferenceResourceWithRawResponse", - "BatchInferenceResourceWithStreamingResponse", - "AsyncBatchInferenceResourceWithStreamingResponse", "DatasetsResource", "AsyncDatasetsResource", "DatasetsResourceWithRawResponse", diff --git a/src/llama_stack_client/resources/agents/turn.py b/src/llama_stack_client/resources/agents/turn.py index 23c44677..6b1b4ae2 100644 --- a/src/llama_stack_client/resources/agents/turn.py +++ b/src/llama_stack_client/resources/agents/turn.py @@ -218,7 +218,9 @@ def create( "tool_config": tool_config, "toolgroups": toolgroups, }, - turn_create_params.TurnCreateParams, + turn_create_params.TurnCreateParamsStreaming + if stream + else turn_create_params.TurnCreateParamsNonStreaming, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -410,7 +412,9 @@ def resume( "tool_responses": tool_responses, "stream": stream, }, - turn_resume_params.TurnResumeParams, + turn_resume_params.TurnResumeParamsStreaming + if stream + else turn_resume_params.TurnResumeParamsNonStreaming, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -608,7 +612,9 @@ async def create( "tool_config": tool_config, "toolgroups": toolgroups, }, - turn_create_params.TurnCreateParams, + turn_create_params.TurnCreateParamsStreaming + if stream + else turn_create_params.TurnCreateParamsNonStreaming, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -800,7 +806,9 @@ async def resume( "tool_responses": tool_responses, "stream": stream, }, - turn_resume_params.TurnResumeParams, + turn_resume_params.TurnResumeParamsStreaming + if stream + else turn_resume_params.TurnResumeParamsNonStreaming, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/llama_stack_client/resources/batch_inference.py b/src/llama_stack_client/resources/batch_inference.py deleted file mode 100644 index 92d437cb..00000000 --- a/src/llama_stack_client/resources/batch_inference.py +++ /dev/null @@ -1,326 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import List, Iterable -from typing_extensions import Literal - -import httpx - -from ..types import batch_inference_completion_params, batch_inference_chat_completion_params -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from .._utils import ( - maybe_transform, - async_maybe_transform, -) -from .._compat import cached_property -from .._resource import SyncAPIResource, AsyncAPIResource -from .._response import ( - to_raw_response_wrapper, - to_streamed_response_wrapper, - async_to_raw_response_wrapper, - async_to_streamed_response_wrapper, -) -from .._base_client import make_request_options -from ..types.shared_params.message import Message -from ..types.shared.batch_completion import BatchCompletion -from ..types.shared_params.response_format import ResponseFormat -from ..types.shared_params.sampling_params import SamplingParams -from ..types.shared_params.interleaved_content import InterleavedContent -from ..types.batch_inference_chat_completion_response import BatchInferenceChatCompletionResponse - -__all__ = ["BatchInferenceResource", "AsyncBatchInferenceResource"] - - -class BatchInferenceResource(SyncAPIResource): - @cached_property - def with_raw_response(self) -> BatchInferenceResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return BatchInferenceResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> BatchInferenceResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return BatchInferenceResourceWithStreamingResponse(self) - - def chat_completion( - self, - *, - messages_batch: Iterable[Iterable[Message]], - model: str, - logprobs: batch_inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, - response_format: ResponseFormat | NotGiven = NOT_GIVEN, - sampling_params: SamplingParams | NotGiven = NOT_GIVEN, - tool_choice: Literal["auto", "required", "none"] | NotGiven = NOT_GIVEN, - tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN, - tools: Iterable[batch_inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> BatchInferenceChatCompletionResponse: - """ - Args: - response_format: Configuration for JSON schema-guided response generation. - - tool_choice: Whether tool use is required or automatic. This is a hint to the model which may - not be followed. It depends on the Instruction Following capabilities of the - model. - - tool_prompt_format: Prompt format for calling custom / zero shot tools. - - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return self._post( - "/v1/batch-inference/chat-completion", - body=maybe_transform( - { - "messages_batch": messages_batch, - "model": model, - "logprobs": logprobs, - "response_format": response_format, - "sampling_params": sampling_params, - "tool_choice": tool_choice, - "tool_prompt_format": tool_prompt_format, - "tools": tools, - }, - batch_inference_chat_completion_params.BatchInferenceChatCompletionParams, - ), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=BatchInferenceChatCompletionResponse, - ) - - def completion( - self, - *, - content_batch: List[InterleavedContent], - model: str, - logprobs: batch_inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, - response_format: ResponseFormat | NotGiven = NOT_GIVEN, - sampling_params: SamplingParams | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> BatchCompletion: - """ - Args: - response_format: Configuration for JSON schema-guided response generation. - - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return self._post( - "/v1/batch-inference/completion", - body=maybe_transform( - { - "content_batch": content_batch, - "model": model, - "logprobs": logprobs, - "response_format": response_format, - "sampling_params": sampling_params, - }, - batch_inference_completion_params.BatchInferenceCompletionParams, - ), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=BatchCompletion, - ) - - -class AsyncBatchInferenceResource(AsyncAPIResource): - @cached_property - def with_raw_response(self) -> AsyncBatchInferenceResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return AsyncBatchInferenceResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> AsyncBatchInferenceResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return AsyncBatchInferenceResourceWithStreamingResponse(self) - - async def chat_completion( - self, - *, - messages_batch: Iterable[Iterable[Message]], - model: str, - logprobs: batch_inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, - response_format: ResponseFormat | NotGiven = NOT_GIVEN, - sampling_params: SamplingParams | NotGiven = NOT_GIVEN, - tool_choice: Literal["auto", "required", "none"] | NotGiven = NOT_GIVEN, - tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN, - tools: Iterable[batch_inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> BatchInferenceChatCompletionResponse: - """ - Args: - response_format: Configuration for JSON schema-guided response generation. - - tool_choice: Whether tool use is required or automatic. This is a hint to the model which may - not be followed. It depends on the Instruction Following capabilities of the - model. - - tool_prompt_format: Prompt format for calling custom / zero shot tools. - - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return await self._post( - "/v1/batch-inference/chat-completion", - body=await async_maybe_transform( - { - "messages_batch": messages_batch, - "model": model, - "logprobs": logprobs, - "response_format": response_format, - "sampling_params": sampling_params, - "tool_choice": tool_choice, - "tool_prompt_format": tool_prompt_format, - "tools": tools, - }, - batch_inference_chat_completion_params.BatchInferenceChatCompletionParams, - ), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=BatchInferenceChatCompletionResponse, - ) - - async def completion( - self, - *, - content_batch: List[InterleavedContent], - model: str, - logprobs: batch_inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, - response_format: ResponseFormat | NotGiven = NOT_GIVEN, - sampling_params: SamplingParams | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> BatchCompletion: - """ - Args: - response_format: Configuration for JSON schema-guided response generation. - - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return await self._post( - "/v1/batch-inference/completion", - body=await async_maybe_transform( - { - "content_batch": content_batch, - "model": model, - "logprobs": logprobs, - "response_format": response_format, - "sampling_params": sampling_params, - }, - batch_inference_completion_params.BatchInferenceCompletionParams, - ), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=BatchCompletion, - ) - - -class BatchInferenceResourceWithRawResponse: - def __init__(self, batch_inference: BatchInferenceResource) -> None: - self._batch_inference = batch_inference - - self.chat_completion = to_raw_response_wrapper( - batch_inference.chat_completion, - ) - self.completion = to_raw_response_wrapper( - batch_inference.completion, - ) - - -class AsyncBatchInferenceResourceWithRawResponse: - def __init__(self, batch_inference: AsyncBatchInferenceResource) -> None: - self._batch_inference = batch_inference - - self.chat_completion = async_to_raw_response_wrapper( - batch_inference.chat_completion, - ) - self.completion = async_to_raw_response_wrapper( - batch_inference.completion, - ) - - -class BatchInferenceResourceWithStreamingResponse: - def __init__(self, batch_inference: BatchInferenceResource) -> None: - self._batch_inference = batch_inference - - self.chat_completion = to_streamed_response_wrapper( - batch_inference.chat_completion, - ) - self.completion = to_streamed_response_wrapper( - batch_inference.completion, - ) - - -class AsyncBatchInferenceResourceWithStreamingResponse: - def __init__(self, batch_inference: AsyncBatchInferenceResource) -> None: - self._batch_inference = batch_inference - - self.chat_completion = async_to_streamed_response_wrapper( - batch_inference.chat_completion, - ) - self.completion = async_to_streamed_response_wrapper( - batch_inference.completion, - ) diff --git a/src/llama_stack_client/resources/datasets.py b/src/llama_stack_client/resources/datasets.py index ed56ac80..aaa27a5e 100644 --- a/src/llama_stack_client/resources/datasets.py +++ b/src/llama_stack_client/resources/datasets.py @@ -119,7 +119,15 @@ def iterrows( ) -> DatasetIterrowsResponse: """Get a paginated list of rows from a dataset. - Uses cursor-based pagination. + Uses offset-based pagination where: + + - start_index: The starting index (0-based). If None, starts from beginning. + - limit: Number of items to return. If None or -1, returns all items. + + The response includes: + + - data: List of items for the current page + - has_more: Whether there are more items available after this set Args: limit: The number of rows to get. @@ -344,7 +352,15 @@ async def iterrows( ) -> DatasetIterrowsResponse: """Get a paginated list of rows from a dataset. - Uses cursor-based pagination. + Uses offset-based pagination where: + + - start_index: The starting index (0-based). If None, starts from beginning. + - limit: Number of items to return. If None or -1, returns all items. + + The response includes: + + - data: List of items for the current page + - has_more: Whether there are more items available after this set Args: limit: The number of rows to get. diff --git a/src/llama_stack_client/resources/inference.py b/src/llama_stack_client/resources/inference.py index c4d85852..428956ce 100644 --- a/src/llama_stack_client/resources/inference.py +++ b/src/llama_stack_client/resources/inference.py @@ -11,6 +11,8 @@ inference_completion_params, inference_embeddings_params, inference_chat_completion_params, + inference_batch_completion_params, + inference_batch_chat_completion_params, ) from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven from .._utils import ( @@ -31,12 +33,14 @@ from ..types.completion_response import CompletionResponse from ..types.embeddings_response import EmbeddingsResponse from ..types.shared_params.message import Message +from ..types.shared.batch_completion import BatchCompletion from ..types.shared_params.response_format import ResponseFormat from ..types.shared_params.sampling_params import SamplingParams from ..types.shared.chat_completion_response import ChatCompletionResponse from ..types.shared_params.interleaved_content import InterleavedContent from ..types.chat_completion_response_stream_chunk import ChatCompletionResponseStreamChunk from ..types.shared_params.interleaved_content_item import InterleavedContentItem +from ..types.inference_batch_chat_completion_response import InferenceBatchChatCompletionResponse __all__ = ["InferenceResource", "AsyncInferenceResource"] @@ -61,6 +65,106 @@ def with_streaming_response(self) -> InferenceResourceWithStreamingResponse: """ return InferenceResourceWithStreamingResponse(self) + def batch_chat_completion( + self, + *, + messages_batch: Iterable[Iterable[Message]], + model_id: str, + logprobs: inference_batch_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + tool_config: inference_batch_chat_completion_params.ToolConfig | NotGiven = NOT_GIVEN, + tools: Iterable[inference_batch_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceBatchChatCompletionResponse: + """ + Args: + response_format: Configuration for JSON schema-guided response generation. + + sampling_params: Sampling parameters. + + tool_config: Configuration for tool use. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._post( + "/v1/inference/batch-chat-completion", + body=maybe_transform( + { + "messages_batch": messages_batch, + "model_id": model_id, + "logprobs": logprobs, + "response_format": response_format, + "sampling_params": sampling_params, + "tool_config": tool_config, + "tools": tools, + }, + inference_batch_chat_completion_params.InferenceBatchChatCompletionParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=InferenceBatchChatCompletionResponse, + ) + + def batch_completion( + self, + *, + content_batch: List[InterleavedContent], + model_id: str, + logprobs: inference_batch_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BatchCompletion: + """ + Args: + response_format: Configuration for JSON schema-guided response generation. + + sampling_params: Sampling parameters. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._post( + "/v1/inference/batch-completion", + body=maybe_transform( + { + "content_batch": content_batch, + "model_id": model_id, + "logprobs": logprobs, + "response_format": response_format, + "sampling_params": sampling_params, + }, + inference_batch_completion_params.InferenceBatchCompletionParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=BatchCompletion, + ) + @overload def chat_completion( self, @@ -303,7 +407,9 @@ def chat_completion( "tool_prompt_format": tool_prompt_format, "tools": tools, }, - inference_chat_completion_params.InferenceChatCompletionParams, + inference_chat_completion_params.InferenceChatCompletionParamsStreaming + if stream + else inference_chat_completion_params.InferenceChatCompletionParamsNonStreaming, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -481,7 +587,9 @@ def completion( "sampling_params": sampling_params, "stream": stream, }, - inference_completion_params.InferenceCompletionParams, + inference_completion_params.InferenceCompletionParamsStreaming + if stream + else inference_completion_params.InferenceCompletionParamsNonStreaming, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -573,6 +681,106 @@ def with_streaming_response(self) -> AsyncInferenceResourceWithStreamingResponse """ return AsyncInferenceResourceWithStreamingResponse(self) + async def batch_chat_completion( + self, + *, + messages_batch: Iterable[Iterable[Message]], + model_id: str, + logprobs: inference_batch_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + tool_config: inference_batch_chat_completion_params.ToolConfig | NotGiven = NOT_GIVEN, + tools: Iterable[inference_batch_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceBatchChatCompletionResponse: + """ + Args: + response_format: Configuration for JSON schema-guided response generation. + + sampling_params: Sampling parameters. + + tool_config: Configuration for tool use. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return await self._post( + "/v1/inference/batch-chat-completion", + body=await async_maybe_transform( + { + "messages_batch": messages_batch, + "model_id": model_id, + "logprobs": logprobs, + "response_format": response_format, + "sampling_params": sampling_params, + "tool_config": tool_config, + "tools": tools, + }, + inference_batch_chat_completion_params.InferenceBatchChatCompletionParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=InferenceBatchChatCompletionResponse, + ) + + async def batch_completion( + self, + *, + content_batch: List[InterleavedContent], + model_id: str, + logprobs: inference_batch_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BatchCompletion: + """ + Args: + response_format: Configuration for JSON schema-guided response generation. + + sampling_params: Sampling parameters. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return await self._post( + "/v1/inference/batch-completion", + body=await async_maybe_transform( + { + "content_batch": content_batch, + "model_id": model_id, + "logprobs": logprobs, + "response_format": response_format, + "sampling_params": sampling_params, + }, + inference_batch_completion_params.InferenceBatchCompletionParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=BatchCompletion, + ) + @overload async def chat_completion( self, @@ -815,7 +1023,9 @@ async def chat_completion( "tool_prompt_format": tool_prompt_format, "tools": tools, }, - inference_chat_completion_params.InferenceChatCompletionParams, + inference_chat_completion_params.InferenceChatCompletionParamsStreaming + if stream + else inference_chat_completion_params.InferenceChatCompletionParamsNonStreaming, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -993,7 +1203,9 @@ async def completion( "sampling_params": sampling_params, "stream": stream, }, - inference_completion_params.InferenceCompletionParams, + inference_completion_params.InferenceCompletionParamsStreaming + if stream + else inference_completion_params.InferenceCompletionParamsNonStreaming, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -1069,6 +1281,12 @@ class InferenceResourceWithRawResponse: def __init__(self, inference: InferenceResource) -> None: self._inference = inference + self.batch_chat_completion = to_raw_response_wrapper( + inference.batch_chat_completion, + ) + self.batch_completion = to_raw_response_wrapper( + inference.batch_completion, + ) self.chat_completion = to_raw_response_wrapper( inference.chat_completion, ) @@ -1084,6 +1302,12 @@ class AsyncInferenceResourceWithRawResponse: def __init__(self, inference: AsyncInferenceResource) -> None: self._inference = inference + self.batch_chat_completion = async_to_raw_response_wrapper( + inference.batch_chat_completion, + ) + self.batch_completion = async_to_raw_response_wrapper( + inference.batch_completion, + ) self.chat_completion = async_to_raw_response_wrapper( inference.chat_completion, ) @@ -1099,6 +1323,12 @@ class InferenceResourceWithStreamingResponse: def __init__(self, inference: InferenceResource) -> None: self._inference = inference + self.batch_chat_completion = to_streamed_response_wrapper( + inference.batch_chat_completion, + ) + self.batch_completion = to_streamed_response_wrapper( + inference.batch_completion, + ) self.chat_completion = to_streamed_response_wrapper( inference.chat_completion, ) @@ -1114,6 +1344,12 @@ class AsyncInferenceResourceWithStreamingResponse: def __init__(self, inference: AsyncInferenceResource) -> None: self._inference = inference + self.batch_chat_completion = async_to_streamed_response_wrapper( + inference.batch_chat_completion, + ) + self.batch_completion = async_to_streamed_response_wrapper( + inference.batch_completion, + ) self.chat_completion = async_to_streamed_response_wrapper( inference.chat_completion, ) diff --git a/src/llama_stack_client/resources/tool_runtime/tool_runtime.py b/src/llama_stack_client/resources/tool_runtime/tool_runtime.py index 2bd7347b..aa380f79 100644 --- a/src/llama_stack_client/resources/tool_runtime/tool_runtime.py +++ b/src/llama_stack_client/resources/tool_runtime/tool_runtime.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Dict, Union, Iterable +from typing import Dict, Type, Union, Iterable, cast import httpx @@ -28,10 +28,10 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ..._wrappers import DataWrapper from ..._base_client import make_request_options -from ...types.tool_def import ToolDef -from ..._decoders.jsonl import JSONLDecoder, AsyncJSONLDecoder from ...types.tool_invocation_result import ToolInvocationResult +from ...types.tool_runtime_list_tools_response import ToolRuntimeListToolsResponse __all__ = ["ToolRuntimeResource", "AsyncToolRuntimeResource"] @@ -110,7 +110,7 @@ def list_tools( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> JSONLDecoder[ToolDef]: + ) -> ToolRuntimeListToolsResponse: """ Args: extra_headers: Send extra headers @@ -121,7 +121,6 @@ def list_tools( timeout: Override the client-level default timeout for this request, in seconds """ - extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} return self._get( "/v1/tool-runtime/list-tools", options=make_request_options( @@ -136,9 +135,9 @@ def list_tools( }, tool_runtime_list_tools_params.ToolRuntimeListToolsParams, ), + post_parser=DataWrapper[ToolRuntimeListToolsResponse]._unwrapper, ), - cast_to=JSONLDecoder[ToolDef], - stream=True, + cast_to=cast(Type[ToolRuntimeListToolsResponse], DataWrapper[ToolRuntimeListToolsResponse]), ) @@ -216,7 +215,7 @@ async def list_tools( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> AsyncJSONLDecoder[ToolDef]: + ) -> ToolRuntimeListToolsResponse: """ Args: extra_headers: Send extra headers @@ -227,7 +226,6 @@ async def list_tools( timeout: Override the client-level default timeout for this request, in seconds """ - extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} return await self._get( "/v1/tool-runtime/list-tools", options=make_request_options( @@ -242,9 +240,9 @@ async def list_tools( }, tool_runtime_list_tools_params.ToolRuntimeListToolsParams, ), + post_parser=DataWrapper[ToolRuntimeListToolsResponse]._unwrapper, ), - cast_to=AsyncJSONLDecoder[ToolDef], - stream=True, + cast_to=cast(Type[ToolRuntimeListToolsResponse], DataWrapper[ToolRuntimeListToolsResponse]), ) diff --git a/src/llama_stack_client/types/__init__.py b/src/llama_stack_client/types/__init__.py index b45996a9..a78eae03 100644 --- a/src/llama_stack_client/types/__init__.py +++ b/src/llama_stack_client/types/__init__.py @@ -126,19 +126,20 @@ from .list_post_training_jobs_response import ListPostTrainingJobsResponse as ListPostTrainingJobsResponse from .scoring_function_register_params import ScoringFunctionRegisterParams as ScoringFunctionRegisterParams from .telemetry_get_span_tree_response import TelemetryGetSpanTreeResponse as TelemetryGetSpanTreeResponse -from .batch_inference_completion_params import BatchInferenceCompletionParams as BatchInferenceCompletionParams +from .tool_runtime_list_tools_response import ToolRuntimeListToolsResponse as ToolRuntimeListToolsResponse +from .inference_batch_completion_params import InferenceBatchCompletionParams as InferenceBatchCompletionParams from .synthetic_data_generation_response import SyntheticDataGenerationResponse as SyntheticDataGenerationResponse from .chat_completion_response_stream_chunk import ( ChatCompletionResponseStreamChunk as ChatCompletionResponseStreamChunk, ) -from .batch_inference_chat_completion_params import ( - BatchInferenceChatCompletionParams as BatchInferenceChatCompletionParams, +from .inference_batch_chat_completion_params import ( + InferenceBatchChatCompletionParams as InferenceBatchChatCompletionParams, ) from .telemetry_save_spans_to_dataset_params import ( TelemetrySaveSpansToDatasetParams as TelemetrySaveSpansToDatasetParams, ) -from .batch_inference_chat_completion_response import ( - BatchInferenceChatCompletionResponse as BatchInferenceChatCompletionResponse, +from .inference_batch_chat_completion_response import ( + InferenceBatchChatCompletionResponse as InferenceBatchChatCompletionResponse, ) from .post_training_preference_optimize_params import ( PostTrainingPreferenceOptimizeParams as PostTrainingPreferenceOptimizeParams, diff --git a/src/llama_stack_client/types/batch_inference_chat_completion_params.py b/src/llama_stack_client/types/batch_inference_chat_completion_params.py deleted file mode 100644 index 3091f5bb..00000000 --- a/src/llama_stack_client/types/batch_inference_chat_completion_params.py +++ /dev/null @@ -1,51 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Dict, Union, Iterable -from typing_extensions import Literal, Required, TypedDict - -from .shared_params.message import Message -from .shared_params.response_format import ResponseFormat -from .shared_params.sampling_params import SamplingParams -from .shared_params.tool_param_definition import ToolParamDefinition - -__all__ = ["BatchInferenceChatCompletionParams", "Logprobs", "Tool"] - - -class BatchInferenceChatCompletionParams(TypedDict, total=False): - messages_batch: Required[Iterable[Iterable[Message]]] - - model: Required[str] - - logprobs: Logprobs - - response_format: ResponseFormat - """Configuration for JSON schema-guided response generation.""" - - sampling_params: SamplingParams - - tool_choice: Literal["auto", "required", "none"] - """Whether tool use is required or automatic. - - This is a hint to the model which may not be followed. It depends on the - Instruction Following capabilities of the model. - """ - - tool_prompt_format: Literal["json", "function_tag", "python_list"] - """Prompt format for calling custom / zero shot tools.""" - - tools: Iterable[Tool] - - -class Logprobs(TypedDict, total=False): - top_k: int - """How many tokens (for each position) to return log probabilities for.""" - - -class Tool(TypedDict, total=False): - tool_name: Required[Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]] - - description: str - - parameters: Dict[str, ToolParamDefinition] diff --git a/src/llama_stack_client/types/dataset_iterrows_response.py b/src/llama_stack_client/types/dataset_iterrows_response.py index 48593bb2..9c451a8c 100644 --- a/src/llama_stack_client/types/dataset_iterrows_response.py +++ b/src/llama_stack_client/types/dataset_iterrows_response.py @@ -1,6 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, List, Union, Optional +from typing import Dict, List, Union from .._models import BaseModel @@ -9,10 +9,7 @@ class DatasetIterrowsResponse(BaseModel): data: List[Dict[str, Union[bool, float, str, List[object], object, None]]] - """The rows in the current page.""" + """The list of items for the current page""" - next_start_index: Optional[int] = None - """Index into dataset for the first row in the next page. - - None if there are no more rows. - """ + has_more: bool + """Whether there are more items available after this set""" diff --git a/src/llama_stack_client/types/inference_batch_chat_completion_params.py b/src/llama_stack_client/types/inference_batch_chat_completion_params.py new file mode 100644 index 00000000..ca53fdbf --- /dev/null +++ b/src/llama_stack_client/types/inference_batch_chat_completion_params.py @@ -0,0 +1,74 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, Union, Iterable +from typing_extensions import Literal, Required, TypedDict + +from .shared_params.message import Message +from .shared_params.response_format import ResponseFormat +from .shared_params.sampling_params import SamplingParams +from .shared_params.tool_param_definition import ToolParamDefinition + +__all__ = ["InferenceBatchChatCompletionParams", "Logprobs", "ToolConfig", "Tool"] + + +class InferenceBatchChatCompletionParams(TypedDict, total=False): + messages_batch: Required[Iterable[Iterable[Message]]] + + model_id: Required[str] + + logprobs: Logprobs + + response_format: ResponseFormat + """Configuration for JSON schema-guided response generation.""" + + sampling_params: SamplingParams + """Sampling parameters.""" + + tool_config: ToolConfig + """Configuration for tool use.""" + + tools: Iterable[Tool] + + +class Logprobs(TypedDict, total=False): + top_k: int + """How many tokens (for each position) to return log probabilities for.""" + + +class ToolConfig(TypedDict, total=False): + system_message_behavior: Literal["append", "replace"] + """(Optional) Config for how to override the default system prompt. + + - `SystemMessageBehavior.append`: Appends the provided system message to the + default system prompt. - `SystemMessageBehavior.replace`: Replaces the default + system prompt with the provided system message. The system message can include + the string '{{function_definitions}}' to indicate where the function + definitions should be inserted. + """ + + tool_choice: Union[Literal["auto", "required", "none"], str] + """(Optional) Whether tool use is automatic, required, or none. + + Can also specify a tool name to use a specific tool. Defaults to + ToolChoice.auto. + """ + + tool_prompt_format: Literal["json", "function_tag", "python_list"] + """(Optional) Instructs the model how to format tool calls. + + By default, Llama Stack will attempt to use a format that is best adapted to the + model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON + object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a + tag. - `ToolPromptFormat.python_list`: The tool calls + are output as Python syntax -- a list of function calls. + """ + + +class Tool(TypedDict, total=False): + tool_name: Required[Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]] + + description: str + + parameters: Dict[str, ToolParamDefinition] diff --git a/src/llama_stack_client/types/batch_inference_chat_completion_response.py b/src/llama_stack_client/types/inference_batch_chat_completion_response.py similarity index 70% rename from src/llama_stack_client/types/batch_inference_chat_completion_response.py rename to src/llama_stack_client/types/inference_batch_chat_completion_response.py index 218b1275..84d6c425 100644 --- a/src/llama_stack_client/types/batch_inference_chat_completion_response.py +++ b/src/llama_stack_client/types/inference_batch_chat_completion_response.py @@ -5,8 +5,8 @@ from .._models import BaseModel from .shared.chat_completion_response import ChatCompletionResponse -__all__ = ["BatchInferenceChatCompletionResponse"] +__all__ = ["InferenceBatchChatCompletionResponse"] -class BatchInferenceChatCompletionResponse(BaseModel): +class InferenceBatchChatCompletionResponse(BaseModel): batch: List[ChatCompletionResponse] diff --git a/src/llama_stack_client/types/batch_inference_completion_params.py b/src/llama_stack_client/types/inference_batch_completion_params.py similarity index 80% rename from src/llama_stack_client/types/batch_inference_completion_params.py rename to src/llama_stack_client/types/inference_batch_completion_params.py index 3f80d625..cbeb9309 100644 --- a/src/llama_stack_client/types/batch_inference_completion_params.py +++ b/src/llama_stack_client/types/inference_batch_completion_params.py @@ -9,13 +9,13 @@ from .shared_params.sampling_params import SamplingParams from .shared_params.interleaved_content import InterleavedContent -__all__ = ["BatchInferenceCompletionParams", "Logprobs"] +__all__ = ["InferenceBatchCompletionParams", "Logprobs"] -class BatchInferenceCompletionParams(TypedDict, total=False): +class InferenceBatchCompletionParams(TypedDict, total=False): content_batch: Required[List[InterleavedContent]] - model: Required[str] + model_id: Required[str] logprobs: Logprobs @@ -23,6 +23,7 @@ class BatchInferenceCompletionParams(TypedDict, total=False): """Configuration for JSON schema-guided response generation.""" sampling_params: SamplingParams + """Sampling parameters.""" class Logprobs(TypedDict, total=False): diff --git a/src/llama_stack_client/types/job.py b/src/llama_stack_client/types/job.py index 74c6beb7..4953b3bf 100644 --- a/src/llama_stack_client/types/job.py +++ b/src/llama_stack_client/types/job.py @@ -10,4 +10,4 @@ class Job(BaseModel): job_id: str - status: Literal["completed", "in_progress", "failed", "scheduled"] + status: Literal["completed", "in_progress", "failed", "scheduled", "cancelled"] diff --git a/src/llama_stack_client/types/post_training/job_status_response.py b/src/llama_stack_client/types/post_training/job_status_response.py index 250bd82a..5ba60a6a 100644 --- a/src/llama_stack_client/types/post_training/job_status_response.py +++ b/src/llama_stack_client/types/post_training/job_status_response.py @@ -14,7 +14,7 @@ class JobStatusResponse(BaseModel): job_uuid: str - status: Literal["completed", "in_progress", "failed", "scheduled"] + status: Literal["completed", "in_progress", "failed", "scheduled", "cancelled"] completed_at: Optional[datetime] = None diff --git a/src/llama_stack_client/types/shared/agent_config.py b/src/llama_stack_client/types/shared/agent_config.py index 273487ae..04997ac4 100644 --- a/src/llama_stack_client/types/shared/agent_config.py +++ b/src/llama_stack_client/types/shared/agent_config.py @@ -68,6 +68,7 @@ class AgentConfig(BaseModel): """Configuration for JSON schema-guided response generation.""" sampling_params: Optional[SamplingParams] = None + """Sampling parameters.""" tool_choice: Optional[Literal["auto", "required", "none"]] = None """Whether tool use is required or automatic. diff --git a/src/llama_stack_client/types/shared/sampling_params.py b/src/llama_stack_client/types/shared/sampling_params.py index bb5866a6..7ce2211e 100644 --- a/src/llama_stack_client/types/shared/sampling_params.py +++ b/src/llama_stack_client/types/shared/sampling_params.py @@ -1,6 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Union, Optional +from typing import List, Union, Optional from typing_extensions import Literal, Annotated, TypeAlias from ..._utils import PropertyInfo @@ -41,7 +41,24 @@ class StrategyTopKSamplingStrategy(BaseModel): class SamplingParams(BaseModel): strategy: Strategy + """The sampling strategy.""" max_tokens: Optional[int] = None + """The maximum number of tokens that can be generated in the completion. + + The token count of your prompt plus max_tokens cannot exceed the model's context + length. + """ repetition_penalty: Optional[float] = None + """Number between -2.0 and 2.0. + + Positive values penalize new tokens based on whether they appear in the text so + far, increasing the model's likelihood to talk about new topics. + """ + + stop: Optional[List[str]] = None + """Up to 4 sequences where the API will stop generating further tokens. + + The returned text will not contain the stop sequence. + """ diff --git a/src/llama_stack_client/types/shared_params/agent_config.py b/src/llama_stack_client/types/shared_params/agent_config.py index 0107ee42..f07efa39 100644 --- a/src/llama_stack_client/types/shared_params/agent_config.py +++ b/src/llama_stack_client/types/shared_params/agent_config.py @@ -69,6 +69,7 @@ class AgentConfig(TypedDict, total=False): """Configuration for JSON schema-guided response generation.""" sampling_params: SamplingParams + """Sampling parameters.""" tool_choice: Literal["auto", "required", "none"] """Whether tool use is required or automatic. diff --git a/src/llama_stack_client/types/shared_params/sampling_params.py b/src/llama_stack_client/types/shared_params/sampling_params.py index 1d9bcaf5..158db1c5 100644 --- a/src/llama_stack_client/types/shared_params/sampling_params.py +++ b/src/llama_stack_client/types/shared_params/sampling_params.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Union +from typing import List, Union from typing_extensions import Literal, Required, TypeAlias, TypedDict __all__ = [ @@ -37,7 +37,24 @@ class StrategyTopKSamplingStrategy(TypedDict, total=False): class SamplingParams(TypedDict, total=False): strategy: Required[Strategy] + """The sampling strategy.""" max_tokens: int + """The maximum number of tokens that can be generated in the completion. + + The token count of your prompt plus max_tokens cannot exceed the model's context + length. + """ repetition_penalty: float + """Number between -2.0 and 2.0. + + Positive values penalize new tokens based on whether they appear in the text so + far, increasing the model's likelihood to talk about new topics. + """ + + stop: List[str] + """Up to 4 sequences where the API will stop generating further tokens. + + The returned text will not contain the stop sequence. + """ diff --git a/src/llama_stack_client/types/tool_runtime_list_tools_response.py b/src/llama_stack_client/types/tool_runtime_list_tools_response.py new file mode 100644 index 00000000..cd65754f --- /dev/null +++ b/src/llama_stack_client/types/tool_runtime_list_tools_response.py @@ -0,0 +1,10 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List +from typing_extensions import TypeAlias + +from .tool_def import ToolDef + +__all__ = ["ToolRuntimeListToolsResponse"] + +ToolRuntimeListToolsResponse: TypeAlias = List[ToolDef] diff --git a/tests/api_resources/test_agents.py b/tests/api_resources/test_agents.py index 03a15837..235d6258 100644 --- a/tests/api_resources/test_agents.py +++ b/tests/api_resources/test_agents.py @@ -61,6 +61,7 @@ def test_method_create_with_all_params(self, client: LlamaStackClient) -> None: "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, "tool_choice": "auto", "tool_config": { @@ -190,6 +191,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, "tool_choice": "auto", "tool_config": { diff --git a/tests/api_resources/test_batch_inference.py b/tests/api_resources/test_batch_inference.py deleted file mode 100644 index 8e5cb9e5..00000000 --- a/tests/api_resources/test_batch_inference.py +++ /dev/null @@ -1,323 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import os -from typing import Any, cast - -import pytest - -from tests.utils import assert_matches_type -from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import ( - BatchInferenceChatCompletionResponse, -) -from llama_stack_client.types.shared import BatchCompletion - -base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") - - -class TestBatchInference: - parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - def test_method_chat_completion(self, client: LlamaStackClient) -> None: - batch_inference = client.batch_inference.chat_completion( - messages_batch=[ - [ - { - "content": "string", - "role": "user", - } - ] - ], - model="model", - ) - assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) - - @parametrize - def test_method_chat_completion_with_all_params(self, client: LlamaStackClient) -> None: - batch_inference = client.batch_inference.chat_completion( - messages_batch=[ - [ - { - "content": "string", - "role": "user", - "context": "string", - } - ] - ], - model="model", - logprobs={"top_k": 0}, - response_format={ - "json_schema": {"foo": True}, - "type": "json_schema", - }, - sampling_params={ - "strategy": {"type": "greedy"}, - "max_tokens": 0, - "repetition_penalty": 0, - }, - tool_choice="auto", - tool_prompt_format="json", - tools=[ - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "default": True, - "description": "description", - "required": True, - } - }, - } - ], - ) - assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) - - @parametrize - def test_raw_response_chat_completion(self, client: LlamaStackClient) -> None: - response = client.batch_inference.with_raw_response.chat_completion( - messages_batch=[ - [ - { - "content": "string", - "role": "user", - } - ] - ], - model="model", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - batch_inference = response.parse() - assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) - - @parametrize - def test_streaming_response_chat_completion(self, client: LlamaStackClient) -> None: - with client.batch_inference.with_streaming_response.chat_completion( - messages_batch=[ - [ - { - "content": "string", - "role": "user", - } - ] - ], - model="model", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - batch_inference = response.parse() - assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_completion(self, client: LlamaStackClient) -> None: - batch_inference = client.batch_inference.completion( - content_batch=["string"], - model="model", - ) - assert_matches_type(BatchCompletion, batch_inference, path=["response"]) - - @parametrize - def test_method_completion_with_all_params(self, client: LlamaStackClient) -> None: - batch_inference = client.batch_inference.completion( - content_batch=["string"], - model="model", - logprobs={"top_k": 0}, - response_format={ - "json_schema": {"foo": True}, - "type": "json_schema", - }, - sampling_params={ - "strategy": {"type": "greedy"}, - "max_tokens": 0, - "repetition_penalty": 0, - }, - ) - assert_matches_type(BatchCompletion, batch_inference, path=["response"]) - - @parametrize - def test_raw_response_completion(self, client: LlamaStackClient) -> None: - response = client.batch_inference.with_raw_response.completion( - content_batch=["string"], - model="model", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - batch_inference = response.parse() - assert_matches_type(BatchCompletion, batch_inference, path=["response"]) - - @parametrize - def test_streaming_response_completion(self, client: LlamaStackClient) -> None: - with client.batch_inference.with_streaming_response.completion( - content_batch=["string"], - model="model", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - batch_inference = response.parse() - assert_matches_type(BatchCompletion, batch_inference, path=["response"]) - - assert cast(Any, response.is_closed) is True - - -class TestAsyncBatchInference: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - async def test_method_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: - batch_inference = await async_client.batch_inference.chat_completion( - messages_batch=[ - [ - { - "content": "string", - "role": "user", - } - ] - ], - model="model", - ) - assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) - - @parametrize - async def test_method_chat_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - batch_inference = await async_client.batch_inference.chat_completion( - messages_batch=[ - [ - { - "content": "string", - "role": "user", - "context": "string", - } - ] - ], - model="model", - logprobs={"top_k": 0}, - response_format={ - "json_schema": {"foo": True}, - "type": "json_schema", - }, - sampling_params={ - "strategy": {"type": "greedy"}, - "max_tokens": 0, - "repetition_penalty": 0, - }, - tool_choice="auto", - tool_prompt_format="json", - tools=[ - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "default": True, - "description": "description", - "required": True, - } - }, - } - ], - ) - assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) - - @parametrize - async def test_raw_response_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.batch_inference.with_raw_response.chat_completion( - messages_batch=[ - [ - { - "content": "string", - "role": "user", - } - ] - ], - model="model", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - batch_inference = await response.parse() - assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) - - @parametrize - async def test_streaming_response_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.batch_inference.with_streaming_response.chat_completion( - messages_batch=[ - [ - { - "content": "string", - "role": "user", - } - ] - ], - model="model", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - batch_inference = await response.parse() - assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_completion(self, async_client: AsyncLlamaStackClient) -> None: - batch_inference = await async_client.batch_inference.completion( - content_batch=["string"], - model="model", - ) - assert_matches_type(BatchCompletion, batch_inference, path=["response"]) - - @parametrize - async def test_method_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - batch_inference = await async_client.batch_inference.completion( - content_batch=["string"], - model="model", - logprobs={"top_k": 0}, - response_format={ - "json_schema": {"foo": True}, - "type": "json_schema", - }, - sampling_params={ - "strategy": {"type": "greedy"}, - "max_tokens": 0, - "repetition_penalty": 0, - }, - ) - assert_matches_type(BatchCompletion, batch_inference, path=["response"]) - - @parametrize - async def test_raw_response_completion(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.batch_inference.with_raw_response.completion( - content_batch=["string"], - model="model", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - batch_inference = await response.parse() - assert_matches_type(BatchCompletion, batch_inference, path=["response"]) - - @parametrize - async def test_streaming_response_completion(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.batch_inference.with_streaming_response.completion( - content_batch=["string"], - model="model", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - batch_inference = await response.parse() - assert_matches_type(BatchCompletion, batch_inference, path=["response"]) - - assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_eval.py b/tests/api_resources/test_eval.py index 9735b4c4..c519056b 100644 --- a/tests/api_resources/test_eval.py +++ b/tests/api_resources/test_eval.py @@ -53,6 +53,7 @@ def test_method_evaluate_rows_with_all_params(self, client: LlamaStackClient) -> "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, "type": "model", "system_message": { @@ -185,6 +186,7 @@ def test_method_evaluate_rows_alpha_with_all_params(self, client: LlamaStackClie "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, "type": "model", "system_message": { @@ -315,6 +317,7 @@ def test_method_run_eval_with_all_params(self, client: LlamaStackClient) -> None "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, "type": "model", "system_message": { @@ -437,6 +440,7 @@ def test_method_run_eval_alpha_with_all_params(self, client: LlamaStackClient) - "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, "type": "model", "system_message": { @@ -565,6 +569,7 @@ async def test_method_evaluate_rows_with_all_params(self, async_client: AsyncLla "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, "type": "model", "system_message": { @@ -697,6 +702,7 @@ async def test_method_evaluate_rows_alpha_with_all_params(self, async_client: As "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, "type": "model", "system_message": { @@ -827,6 +833,7 @@ async def test_method_run_eval_with_all_params(self, async_client: AsyncLlamaSta "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, "type": "model", "system_message": { @@ -949,6 +956,7 @@ async def test_method_run_eval_alpha_with_all_params(self, async_client: AsyncLl "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, "type": "model", "system_message": { diff --git a/tests/api_resources/test_inference.py b/tests/api_resources/test_inference.py index 4d078587..d876ae56 100644 --- a/tests/api_resources/test_inference.py +++ b/tests/api_resources/test_inference.py @@ -12,8 +12,9 @@ from llama_stack_client.types import ( CompletionResponse, EmbeddingsResponse, + InferenceBatchChatCompletionResponse, ) -from llama_stack_client.types.shared import ChatCompletionResponse +from llama_stack_client.types.shared import BatchCompletion, ChatCompletionResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -21,6 +22,160 @@ class TestInference: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + @parametrize + def test_method_batch_chat_completion(self, client: LlamaStackClient) -> None: + inference = client.inference.batch_chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + } + ] + ], + model_id="model_id", + ) + assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"]) + + @parametrize + def test_method_batch_chat_completion_with_all_params(self, client: LlamaStackClient) -> None: + inference = client.inference.batch_chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + "context": "string", + } + ] + ], + model_id="model_id", + logprobs={"top_k": 0}, + response_format={ + "json_schema": {"foo": True}, + "type": "json_schema", + }, + sampling_params={ + "strategy": {"type": "greedy"}, + "max_tokens": 0, + "repetition_penalty": 0, + "stop": ["string"], + }, + tool_config={ + "system_message_behavior": "append", + "tool_choice": "auto", + "tool_prompt_format": "json", + }, + tools=[ + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "default": True, + "description": "description", + "required": True, + } + }, + } + ], + ) + assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"]) + + @parametrize + def test_raw_response_batch_chat_completion(self, client: LlamaStackClient) -> None: + response = client.inference.with_raw_response.batch_chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + } + ] + ], + model_id="model_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + inference = response.parse() + assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"]) + + @parametrize + def test_streaming_response_batch_chat_completion(self, client: LlamaStackClient) -> None: + with client.inference.with_streaming_response.batch_chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + } + ] + ], + model_id="model_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + inference = response.parse() + assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_batch_completion(self, client: LlamaStackClient) -> None: + inference = client.inference.batch_completion( + content_batch=["string"], + model_id="model_id", + ) + assert_matches_type(BatchCompletion, inference, path=["response"]) + + @parametrize + def test_method_batch_completion_with_all_params(self, client: LlamaStackClient) -> None: + inference = client.inference.batch_completion( + content_batch=["string"], + model_id="model_id", + logprobs={"top_k": 0}, + response_format={ + "json_schema": {"foo": True}, + "type": "json_schema", + }, + sampling_params={ + "strategy": {"type": "greedy"}, + "max_tokens": 0, + "repetition_penalty": 0, + "stop": ["string"], + }, + ) + assert_matches_type(BatchCompletion, inference, path=["response"]) + + @parametrize + def test_raw_response_batch_completion(self, client: LlamaStackClient) -> None: + response = client.inference.with_raw_response.batch_completion( + content_batch=["string"], + model_id="model_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + inference = response.parse() + assert_matches_type(BatchCompletion, inference, path=["response"]) + + @parametrize + def test_streaming_response_batch_completion(self, client: LlamaStackClient) -> None: + with client.inference.with_streaming_response.batch_completion( + content_batch=["string"], + model_id="model_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + inference = response.parse() + assert_matches_type(BatchCompletion, inference, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_chat_completion_overload_1(self, client: LlamaStackClient) -> None: inference = client.inference.chat_completion( @@ -54,6 +209,7 @@ def test_method_chat_completion_with_all_params_overload_1(self, client: LlamaSt "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, stream=False, tool_choice="auto", @@ -151,6 +307,7 @@ def test_method_chat_completion_with_all_params_overload_2(self, client: LlamaSt "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, tool_choice="auto", tool_config={ @@ -235,6 +392,7 @@ def test_method_completion_with_all_params_overload_1(self, client: LlamaStackCl "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, stream=False, ) @@ -290,6 +448,7 @@ def test_method_completion_with_all_params_overload_2(self, client: LlamaStackCl "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, ) inference_stream.response.close() @@ -370,6 +529,160 @@ def test_streaming_response_embeddings(self, client: LlamaStackClient) -> None: class TestAsyncInference: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + @parametrize + async def test_method_batch_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: + inference = await async_client.inference.batch_chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + } + ] + ], + model_id="model_id", + ) + assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"]) + + @parametrize + async def test_method_batch_chat_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + inference = await async_client.inference.batch_chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + "context": "string", + } + ] + ], + model_id="model_id", + logprobs={"top_k": 0}, + response_format={ + "json_schema": {"foo": True}, + "type": "json_schema", + }, + sampling_params={ + "strategy": {"type": "greedy"}, + "max_tokens": 0, + "repetition_penalty": 0, + "stop": ["string"], + }, + tool_config={ + "system_message_behavior": "append", + "tool_choice": "auto", + "tool_prompt_format": "json", + }, + tools=[ + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "default": True, + "description": "description", + "required": True, + } + }, + } + ], + ) + assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"]) + + @parametrize + async def test_raw_response_batch_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.inference.with_raw_response.batch_chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + } + ] + ], + model_id="model_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + inference = await response.parse() + assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"]) + + @parametrize + async def test_streaming_response_batch_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.inference.with_streaming_response.batch_chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + } + ] + ], + model_id="model_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + inference = await response.parse() + assert_matches_type(InferenceBatchChatCompletionResponse, inference, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_batch_completion(self, async_client: AsyncLlamaStackClient) -> None: + inference = await async_client.inference.batch_completion( + content_batch=["string"], + model_id="model_id", + ) + assert_matches_type(BatchCompletion, inference, path=["response"]) + + @parametrize + async def test_method_batch_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + inference = await async_client.inference.batch_completion( + content_batch=["string"], + model_id="model_id", + logprobs={"top_k": 0}, + response_format={ + "json_schema": {"foo": True}, + "type": "json_schema", + }, + sampling_params={ + "strategy": {"type": "greedy"}, + "max_tokens": 0, + "repetition_penalty": 0, + "stop": ["string"], + }, + ) + assert_matches_type(BatchCompletion, inference, path=["response"]) + + @parametrize + async def test_raw_response_batch_completion(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.inference.with_raw_response.batch_completion( + content_batch=["string"], + model_id="model_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + inference = await response.parse() + assert_matches_type(BatchCompletion, inference, path=["response"]) + + @parametrize + async def test_streaming_response_batch_completion(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.inference.with_streaming_response.batch_completion( + content_batch=["string"], + model_id="model_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + inference = await response.parse() + assert_matches_type(BatchCompletion, inference, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_chat_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None: inference = await async_client.inference.chat_completion( @@ -403,6 +716,7 @@ async def test_method_chat_completion_with_all_params_overload_1(self, async_cli "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, stream=False, tool_choice="auto", @@ -500,6 +814,7 @@ async def test_method_chat_completion_with_all_params_overload_2(self, async_cli "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, tool_choice="auto", tool_config={ @@ -584,6 +899,7 @@ async def test_method_completion_with_all_params_overload_1(self, async_client: "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, stream=False, ) @@ -639,6 +955,7 @@ async def test_method_completion_with_all_params_overload_2(self, async_client: "strategy": {"type": "greedy"}, "max_tokens": 0, "repetition_penalty": 0, + "stop": ["string"], }, ) await inference_stream.response.aclose() diff --git a/tests/api_resources/test_tool_runtime.py b/tests/api_resources/test_tool_runtime.py index ca4279bb..b13e8c1f 100644 --- a/tests/api_resources/test_tool_runtime.py +++ b/tests/api_resources/test_tool_runtime.py @@ -10,10 +10,9 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient from llama_stack_client.types import ( - ToolDef, ToolInvocationResult, + ToolRuntimeListToolsResponse, ) -from llama_stack_client._decoders.jsonl import JSONLDecoder, AsyncJSONLDecoder base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -55,22 +54,19 @@ def test_streaming_response_invoke_tool(self, client: LlamaStackClient) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet") @parametrize def test_method_list_tools(self, client: LlamaStackClient) -> None: tool_runtime = client.tool_runtime.list_tools() - assert_matches_type(JSONLDecoder[ToolDef], tool_runtime, path=["response"]) + assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"]) - @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet") @parametrize def test_method_list_tools_with_all_params(self, client: LlamaStackClient) -> None: tool_runtime = client.tool_runtime.list_tools( mcp_endpoint={"uri": "uri"}, tool_group_id="tool_group_id", ) - assert_matches_type(JSONLDecoder[ToolDef], tool_runtime, path=["response"]) + assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"]) - @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet") @parametrize def test_raw_response_list_tools(self, client: LlamaStackClient) -> None: response = client.tool_runtime.with_raw_response.list_tools() @@ -78,9 +74,8 @@ def test_raw_response_list_tools(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" tool_runtime = response.parse() - assert_matches_type(JSONLDecoder[ToolDef], tool_runtime, path=["response"]) + assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"]) - @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet") @parametrize def test_streaming_response_list_tools(self, client: LlamaStackClient) -> None: with client.tool_runtime.with_streaming_response.list_tools() as response: @@ -88,7 +83,7 @@ def test_streaming_response_list_tools(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" tool_runtime = response.parse() - assert_matches_type(JSONLDecoder[ToolDef], tool_runtime, path=["response"]) + assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"]) assert cast(Any, response.is_closed) is True @@ -130,22 +125,19 @@ async def test_streaming_response_invoke_tool(self, async_client: AsyncLlamaStac assert cast(Any, response.is_closed) is True - @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet") @parametrize async def test_method_list_tools(self, async_client: AsyncLlamaStackClient) -> None: tool_runtime = await async_client.tool_runtime.list_tools() - assert_matches_type(AsyncJSONLDecoder[ToolDef], tool_runtime, path=["response"]) + assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"]) - @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet") @parametrize async def test_method_list_tools_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: tool_runtime = await async_client.tool_runtime.list_tools( mcp_endpoint={"uri": "uri"}, tool_group_id="tool_group_id", ) - assert_matches_type(AsyncJSONLDecoder[ToolDef], tool_runtime, path=["response"]) + assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"]) - @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet") @parametrize async def test_raw_response_list_tools(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.tool_runtime.with_raw_response.list_tools() @@ -153,9 +145,8 @@ async def test_raw_response_list_tools(self, async_client: AsyncLlamaStackClient assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" tool_runtime = await response.parse() - assert_matches_type(AsyncJSONLDecoder[ToolDef], tool_runtime, path=["response"]) + assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"]) - @pytest.mark.skip(reason="Prism doesn't support JSONL responses yet") @parametrize async def test_streaming_response_list_tools(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.tool_runtime.with_streaming_response.list_tools() as response: @@ -163,6 +154,6 @@ async def test_streaming_response_list_tools(self, async_client: AsyncLlamaStack assert response.http_request.headers.get("X-Stainless-Lang") == "python" tool_runtime = await response.parse() - assert_matches_type(AsyncJSONLDecoder[ToolDef], tool_runtime, path=["response"]) + assert_matches_type(ToolRuntimeListToolsResponse, tool_runtime, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/decoders/test_jsonl.py b/tests/decoders/test_jsonl.py deleted file mode 100644 index 54af8e49..00000000 --- a/tests/decoders/test_jsonl.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations - -from typing import Any, Iterator, AsyncIterator -from typing_extensions import TypeVar - -import httpx -import pytest - -from llama_stack_client._decoders.jsonl import JSONLDecoder, AsyncJSONLDecoder - -_T = TypeVar("_T") - - -@pytest.mark.asyncio -@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) -async def test_basic(sync: bool) -> None: - def body() -> Iterator[bytes]: - yield b'{"foo":true}\n' - yield b'{"bar":false}\n' - - iterator = make_jsonl_iterator( - content=body(), - sync=sync, - line_type=object, - ) - - assert await iter_next(iterator) == {"foo": True} - assert await iter_next(iterator) == {"bar": False} - - await assert_empty_iter(iterator) - - -@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) -async def test_new_lines_in_json( - sync: bool, -) -> None: - def body() -> Iterator[bytes]: - yield b'{"content":"Hello, world!\\nHow are you doing?"}' - - iterator = make_jsonl_iterator(content=body(), sync=sync, line_type=object) - - assert await iter_next(iterator) == {"content": "Hello, world!\nHow are you doing?"} - - -@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) -async def test_multi_byte_character_multiple_chunks( - sync: bool, -) -> None: - def body() -> Iterator[bytes]: - yield b'{"content":"' - # bytes taken from the string 'известни' and arbitrarily split - # so that some multi-byte characters span multiple chunks - yield b"\xd0" - yield b"\xb8\xd0\xb7\xd0" - yield b"\xb2\xd0\xb5\xd1\x81\xd1\x82\xd0\xbd\xd0\xb8" - yield b'"}\n' - - iterator = make_jsonl_iterator(content=body(), sync=sync, line_type=object) - - assert await iter_next(iterator) == {"content": "известни"} - - -async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]: - for chunk in iter: - yield chunk - - -async def iter_next(iter: Iterator[_T] | AsyncIterator[_T]) -> _T: - if isinstance(iter, AsyncIterator): - return await iter.__anext__() - return next(iter) - - -async def assert_empty_iter(decoder: JSONLDecoder[Any] | AsyncJSONLDecoder[Any]) -> None: - with pytest.raises((StopAsyncIteration, RuntimeError)): - await iter_next(decoder) - - -def make_jsonl_iterator( - content: Iterator[bytes], - *, - sync: bool, - line_type: type[_T], -) -> JSONLDecoder[_T] | AsyncJSONLDecoder[_T]: - if sync: - return JSONLDecoder(line_type=line_type, raw_iterator=content, http_response=httpx.Response(200)) - - return AsyncJSONLDecoder(line_type=line_type, raw_iterator=to_aiter(content), http_response=httpx.Response(200)) diff --git a/tests/test_client.py b/tests/test_client.py index 8a2992af..7ad6e189 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -313,9 +313,6 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" - def test_validate_headers(self) -> None: - client = LlamaStackClient(base_url=base_url, _strict_response_validation=True) - def test_default_query_option(self) -> None: client = LlamaStackClient( base_url=base_url, _strict_response_validation=True, default_query={"query_param": "bar"} @@ -1103,9 +1100,6 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" - def test_validate_headers(self) -> None: - client = AsyncLlamaStackClient(base_url=base_url, _strict_response_validation=True) - def test_default_query_option(self) -> None: client = AsyncLlamaStackClient( base_url=base_url, _strict_response_validation=True, default_query={"query_param": "bar"} @@ -1648,7 +1642,7 @@ def test_get_platform(self) -> None: import threading from llama_stack_client._utils import asyncify - from llama_stack_client._base_client import get_platform + from llama_stack_client._base_client import get_platform async def test_main() -> None: result = await asyncify(get_platform)() diff --git a/tests/test_transform.py b/tests/test_transform.py index 8ceafb36..b6eb411d 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -8,7 +8,7 @@ import pytest -from llama_stack_client._types import Base64FileInput +from llama_stack_client._types import NOT_GIVEN, Base64FileInput from llama_stack_client._utils import ( PropertyInfo, transform as _transform, @@ -432,3 +432,22 @@ async def test_base64_file_input(use_async: bool) -> None: assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == { "foo": "SGVsbG8sIHdvcmxkIQ==" } # type: ignore[comparison-overlap] + + +@parametrize +@pytest.mark.asyncio +async def test_transform_skipping(use_async: bool) -> None: + # lists of ints are left as-is + data = [1, 2, 3] + assert await transform(data, List[int], use_async) is data + + # iterables of ints are converted to a list + data = iter([1, 2, 3]) + assert await transform(data, Iterable[int], use_async) == [1, 2, 3] + + +@parametrize +@pytest.mark.asyncio +async def test_strips_notgiven(use_async: bool) -> None: + assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} + assert await transform({"foo_bar": NOT_GIVEN}, Foo1, use_async) == {}