diff --git a/pyproject.toml b/pyproject.toml
index 19c5603a..08b99493 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "llama_stack_client"
-version = "0.0.53rc4"
+version = "0.0.53rc5"
description = "The official Python library for the llama-stack-client API"
dynamic = ["readme"]
license = "Apache-2.0"
diff --git a/src/llama_stack_client/resources/inference.py b/src/llama_stack_client/resources/inference.py
index 248c00e8..63039ac6 100644
--- a/src/llama_stack_client/resources/inference.py
+++ b/src/llama_stack_client/resources/inference.py
@@ -3,7 +3,7 @@
from __future__ import annotations
from typing import Any, List, Iterable, cast
-from typing_extensions import Literal
+from typing_extensions import Literal, overload
import httpx
@@ -14,6 +14,7 @@
)
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import (
+ required_args,
maybe_transform,
strip_not_given,
async_maybe_transform,
@@ -26,6 +27,7 @@
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
+from .._streaming import Stream, AsyncStream
from .._base_client import make_request_options
from ..types.embeddings_response import EmbeddingsResponse
from ..types.inference_completion_response import InferenceCompletionResponse
@@ -55,6 +57,7 @@ def with_streaming_response(self) -> InferenceResourceWithStreamingResponse:
"""
return InferenceResourceWithStreamingResponse(self)
+ @overload
def chat_completion(
self,
*,
@@ -63,7 +66,7 @@ def chat_completion(
logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN,
response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
- stream: bool | NotGiven = NOT_GIVEN,
+ stream: Literal[False] | NotGiven = NOT_GIVEN,
tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN,
tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
@@ -95,6 +98,115 @@ def chat_completion(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ ...
+
+ @overload
+ def chat_completion(
+ self,
+ *,
+ messages: Iterable[inference_chat_completion_params.Message],
+ model_id: str,
+ stream: Literal[True],
+ logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
+ tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN,
+ tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Stream[InferenceChatCompletionResponse]:
+ """
+ Args:
+ tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the
+ form like { "type": "function", "function" : { "name": "function_name",
+ "description": "function_description", "parameters": {...} } }
+
+ `function_tag` -- This is an example of how you could define your own user
+ defined format for making tool calls. The function_tag format looks like this,
+ (parameters)
+
+ The detailed prompts for each of these formats are added to llama cli
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ def chat_completion(
+ self,
+ *,
+ messages: Iterable[inference_chat_completion_params.Message],
+ model_id: str,
+ stream: bool,
+ logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
+ tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN,
+ tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> InferenceChatCompletionResponse | Stream[InferenceChatCompletionResponse]:
+ """
+ Args:
+ tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the
+ form like { "type": "function", "function" : { "name": "function_name",
+ "description": "function_description", "parameters": {...} } }
+
+ `function_tag` -- This is an example of how you could define your own user
+ defined format for making tool calls. The function_tag format looks like this,
+ (parameters)
+
+ The detailed prompts for each of these formats are added to llama cli
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @required_args(["messages", "model_id"], ["messages", "model_id", "stream"])
+ def chat_completion(
+ self,
+ *,
+ messages: Iterable[inference_chat_completion_params.Message],
+ model_id: str,
+ logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
+ tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
+ tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN,
+ tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> InferenceChatCompletionResponse | Stream[InferenceChatCompletionResponse]:
extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})}
extra_headers = {
**strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
@@ -124,9 +236,12 @@ def chat_completion(
cast_to=cast(
Any, InferenceChatCompletionResponse
), # Union types cannot be passed in as arguments in the type system
+ stream=stream or False,
+ stream_cls=Stream[InferenceChatCompletionResponse],
),
)
+ @overload
def completion(
self,
*,
@@ -135,7 +250,7 @@ def completion(
logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
- stream: bool | NotGiven = NOT_GIVEN,
+ stream: Literal[False] | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
@@ -154,6 +269,86 @@ def completion(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ ...
+
+ @overload
+ def completion(
+ self,
+ *,
+ content: inference_completion_params.Content,
+ model_id: str,
+ stream: Literal[True],
+ logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Stream[InferenceCompletionResponse]:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ def completion(
+ self,
+ *,
+ content: inference_completion_params.Content,
+ model_id: str,
+ stream: bool,
+ logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> InferenceCompletionResponse | Stream[InferenceCompletionResponse]:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @required_args(["content", "model_id"], ["content", "model_id", "stream"])
+ def completion(
+ self,
+ *,
+ content: inference_completion_params.Content,
+ model_id: str,
+ logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> InferenceCompletionResponse | Stream[InferenceCompletionResponse]:
extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})}
extra_headers = {
**strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
@@ -180,6 +375,8 @@ def completion(
cast_to=cast(
Any, InferenceCompletionResponse
), # Union types cannot be passed in as arguments in the type system
+ stream=stream or False,
+ stream_cls=Stream[InferenceCompletionResponse],
),
)
@@ -246,6 +443,7 @@ def with_streaming_response(self) -> AsyncInferenceResourceWithStreamingResponse
"""
return AsyncInferenceResourceWithStreamingResponse(self)
+ @overload
async def chat_completion(
self,
*,
@@ -254,7 +452,7 @@ async def chat_completion(
logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN,
response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
- stream: bool | NotGiven = NOT_GIVEN,
+ stream: Literal[False] | NotGiven = NOT_GIVEN,
tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN,
tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
@@ -286,6 +484,115 @@ async def chat_completion(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ ...
+
+ @overload
+ async def chat_completion(
+ self,
+ *,
+ messages: Iterable[inference_chat_completion_params.Message],
+ model_id: str,
+ stream: Literal[True],
+ logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
+ tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN,
+ tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> AsyncStream[InferenceChatCompletionResponse]:
+ """
+ Args:
+ tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the
+ form like { "type": "function", "function" : { "name": "function_name",
+ "description": "function_description", "parameters": {...} } }
+
+ `function_tag` -- This is an example of how you could define your own user
+ defined format for making tool calls. The function_tag format looks like this,
+ (parameters)
+
+ The detailed prompts for each of these formats are added to llama cli
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ async def chat_completion(
+ self,
+ *,
+ messages: Iterable[inference_chat_completion_params.Message],
+ model_id: str,
+ stream: bool,
+ logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
+ tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN,
+ tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> InferenceChatCompletionResponse | AsyncStream[InferenceChatCompletionResponse]:
+ """
+ Args:
+ tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the
+ form like { "type": "function", "function" : { "name": "function_name",
+ "description": "function_description", "parameters": {...} } }
+
+ `function_tag` -- This is an example of how you could define your own user
+ defined format for making tool calls. The function_tag format looks like this,
+ (parameters)
+
+ The detailed prompts for each of these formats are added to llama cli
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @required_args(["messages", "model_id"], ["messages", "model_id", "stream"])
+ async def chat_completion(
+ self,
+ *,
+ messages: Iterable[inference_chat_completion_params.Message],
+ model_id: str,
+ logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
+ tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
+ tool_prompt_format: Literal["json", "function_tag", "python_list"] | NotGiven = NOT_GIVEN,
+ tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> InferenceChatCompletionResponse | AsyncStream[InferenceChatCompletionResponse]:
extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})}
extra_headers = {
**strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
@@ -315,9 +622,12 @@ async def chat_completion(
cast_to=cast(
Any, InferenceChatCompletionResponse
), # Union types cannot be passed in as arguments in the type system
+ stream=stream or False,
+ stream_cls=AsyncStream[InferenceChatCompletionResponse],
),
)
+ @overload
async def completion(
self,
*,
@@ -326,7 +636,7 @@ async def completion(
logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
- stream: bool | NotGiven = NOT_GIVEN,
+ stream: Literal[False] | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
@@ -345,6 +655,86 @@ async def completion(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ ...
+
+ @overload
+ async def completion(
+ self,
+ *,
+ content: inference_completion_params.Content,
+ model_id: str,
+ stream: Literal[True],
+ logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> AsyncStream[InferenceCompletionResponse]:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @overload
+ async def completion(
+ self,
+ *,
+ content: inference_completion_params.Content,
+ model_id: str,
+ stream: bool,
+ logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> InferenceCompletionResponse | AsyncStream[InferenceCompletionResponse]:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ ...
+
+ @required_args(["content", "model_id"], ["content", "model_id", "stream"])
+ async def completion(
+ self,
+ *,
+ content: inference_completion_params.Content,
+ model_id: str,
+ logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
+ response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
+ sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> InferenceCompletionResponse | AsyncStream[InferenceCompletionResponse]:
extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})}
extra_headers = {
**strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
@@ -371,6 +761,8 @@ async def completion(
cast_to=cast(
Any, InferenceCompletionResponse
), # Union types cannot be passed in as arguments in the type system
+ stream=stream or False,
+ stream_cls=AsyncStream[InferenceCompletionResponse],
),
)
diff --git a/src/llama_stack_client/types/inference_chat_completion_params.py b/src/llama_stack_client/types/inference_chat_completion_params.py
index ad393061..e4f263b0 100644
--- a/src/llama_stack_client/types/inference_chat_completion_params.py
+++ b/src/llama_stack_client/types/inference_chat_completion_params.py
@@ -14,17 +14,19 @@
from .shared_params.tool_response_message import ToolResponseMessage
__all__ = [
- "InferenceChatCompletionParams",
+ "InferenceChatCompletionParamsBase",
"Message",
"Logprobs",
"ResponseFormat",
"ResponseFormatJsonSchemaFormat",
"ResponseFormatGrammarFormat",
"Tool",
+ "InferenceChatCompletionParamsNonStreaming",
+ "InferenceChatCompletionParamsStreaming",
]
-class InferenceChatCompletionParams(TypedDict, total=False):
+class InferenceChatCompletionParamsBase(TypedDict, total=False):
messages: Required[Iterable[Message]]
model_id: Required[str]
@@ -35,8 +37,6 @@ class InferenceChatCompletionParams(TypedDict, total=False):
sampling_params: SamplingParams
- stream: bool
-
tool_choice: Literal["auto", "required"]
tool_prompt_format: Literal["json", "function_tag", "python_list"]
@@ -85,3 +85,14 @@ class Tool(TypedDict, total=False):
description: str
parameters: Dict[str, ToolParamDefinition]
+
+
+class InferenceChatCompletionParamsNonStreaming(InferenceChatCompletionParamsBase, total=False):
+ stream: Literal[False]
+
+
+class InferenceChatCompletionParamsStreaming(InferenceChatCompletionParamsBase):
+ stream: Required[Literal[True]]
+
+
+InferenceChatCompletionParams = Union[InferenceChatCompletionParamsNonStreaming, InferenceChatCompletionParamsStreaming]
diff --git a/src/llama_stack_client/types/inference_completion_params.py b/src/llama_stack_client/types/inference_completion_params.py
index 8c20a1f9..c7d813bc 100644
--- a/src/llama_stack_client/types/inference_completion_params.py
+++ b/src/llama_stack_client/types/inference_completion_params.py
@@ -10,17 +10,19 @@
from .shared_params.sampling_params import SamplingParams
__all__ = [
- "InferenceCompletionParams",
+ "InferenceCompletionParamsBase",
"Content",
"ContentImageMediaArray",
"Logprobs",
"ResponseFormat",
"ResponseFormatJsonSchemaFormat",
"ResponseFormatGrammarFormat",
+ "InferenceCompletionParamsNonStreaming",
+ "InferenceCompletionParamsStreaming",
]
-class InferenceCompletionParams(TypedDict, total=False):
+class InferenceCompletionParamsBase(TypedDict, total=False):
content: Required[Content]
model_id: Required[str]
@@ -31,8 +33,6 @@ class InferenceCompletionParams(TypedDict, total=False):
sampling_params: SamplingParams
- stream: bool
-
x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
@@ -58,3 +58,14 @@ class ResponseFormatGrammarFormat(TypedDict, total=False):
ResponseFormat: TypeAlias = Union[ResponseFormatJsonSchemaFormat, ResponseFormatGrammarFormat]
+
+
+class InferenceCompletionParamsNonStreaming(InferenceCompletionParamsBase, total=False):
+ stream: Literal[False]
+
+
+class InferenceCompletionParamsStreaming(InferenceCompletionParamsBase):
+ stream: Required[Literal[True]]
+
+
+InferenceCompletionParams = Union[InferenceCompletionParamsNonStreaming, InferenceCompletionParamsStreaming]
diff --git a/tests/api_resources/test_inference.py b/tests/api_resources/test_inference.py
index 60895ef3..539ddd1b 100644
--- a/tests/api_resources/test_inference.py
+++ b/tests/api_resources/test_inference.py
@@ -25,7 +25,7 @@ class TestInference:
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- def test_method_chat_completion(self, client: LlamaStackClient) -> None:
+ def test_method_chat_completion_overload_1(self, client: LlamaStackClient) -> None:
inference = client.inference.chat_completion(
messages=[
{
@@ -41,7 +41,7 @@ def test_method_chat_completion(self, client: LlamaStackClient) -> None:
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- def test_method_chat_completion_with_all_params(self, client: LlamaStackClient) -> None:
+ def test_method_chat_completion_with_all_params_overload_1(self, client: LlamaStackClient) -> None:
inference = client.inference.chat_completion(
messages=[
{
@@ -64,7 +64,7 @@ def test_method_chat_completion_with_all_params(self, client: LlamaStackClient)
"top_k": 0,
"top_p": 0,
},
- stream=True,
+ stream=False,
tool_choice="auto",
tool_prompt_format="json",
tools=[
@@ -89,7 +89,7 @@ def test_method_chat_completion_with_all_params(self, client: LlamaStackClient)
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- def test_raw_response_chat_completion(self, client: LlamaStackClient) -> None:
+ def test_raw_response_chat_completion_overload_1(self, client: LlamaStackClient) -> None:
response = client.inference.with_raw_response.chat_completion(
messages=[
{
@@ -109,7 +109,7 @@ def test_raw_response_chat_completion(self, client: LlamaStackClient) -> None:
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- def test_streaming_response_chat_completion(self, client: LlamaStackClient) -> None:
+ def test_streaming_response_chat_completion_overload_1(self, client: LlamaStackClient) -> None:
with client.inference.with_streaming_response.chat_completion(
messages=[
{
@@ -131,7 +131,115 @@ def test_streaming_response_chat_completion(self, client: LlamaStackClient) -> N
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- def test_method_completion(self, client: LlamaStackClient) -> None:
+ def test_method_chat_completion_overload_2(self, client: LlamaStackClient) -> None:
+ inference_stream = client.inference.chat_completion(
+ messages=[
+ {
+ "content": "string",
+ "role": "user",
+ }
+ ],
+ model_id="model_id",
+ stream=True,
+ )
+ inference_stream.response.close()
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ def test_method_chat_completion_with_all_params_overload_2(self, client: LlamaStackClient) -> None:
+ inference_stream = client.inference.chat_completion(
+ messages=[
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ }
+ ],
+ model_id="model_id",
+ stream=True,
+ logprobs={"top_k": 0},
+ response_format={
+ "json_schema": {"foo": True},
+ "type": "json_schema",
+ },
+ sampling_params={
+ "strategy": "greedy",
+ "max_tokens": 0,
+ "repetition_penalty": 0,
+ "temperature": 0,
+ "top_k": 0,
+ "top_p": 0,
+ },
+ tool_choice="auto",
+ tool_prompt_format="json",
+ tools=[
+ {
+ "tool_name": "brave_search",
+ "description": "description",
+ "parameters": {
+ "foo": {
+ "param_type": "param_type",
+ "default": True,
+ "description": "description",
+ "required": True,
+ }
+ },
+ }
+ ],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ inference_stream.response.close()
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ def test_raw_response_chat_completion_overload_2(self, client: LlamaStackClient) -> None:
+ response = client.inference.with_raw_response.chat_completion(
+ messages=[
+ {
+ "content": "string",
+ "role": "user",
+ }
+ ],
+ model_id="model_id",
+ stream=True,
+ )
+
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ stream = response.parse()
+ stream.close()
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ def test_streaming_response_chat_completion_overload_2(self, client: LlamaStackClient) -> None:
+ with client.inference.with_streaming_response.chat_completion(
+ messages=[
+ {
+ "content": "string",
+ "role": "user",
+ }
+ ],
+ model_id="model_id",
+ stream=True,
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ stream = response.parse()
+ stream.close()
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ def test_method_completion_overload_1(self, client: LlamaStackClient) -> None:
inference = client.inference.completion(
content="string",
model_id="model_id",
@@ -142,7 +250,7 @@ def test_method_completion(self, client: LlamaStackClient) -> None:
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- def test_method_completion_with_all_params(self, client: LlamaStackClient) -> None:
+ def test_method_completion_with_all_params_overload_1(self, client: LlamaStackClient) -> None:
inference = client.inference.completion(
content="string",
model_id="model_id",
@@ -159,7 +267,7 @@ def test_method_completion_with_all_params(self, client: LlamaStackClient) -> No
"top_k": 0,
"top_p": 0,
},
- stream=True,
+ stream=False,
x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(InferenceCompletionResponse, inference, path=["response"])
@@ -168,7 +276,7 @@ def test_method_completion_with_all_params(self, client: LlamaStackClient) -> No
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- def test_raw_response_completion(self, client: LlamaStackClient) -> None:
+ def test_raw_response_completion_overload_1(self, client: LlamaStackClient) -> None:
response = client.inference.with_raw_response.completion(
content="string",
model_id="model_id",
@@ -183,7 +291,7 @@ def test_raw_response_completion(self, client: LlamaStackClient) -> None:
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- def test_streaming_response_completion(self, client: LlamaStackClient) -> None:
+ def test_streaming_response_completion_overload_1(self, client: LlamaStackClient) -> None:
with client.inference.with_streaming_response.completion(
content="string",
model_id="model_id",
@@ -196,6 +304,77 @@ def test_streaming_response_completion(self, client: LlamaStackClient) -> None:
assert cast(Any, response.is_closed) is True
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ def test_method_completion_overload_2(self, client: LlamaStackClient) -> None:
+ inference_stream = client.inference.completion(
+ content="string",
+ model_id="model_id",
+ stream=True,
+ )
+ inference_stream.response.close()
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ def test_method_completion_with_all_params_overload_2(self, client: LlamaStackClient) -> None:
+ inference_stream = client.inference.completion(
+ content="string",
+ model_id="model_id",
+ stream=True,
+ logprobs={"top_k": 0},
+ response_format={
+ "json_schema": {"foo": True},
+ "type": "json_schema",
+ },
+ sampling_params={
+ "strategy": "greedy",
+ "max_tokens": 0,
+ "repetition_penalty": 0,
+ "temperature": 0,
+ "top_k": 0,
+ "top_p": 0,
+ },
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ inference_stream.response.close()
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ def test_raw_response_completion_overload_2(self, client: LlamaStackClient) -> None:
+ response = client.inference.with_raw_response.completion(
+ content="string",
+ model_id="model_id",
+ stream=True,
+ )
+
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ stream = response.parse()
+ stream.close()
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ def test_streaming_response_completion_overload_2(self, client: LlamaStackClient) -> None:
+ with client.inference.with_streaming_response.completion(
+ content="string",
+ model_id="model_id",
+ stream=True,
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ stream = response.parse()
+ stream.close()
+
+ assert cast(Any, response.is_closed) is True
+
@parametrize
def test_method_embeddings(self, client: LlamaStackClient) -> None:
inference = client.inference.embeddings(
@@ -247,7 +426,7 @@ class TestAsyncInference:
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- async def test_method_chat_completion(self, async_client: AsyncLlamaStackClient) -> None:
+ async def test_method_chat_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None:
inference = await async_client.inference.chat_completion(
messages=[
{
@@ -263,7 +442,7 @@ async def test_method_chat_completion(self, async_client: AsyncLlamaStackClient)
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- async def test_method_chat_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None:
+ async def test_method_chat_completion_with_all_params_overload_1(self, async_client: AsyncLlamaStackClient) -> None:
inference = await async_client.inference.chat_completion(
messages=[
{
@@ -286,7 +465,7 @@ async def test_method_chat_completion_with_all_params(self, async_client: AsyncL
"top_k": 0,
"top_p": 0,
},
- stream=True,
+ stream=False,
tool_choice="auto",
tool_prompt_format="json",
tools=[
@@ -311,7 +490,7 @@ async def test_method_chat_completion_with_all_params(self, async_client: AsyncL
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- async def test_raw_response_chat_completion(self, async_client: AsyncLlamaStackClient) -> None:
+ async def test_raw_response_chat_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None:
response = await async_client.inference.with_raw_response.chat_completion(
messages=[
{
@@ -331,7 +510,7 @@ async def test_raw_response_chat_completion(self, async_client: AsyncLlamaStackC
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- async def test_streaming_response_chat_completion(self, async_client: AsyncLlamaStackClient) -> None:
+ async def test_streaming_response_chat_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None:
async with async_client.inference.with_streaming_response.chat_completion(
messages=[
{
@@ -353,7 +532,115 @@ async def test_streaming_response_chat_completion(self, async_client: AsyncLlama
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- async def test_method_completion(self, async_client: AsyncLlamaStackClient) -> None:
+ async def test_method_chat_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None:
+ inference_stream = await async_client.inference.chat_completion(
+ messages=[
+ {
+ "content": "string",
+ "role": "user",
+ }
+ ],
+ model_id="model_id",
+ stream=True,
+ )
+ await inference_stream.response.aclose()
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ async def test_method_chat_completion_with_all_params_overload_2(self, async_client: AsyncLlamaStackClient) -> None:
+ inference_stream = await async_client.inference.chat_completion(
+ messages=[
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ }
+ ],
+ model_id="model_id",
+ stream=True,
+ logprobs={"top_k": 0},
+ response_format={
+ "json_schema": {"foo": True},
+ "type": "json_schema",
+ },
+ sampling_params={
+ "strategy": "greedy",
+ "max_tokens": 0,
+ "repetition_penalty": 0,
+ "temperature": 0,
+ "top_k": 0,
+ "top_p": 0,
+ },
+ tool_choice="auto",
+ tool_prompt_format="json",
+ tools=[
+ {
+ "tool_name": "brave_search",
+ "description": "description",
+ "parameters": {
+ "foo": {
+ "param_type": "param_type",
+ "default": True,
+ "description": "description",
+ "required": True,
+ }
+ },
+ }
+ ],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ await inference_stream.response.aclose()
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ async def test_raw_response_chat_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None:
+ response = await async_client.inference.with_raw_response.chat_completion(
+ messages=[
+ {
+ "content": "string",
+ "role": "user",
+ }
+ ],
+ model_id="model_id",
+ stream=True,
+ )
+
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ stream = await response.parse()
+ await stream.close()
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ async def test_streaming_response_chat_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None:
+ async with async_client.inference.with_streaming_response.chat_completion(
+ messages=[
+ {
+ "content": "string",
+ "role": "user",
+ }
+ ],
+ model_id="model_id",
+ stream=True,
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ stream = await response.parse()
+ await stream.close()
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ async def test_method_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None:
inference = await async_client.inference.completion(
content="string",
model_id="model_id",
@@ -364,7 +651,7 @@ async def test_method_completion(self, async_client: AsyncLlamaStackClient) -> N
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- async def test_method_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None:
+ async def test_method_completion_with_all_params_overload_1(self, async_client: AsyncLlamaStackClient) -> None:
inference = await async_client.inference.completion(
content="string",
model_id="model_id",
@@ -381,7 +668,7 @@ async def test_method_completion_with_all_params(self, async_client: AsyncLlamaS
"top_k": 0,
"top_p": 0,
},
- stream=True,
+ stream=False,
x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(InferenceCompletionResponse, inference, path=["response"])
@@ -390,7 +677,7 @@ async def test_method_completion_with_all_params(self, async_client: AsyncLlamaS
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- async def test_raw_response_completion(self, async_client: AsyncLlamaStackClient) -> None:
+ async def test_raw_response_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None:
response = await async_client.inference.with_raw_response.completion(
content="string",
model_id="model_id",
@@ -405,7 +692,7 @@ async def test_raw_response_completion(self, async_client: AsyncLlamaStackClient
reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
)
@parametrize
- async def test_streaming_response_completion(self, async_client: AsyncLlamaStackClient) -> None:
+ async def test_streaming_response_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None:
async with async_client.inference.with_streaming_response.completion(
content="string",
model_id="model_id",
@@ -418,6 +705,77 @@ async def test_streaming_response_completion(self, async_client: AsyncLlamaStack
assert cast(Any, response.is_closed) is True
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ async def test_method_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None:
+ inference_stream = await async_client.inference.completion(
+ content="string",
+ model_id="model_id",
+ stream=True,
+ )
+ await inference_stream.response.aclose()
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ async def test_method_completion_with_all_params_overload_2(self, async_client: AsyncLlamaStackClient) -> None:
+ inference_stream = await async_client.inference.completion(
+ content="string",
+ model_id="model_id",
+ stream=True,
+ logprobs={"top_k": 0},
+ response_format={
+ "json_schema": {"foo": True},
+ "type": "json_schema",
+ },
+ sampling_params={
+ "strategy": "greedy",
+ "max_tokens": 0,
+ "repetition_penalty": 0,
+ "temperature": 0,
+ "top_k": 0,
+ "top_p": 0,
+ },
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ await inference_stream.response.aclose()
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ async def test_raw_response_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None:
+ response = await async_client.inference.with_raw_response.completion(
+ content="string",
+ model_id="model_id",
+ stream=True,
+ )
+
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ stream = await response.parse()
+ await stream.close()
+
+ @pytest.mark.skip(
+ reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail"
+ )
+ @parametrize
+ async def test_streaming_response_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None:
+ async with async_client.inference.with_streaming_response.completion(
+ content="string",
+ model_id="model_id",
+ stream=True,
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ stream = await response.parse()
+ await stream.close()
+
+ assert cast(Any, response.is_closed) is True
+
@parametrize
async def test_method_embeddings(self, async_client: AsyncLlamaStackClient) -> None:
inference = await async_client.inference.embeddings(