diff --git a/pyproject.toml b/pyproject.toml index 19c5603a..08b99493 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "llama_stack_client" -version = "0.0.53rc4" +version = "0.0.53rc5" description = "The official Python library for the llama-stack-client API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/llama_stack_client/resources/inference.py b/src/llama_stack_client/resources/inference.py index 248c00e8..63039ac6 100644 --- a/src/llama_stack_client/resources/inference.py +++ b/src/llama_stack_client/resources/inference.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, List, Iterable, cast -from typing_extensions import Literal +from typing_extensions import Literal, overload import httpx @@ -14,6 +14,7 @@ ) from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven from .._utils import ( + required_args, maybe_transform, strip_not_given, async_maybe_transform, @@ -26,6 +27,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from .._streaming import Stream, AsyncStream from .._base_client import make_request_options from ..types.embeddings_response import EmbeddingsResponse from ..types.inference_completion_response import InferenceCompletionResponse @@ -55,6 +57,7 @@ def with_streaming_response(self) -> InferenceResourceWithStreamingResponse: """ return InferenceResourceWithStreamingResponse(self) + @overload def chat_completion( self, *, @@ -63,7 +66,7 @@ def chat_completion( logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, sampling_params: SamplingParams | NotGiven = NOT_GIVEN, - stream: bool | NotGiven = NOT_GIVEN, + stream: Literal[False] | NotGiven = NOT_GIVEN, tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN, tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, @@ -95,6 +98,115 @@ def chat_completion( timeout: Override the client-level default timeout for this request, in seconds """ + ... + + @overload + def chat_completion( + self, + *, + messages: Iterable[inference_chat_completion_params.Message], + model_id: str, + stream: Literal[True], + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Stream[InferenceChatCompletionResponse]: + """ + Args: + tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def chat_completion( + self, + *, + messages: Iterable[inference_chat_completion_params.Message], + model_id: str, + stream: bool, + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceChatCompletionResponse | Stream[InferenceChatCompletionResponse]: + """ + Args: + tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["messages", "model_id"], ["messages", "model_id", "stream"]) + def chat_completion( + self, + *, + messages: Iterable[inference_chat_completion_params.Message], + model_id: str, + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceChatCompletionResponse | Stream[InferenceChatCompletionResponse]: extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} extra_headers = { **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), @@ -124,9 +236,12 @@ def chat_completion( cast_to=cast( Any, InferenceChatCompletionResponse ), # Union types cannot be passed in as arguments in the type system + stream=stream or False, + stream_cls=Stream[InferenceChatCompletionResponse], ), ) + @overload def completion( self, *, @@ -135,7 +250,7 @@ def completion( logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, sampling_params: SamplingParams | NotGiven = NOT_GIVEN, - stream: bool | NotGiven = NOT_GIVEN, + stream: Literal[False] | NotGiven = NOT_GIVEN, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -154,6 +269,86 @@ def completion( timeout: Override the client-level default timeout for this request, in seconds """ + ... + + @overload + def completion( + self, + *, + content: inference_completion_params.Content, + model_id: str, + stream: Literal[True], + logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Stream[InferenceCompletionResponse]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def completion( + self, + *, + content: inference_completion_params.Content, + model_id: str, + stream: bool, + logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceCompletionResponse | Stream[InferenceCompletionResponse]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["content", "model_id"], ["content", "model_id", "stream"]) + def completion( + self, + *, + content: inference_completion_params.Content, + model_id: str, + logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceCompletionResponse | Stream[InferenceCompletionResponse]: extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} extra_headers = { **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), @@ -180,6 +375,8 @@ def completion( cast_to=cast( Any, InferenceCompletionResponse ), # Union types cannot be passed in as arguments in the type system + stream=stream or False, + stream_cls=Stream[InferenceCompletionResponse], ), ) @@ -246,6 +443,7 @@ def with_streaming_response(self) -> AsyncInferenceResourceWithStreamingResponse """ return AsyncInferenceResourceWithStreamingResponse(self) + @overload async def chat_completion( self, *, @@ -254,7 +452,7 @@ async def chat_completion( logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, sampling_params: SamplingParams | NotGiven = NOT_GIVEN, - stream: bool | NotGiven = NOT_GIVEN, + stream: Literal[False] | NotGiven = NOT_GIVEN, tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN, tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, @@ -286,6 +484,115 @@ async def chat_completion( timeout: Override the client-level default timeout for this request, in seconds """ + ... + + @overload + async def chat_completion( + self, + *, + messages: Iterable[inference_chat_completion_params.Message], + model_id: str, + stream: Literal[True], + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AsyncStream[InferenceChatCompletionResponse]: + """ + Args: + tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def chat_completion( + self, + *, + messages: Iterable[inference_chat_completion_params.Message], + model_id: str, + stream: bool, + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceChatCompletionResponse | AsyncStream[InferenceChatCompletionResponse]: + """ + Args: + tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["messages", "model_id"], ["messages", "model_id", "stream"]) + async def chat_completion( + self, + *, + messages: Iterable[inference_chat_completion_params.Message], + model_id: str, + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceChatCompletionResponse | AsyncStream[InferenceChatCompletionResponse]: extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} extra_headers = { **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), @@ -315,9 +622,12 @@ async def chat_completion( cast_to=cast( Any, InferenceChatCompletionResponse ), # Union types cannot be passed in as arguments in the type system + stream=stream or False, + stream_cls=AsyncStream[InferenceChatCompletionResponse], ), ) + @overload async def completion( self, *, @@ -326,7 +636,7 @@ async def completion( logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, sampling_params: SamplingParams | NotGiven = NOT_GIVEN, - stream: bool | NotGiven = NOT_GIVEN, + stream: Literal[False] | NotGiven = NOT_GIVEN, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -345,6 +655,86 @@ async def completion( timeout: Override the client-level default timeout for this request, in seconds """ + ... + + @overload + async def completion( + self, + *, + content: inference_completion_params.Content, + model_id: str, + stream: Literal[True], + logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AsyncStream[InferenceCompletionResponse]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def completion( + self, + *, + content: inference_completion_params.Content, + model_id: str, + stream: bool, + logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceCompletionResponse | AsyncStream[InferenceCompletionResponse]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["content", "model_id"], ["content", "model_id", "stream"]) + async def completion( + self, + *, + content: inference_completion_params.Content, + model_id: str, + logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceCompletionResponse | AsyncStream[InferenceCompletionResponse]: extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} extra_headers = { **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), @@ -371,6 +761,8 @@ async def completion( cast_to=cast( Any, InferenceCompletionResponse ), # Union types cannot be passed in as arguments in the type system + stream=stream or False, + stream_cls=AsyncStream[InferenceCompletionResponse], ), ) diff --git a/src/llama_stack_client/types/inference_chat_completion_params.py b/src/llama_stack_client/types/inference_chat_completion_params.py index ad393061..e4f263b0 100644 --- a/src/llama_stack_client/types/inference_chat_completion_params.py +++ b/src/llama_stack_client/types/inference_chat_completion_params.py @@ -14,17 +14,19 @@ from .shared_params.tool_response_message import ToolResponseMessage __all__ = [ - "InferenceChatCompletionParams", + "InferenceChatCompletionParamsBase", "Message", "Logprobs", "ResponseFormat", "ResponseFormatJsonSchemaFormat", "ResponseFormatGrammarFormat", "Tool", + "InferenceChatCompletionParamsNonStreaming", + "InferenceChatCompletionParamsStreaming", ] -class InferenceChatCompletionParams(TypedDict, total=False): +class InferenceChatCompletionParamsBase(TypedDict, total=False): messages: Required[Iterable[Message]] model_id: Required[str] @@ -35,8 +37,6 @@ class InferenceChatCompletionParams(TypedDict, total=False): sampling_params: SamplingParams - stream: bool - tool_choice: Literal["auto", "required"] tool_prompt_format: Literal["json", "function_tag", "python_list"] @@ -85,3 +85,14 @@ class Tool(TypedDict, total=False): description: str parameters: Dict[str, ToolParamDefinition] + + +class InferenceChatCompletionParamsNonStreaming(InferenceChatCompletionParamsBase, total=False): + stream: Literal[False] + + +class InferenceChatCompletionParamsStreaming(InferenceChatCompletionParamsBase): + stream: Required[Literal[True]] + + +InferenceChatCompletionParams = Union[InferenceChatCompletionParamsNonStreaming, InferenceChatCompletionParamsStreaming] diff --git a/src/llama_stack_client/types/inference_completion_params.py b/src/llama_stack_client/types/inference_completion_params.py index 8c20a1f9..c7d813bc 100644 --- a/src/llama_stack_client/types/inference_completion_params.py +++ b/src/llama_stack_client/types/inference_completion_params.py @@ -10,17 +10,19 @@ from .shared_params.sampling_params import SamplingParams __all__ = [ - "InferenceCompletionParams", + "InferenceCompletionParamsBase", "Content", "ContentImageMediaArray", "Logprobs", "ResponseFormat", "ResponseFormatJsonSchemaFormat", "ResponseFormatGrammarFormat", + "InferenceCompletionParamsNonStreaming", + "InferenceCompletionParamsStreaming", ] -class InferenceCompletionParams(TypedDict, total=False): +class InferenceCompletionParamsBase(TypedDict, total=False): content: Required[Content] model_id: Required[str] @@ -31,8 +33,6 @@ class InferenceCompletionParams(TypedDict, total=False): sampling_params: SamplingParams - stream: bool - x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] @@ -58,3 +58,14 @@ class ResponseFormatGrammarFormat(TypedDict, total=False): ResponseFormat: TypeAlias = Union[ResponseFormatJsonSchemaFormat, ResponseFormatGrammarFormat] + + +class InferenceCompletionParamsNonStreaming(InferenceCompletionParamsBase, total=False): + stream: Literal[False] + + +class InferenceCompletionParamsStreaming(InferenceCompletionParamsBase): + stream: Required[Literal[True]] + + +InferenceCompletionParams = Union[InferenceCompletionParamsNonStreaming, InferenceCompletionParamsStreaming] diff --git a/tests/api_resources/test_inference.py b/tests/api_resources/test_inference.py index 60895ef3..539ddd1b 100644 --- a/tests/api_resources/test_inference.py +++ b/tests/api_resources/test_inference.py @@ -25,7 +25,7 @@ class TestInference: reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - def test_method_chat_completion(self, client: LlamaStackClient) -> None: + def test_method_chat_completion_overload_1(self, client: LlamaStackClient) -> None: inference = client.inference.chat_completion( messages=[ { @@ -41,7 +41,7 @@ def test_method_chat_completion(self, client: LlamaStackClient) -> None: reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - def test_method_chat_completion_with_all_params(self, client: LlamaStackClient) -> None: + def test_method_chat_completion_with_all_params_overload_1(self, client: LlamaStackClient) -> None: inference = client.inference.chat_completion( messages=[ { @@ -64,7 +64,7 @@ def test_method_chat_completion_with_all_params(self, client: LlamaStackClient) "top_k": 0, "top_p": 0, }, - stream=True, + stream=False, tool_choice="auto", tool_prompt_format="json", tools=[ @@ -89,7 +89,7 @@ def test_method_chat_completion_with_all_params(self, client: LlamaStackClient) reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - def test_raw_response_chat_completion(self, client: LlamaStackClient) -> None: + def test_raw_response_chat_completion_overload_1(self, client: LlamaStackClient) -> None: response = client.inference.with_raw_response.chat_completion( messages=[ { @@ -109,7 +109,7 @@ def test_raw_response_chat_completion(self, client: LlamaStackClient) -> None: reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - def test_streaming_response_chat_completion(self, client: LlamaStackClient) -> None: + def test_streaming_response_chat_completion_overload_1(self, client: LlamaStackClient) -> None: with client.inference.with_streaming_response.chat_completion( messages=[ { @@ -131,7 +131,115 @@ def test_streaming_response_chat_completion(self, client: LlamaStackClient) -> N reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - def test_method_completion(self, client: LlamaStackClient) -> None: + def test_method_chat_completion_overload_2(self, client: LlamaStackClient) -> None: + inference_stream = client.inference.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + stream=True, + ) + inference_stream.response.close() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_method_chat_completion_with_all_params_overload_2(self, client: LlamaStackClient) -> None: + inference_stream = client.inference.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + } + ], + model_id="model_id", + stream=True, + logprobs={"top_k": 0}, + response_format={ + "json_schema": {"foo": True}, + "type": "json_schema", + }, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + tool_choice="auto", + tool_prompt_format="json", + tools=[ + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "default": True, + "description": "description", + "required": True, + } + }, + } + ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + inference_stream.response.close() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_raw_response_chat_completion_overload_2(self, client: LlamaStackClient) -> None: + response = client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = response.parse() + stream.close() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_streaming_response_chat_completion_overload_2(self, client: LlamaStackClient) -> None: + with client.inference.with_streaming_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_method_completion_overload_1(self, client: LlamaStackClient) -> None: inference = client.inference.completion( content="string", model_id="model_id", @@ -142,7 +250,7 @@ def test_method_completion(self, client: LlamaStackClient) -> None: reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - def test_method_completion_with_all_params(self, client: LlamaStackClient) -> None: + def test_method_completion_with_all_params_overload_1(self, client: LlamaStackClient) -> None: inference = client.inference.completion( content="string", model_id="model_id", @@ -159,7 +267,7 @@ def test_method_completion_with_all_params(self, client: LlamaStackClient) -> No "top_k": 0, "top_p": 0, }, - stream=True, + stream=False, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @@ -168,7 +276,7 @@ def test_method_completion_with_all_params(self, client: LlamaStackClient) -> No reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - def test_raw_response_completion(self, client: LlamaStackClient) -> None: + def test_raw_response_completion_overload_1(self, client: LlamaStackClient) -> None: response = client.inference.with_raw_response.completion( content="string", model_id="model_id", @@ -183,7 +291,7 @@ def test_raw_response_completion(self, client: LlamaStackClient) -> None: reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - def test_streaming_response_completion(self, client: LlamaStackClient) -> None: + def test_streaming_response_completion_overload_1(self, client: LlamaStackClient) -> None: with client.inference.with_streaming_response.completion( content="string", model_id="model_id", @@ -196,6 +304,77 @@ def test_streaming_response_completion(self, client: LlamaStackClient) -> None: assert cast(Any, response.is_closed) is True + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_method_completion_overload_2(self, client: LlamaStackClient) -> None: + inference_stream = client.inference.completion( + content="string", + model_id="model_id", + stream=True, + ) + inference_stream.response.close() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_method_completion_with_all_params_overload_2(self, client: LlamaStackClient) -> None: + inference_stream = client.inference.completion( + content="string", + model_id="model_id", + stream=True, + logprobs={"top_k": 0}, + response_format={ + "json_schema": {"foo": True}, + "type": "json_schema", + }, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + inference_stream.response.close() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_raw_response_completion_overload_2(self, client: LlamaStackClient) -> None: + response = client.inference.with_raw_response.completion( + content="string", + model_id="model_id", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = response.parse() + stream.close() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_streaming_response_completion_overload_2(self, client: LlamaStackClient) -> None: + with client.inference.with_streaming_response.completion( + content="string", + model_id="model_id", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_embeddings(self, client: LlamaStackClient) -> None: inference = client.inference.embeddings( @@ -247,7 +426,7 @@ class TestAsyncInference: reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - async def test_method_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: + async def test_method_chat_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None: inference = await async_client.inference.chat_completion( messages=[ { @@ -263,7 +442,7 @@ async def test_method_chat_completion(self, async_client: AsyncLlamaStackClient) reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - async def test_method_chat_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + async def test_method_chat_completion_with_all_params_overload_1(self, async_client: AsyncLlamaStackClient) -> None: inference = await async_client.inference.chat_completion( messages=[ { @@ -286,7 +465,7 @@ async def test_method_chat_completion_with_all_params(self, async_client: AsyncL "top_k": 0, "top_p": 0, }, - stream=True, + stream=False, tool_choice="auto", tool_prompt_format="json", tools=[ @@ -311,7 +490,7 @@ async def test_method_chat_completion_with_all_params(self, async_client: AsyncL reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - async def test_raw_response_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: + async def test_raw_response_chat_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.inference.with_raw_response.chat_completion( messages=[ { @@ -331,7 +510,7 @@ async def test_raw_response_chat_completion(self, async_client: AsyncLlamaStackC reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - async def test_streaming_response_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: + async def test_streaming_response_chat_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.inference.with_streaming_response.chat_completion( messages=[ { @@ -353,7 +532,115 @@ async def test_streaming_response_chat_completion(self, async_client: AsyncLlama reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - async def test_method_completion(self, async_client: AsyncLlamaStackClient) -> None: + async def test_method_chat_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + inference_stream = await async_client.inference.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + stream=True, + ) + await inference_stream.response.aclose() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_method_chat_completion_with_all_params_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + inference_stream = await async_client.inference.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + } + ], + model_id="model_id", + stream=True, + logprobs={"top_k": 0}, + response_format={ + "json_schema": {"foo": True}, + "type": "json_schema", + }, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + tool_choice="auto", + tool_prompt_format="json", + tools=[ + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "default": True, + "description": "description", + "required": True, + } + }, + } + ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + await inference_stream.response.aclose() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_raw_response_chat_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = await response.parse() + await stream.close() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_streaming_response_chat_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.inference.with_streaming_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_method_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None: inference = await async_client.inference.completion( content="string", model_id="model_id", @@ -364,7 +651,7 @@ async def test_method_completion(self, async_client: AsyncLlamaStackClient) -> N reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - async def test_method_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + async def test_method_completion_with_all_params_overload_1(self, async_client: AsyncLlamaStackClient) -> None: inference = await async_client.inference.completion( content="string", model_id="model_id", @@ -381,7 +668,7 @@ async def test_method_completion_with_all_params(self, async_client: AsyncLlamaS "top_k": 0, "top_p": 0, }, - stream=True, + stream=False, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @@ -390,7 +677,7 @@ async def test_method_completion_with_all_params(self, async_client: AsyncLlamaS reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - async def test_raw_response_completion(self, async_client: AsyncLlamaStackClient) -> None: + async def test_raw_response_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.inference.with_raw_response.completion( content="string", model_id="model_id", @@ -405,7 +692,7 @@ async def test_raw_response_completion(self, async_client: AsyncLlamaStackClient reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" ) @parametrize - async def test_streaming_response_completion(self, async_client: AsyncLlamaStackClient) -> None: + async def test_streaming_response_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.inference.with_streaming_response.completion( content="string", model_id="model_id", @@ -418,6 +705,77 @@ async def test_streaming_response_completion(self, async_client: AsyncLlamaStack assert cast(Any, response.is_closed) is True + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_method_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + inference_stream = await async_client.inference.completion( + content="string", + model_id="model_id", + stream=True, + ) + await inference_stream.response.aclose() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_method_completion_with_all_params_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + inference_stream = await async_client.inference.completion( + content="string", + model_id="model_id", + stream=True, + logprobs={"top_k": 0}, + response_format={ + "json_schema": {"foo": True}, + "type": "json_schema", + }, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + await inference_stream.response.aclose() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_raw_response_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.inference.with_raw_response.completion( + content="string", + model_id="model_id", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = await response.parse() + await stream.close() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_streaming_response_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.inference.with_streaming_response.completion( + content="string", + model_id="model_id", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_embeddings(self, async_client: AsyncLlamaStackClient) -> None: inference = await async_client.inference.embeddings(