diff --git a/src/llama_stack_client/_base_client.py b/src/llama_stack_client/_base_client.py index c8b0b413..90df64c4 100644 --- a/src/llama_stack_client/_base_client.py +++ b/src/llama_stack_client/_base_client.py @@ -518,7 +518,7 @@ def _build_request( # so that passing a `TypedDict` doesn't cause an error. # https://github.com/microsoft/pyright/issues/3526#event-6715453066 params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None, - json=json_data, + json=json_data if is_given(json_data) else None, files=files, **kwargs, ) diff --git a/src/llama_stack_client/_client.py b/src/llama_stack_client/_client.py index bb5bb755..35df3a1e 100644 --- a/src/llama_stack_client/_client.py +++ b/src/llama_stack_client/_client.py @@ -126,16 +126,16 @@ def __init__( ) -> None: """Construct a new synchronous llama-stack-client client instance. - This automatically infers the `api_key` argument from the `LLAMA_STACK_CLIENT_API_KEY` environment variable if it is not provided. + This automatically infers the `api_key` argument from the `LLAMA_STACK_API_KEY` environment variable if it is not provided. """ if api_key is None: - api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY") + api_key = os.environ.get("LLAMA_STACK_API_KEY") self.api_key = api_key if base_url is None: - base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL") + base_url = os.environ.get("LLAMA_STACK_BASE_URL") if base_url is None: - base_url = f"http://any-hosted-llama-stack.com" + base_url = "http://any-hosted-llama-stack.com" custom_headers = default_headers or {} custom_headers["X-LlamaStack-Client-Version"] = __version__ @@ -342,16 +342,16 @@ def __init__( ) -> None: """Construct a new async llama-stack-client client instance. - This automatically infers the `api_key` argument from the `LLAMA_STACK_CLIENT_API_KEY` environment variable if it is not provided. + This automatically infers the `api_key` argument from the `LLAMA_STACK_API_KEY` environment variable if it is not provided. """ if api_key is None: - api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY") + api_key = os.environ.get("LLAMA_STACK_API_KEY") self.api_key = api_key if base_url is None: - base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL") + base_url = os.environ.get("LLAMA_STACK_BASE_URL") if base_url is None: - base_url = f"http://any-hosted-llama-stack.com" + base_url = "http://any-hosted-llama-stack.com" custom_headers = default_headers or {} custom_headers["X-LlamaStack-Client-Version"] = __version__ diff --git a/src/llama_stack_client/_files.py b/src/llama_stack_client/_files.py index 715cc207..43d5ca1c 100644 --- a/src/llama_stack_client/_files.py +++ b/src/llama_stack_client/_files.py @@ -71,7 +71,7 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes: if is_tuple_t(file): return (file[0], _read_file_content(file[1]), *file[2:]) - raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") + raise TypeError("Expected file types input to be a FileContent type or to be a tuple") def _read_file_content(file: FileContent) -> HttpxFileContent: @@ -113,7 +113,7 @@ async def _async_transform_file(file: FileTypes) -> HttpxFileTypes: if is_tuple_t(file): return (file[0], await _async_read_file_content(file[1]), *file[2:]) - raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") + raise TypeError("Expected file types input to be a FileContent type or to be a tuple") async def _async_read_file_content(file: FileContent) -> HttpxFileContent: diff --git a/src/llama_stack_client/_response.py b/src/llama_stack_client/_response.py index ea35182f..31f945b7 100644 --- a/src/llama_stack_client/_response.py +++ b/src/llama_stack_client/_response.py @@ -229,7 +229,7 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: # the response class ourselves but that is something that should be supported directly in httpx # as it would be easy to incorrectly construct the Response object due to the multitude of arguments. if cast_to != httpx.Response: - raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`") + raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_to`") return cast(R, response) if ( @@ -245,9 +245,9 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: if ( cast_to is not object - and not origin is list - and not origin is dict - and not origin is Union + and origin is not list + and origin is not dict + and origin is not Union and not issubclass(origin, BaseModel) ): raise RuntimeError( diff --git a/src/llama_stack_client/_utils/_logs.py b/src/llama_stack_client/_utils/_logs.py index 39ff9635..49f3ee8c 100644 --- a/src/llama_stack_client/_utils/_logs.py +++ b/src/llama_stack_client/_utils/_logs.py @@ -14,7 +14,7 @@ def _basic_config() -> None: def setup_logging() -> None: - env = os.environ.get("LLAMA_STACK_CLIENT_LOG") + env = os.environ.get("LLAMA_STACK_LOG") if env == "debug": _basic_config() logger.setLevel(logging.DEBUG) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 3b7bcc7f..87badd46 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -3,8 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import uuid -from datetime import datetime from typing import Iterator, List, Optional, Tuple, Union from llama_stack_client import LlamaStackClient @@ -16,13 +14,7 @@ from llama_stack_client.types.agents.turn_create_response import ( AgentTurnResponseStreamChunk, ) -from llama_stack_client.types.agents.turn_response_event import TurnResponseEvent -from llama_stack_client.types.agents.turn_response_event_payload import ( - AgentTurnResponseStepCompletePayload, -) from llama_stack_client.types.shared.tool_call import ToolCall -from llama_stack_client.types.tool_execution_step import ToolExecutionStep -from llama_stack_client.types.tool_response import ToolResponse from .client_tool import ClientTool from .tool_parser import ToolParser @@ -66,7 +58,7 @@ def create_session(self, session_name: str) -> str: return self.session_id def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: - if chunk.event.payload.event_type != "turn_complete": + if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: return [] message = chunk.event.payload.turn.output_message @@ -78,6 +70,12 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall] return message.tool_calls + def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: + if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: + return None + + return chunk.event.payload.turn.turn_id + def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: assert len(tool_calls) == 1, "Only one tool call is supported" tool_call = tool_calls[0] @@ -132,27 +130,10 @@ def create_turn( if stream: return self._create_turn_streaming(messages, session_id, toolgroups, documents) else: - chunks = [] - for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents): - if chunk.event.payload.event_type == "turn_complete": - chunks.append(chunk) - pass + chunks = [x for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)] if not chunks: raise Exception("Turn did not complete") - - # merge chunks - return Turn( - input_messages=chunks[0].event.payload.turn.input_messages, - output_message=chunks[-1].event.payload.turn.output_message, - session_id=chunks[0].event.payload.turn.session_id, - steps=[step for chunk in chunks for step in chunk.event.payload.turn.steps], - turn_id=chunks[0].event.payload.turn.turn_id, - started_at=chunks[0].event.payload.turn.started_at, - completed_at=chunks[-1].event.payload.turn.completed_at, - output_attachments=[ - attachment for chunk in chunks for attachment in chunk.event.payload.turn.output_attachments - ], - ) + return chunks[-1].event.payload.turn def _create_turn_streaming( self, @@ -161,22 +142,26 @@ def _create_turn_streaming( toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, ) -> Iterator[AgentTurnResponseStreamChunk]: - stop = False n_iter = 0 max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) - while not stop and n_iter < max_iter: - response = self.client.agents.turn.create( - agent_id=self.agent_id, - # use specified session_id or last session created - session_id=session_id or self.session_id[-1], - messages=messages, - stream=True, - documents=documents, - toolgroups=toolgroups, - ) - # by default, we stop after the first turn - stop = True - for chunk in response: + + # 1. create an agent turn + turn_response = self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + toolgroups=toolgroups, + allow_turn_resume=True, + ) + + # 2. process turn and resume if there's a tool call + is_turn_complete = False + while not is_turn_complete: + is_turn_complete = True + for chunk in turn_response: tool_calls = self._get_tool_calls(chunk) if hasattr(chunk, "error"): yield chunk @@ -184,39 +169,23 @@ def _create_turn_streaming( elif not tool_calls: yield chunk else: - tool_execution_start_time = datetime.now() + is_turn_complete = False + turn_id = self._get_turn_id(chunk) + if n_iter == 0: + yield chunk + + # run the tools tool_response_message = self._run_tool(tool_calls) - tool_execution_step = ToolExecutionStep( - step_type="tool_execution", - step_id=str(uuid.uuid4()), - tool_calls=tool_calls, - tool_responses=[ - ToolResponse( - tool_name=tool_response_message.tool_name, - content=tool_response_message.content, - call_id=tool_response_message.call_id, - ) - ], - turn_id=chunk.event.payload.turn.turn_id, - completed_at=datetime.now(), - started_at=tool_execution_start_time, - ) - yield AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - event_type="step_complete", - step_id=tool_execution_step.step_id, - step_type="tool_execution", - step_details=tool_execution_step, - ) - ) + # pass it to next iteration + turn_response = self.client.agents.turn.resume( + agent_id=self.agent_id, + session_id=session_id or self.session_id[-1], + turn_id=turn_id, + tool_responses=[tool_response_message], + stream=True, ) - - # HACK: append the tool execution step to the turn - chunk.event.payload.turn.steps.append(tool_execution_step) - yield chunk - - # continue the turn when there's a tool call - stop = False - messages = [tool_response_message] n_iter += 1 + break + + if n_iter >= max_iter: + raise Exception(f"Turn did not complete in {max_iter} iterations") diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index d7fa514a..40a1d359 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -75,7 +75,7 @@ def _yield_printable_events( event = chunk.event event_type = event.payload.event_type - if event_type in {"turn_start", "turn_complete"}: + if event_type in {"turn_start", "turn_complete", "turn_awaiting_input"}: # Currently not logging any turn realted info yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey") return @@ -149,7 +149,9 @@ def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional if hasattr(chunk, "event"): previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None previous_step_type = ( - chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None + chunk.event.payload.step_type + if previous_event_type not in {"turn_start", "turn_complete", "turn_awaiting_input"} + else None ) return previous_event_type, previous_step_type return None, None diff --git a/src/llama_stack_client/resources/agents/turn.py b/src/llama_stack_client/resources/agents/turn.py index da659e26..56e130f3 100644 --- a/src/llama_stack_client/resources/agents/turn.py +++ b/src/llama_stack_client/resources/agents/turn.py @@ -23,8 +23,9 @@ ) from ..._streaming import Stream, AsyncStream from ..._base_client import make_request_options -from ...types.agents import turn_create_params +from ...types.agents import turn_create_params, turn_resume_params from ...types.agents.turn import Turn +from ...types.shared_params.tool_response_message import ToolResponseMessage from ...types.agents.agent_turn_response_stream_chunk import AgentTurnResponseStreamChunk __all__ = ["TurnResource", "AsyncTurnResource"] @@ -57,6 +58,7 @@ def create( *, agent_id: str, messages: Iterable[turn_create_params.Message], + allow_turn_resume: bool | NotGiven = NOT_GIVEN, documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN, stream: Literal[False] | NotGiven = NOT_GIVEN, tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN, @@ -90,6 +92,7 @@ def create( agent_id: str, messages: Iterable[turn_create_params.Message], stream: Literal[True], + allow_turn_resume: bool | NotGiven = NOT_GIVEN, documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN, tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN, toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN, @@ -122,6 +125,7 @@ def create( agent_id: str, messages: Iterable[turn_create_params.Message], stream: bool, + allow_turn_resume: bool | NotGiven = NOT_GIVEN, documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN, tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN, toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN, @@ -153,6 +157,7 @@ def create( *, agent_id: str, messages: Iterable[turn_create_params.Message], + allow_turn_resume: bool | NotGiven = NOT_GIVEN, documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN, stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN, @@ -173,6 +178,7 @@ def create( body=maybe_transform( { "messages": messages, + "allow_turn_resume": allow_turn_resume, "documents": documents, "stream": stream, "tool_config": tool_config, @@ -225,6 +231,129 @@ def retrieve( cast_to=Turn, ) + @overload + def resume( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + tool_responses: Iterable[ToolResponseMessage], + stream: Literal[False] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def resume( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + stream: Literal[True], + tool_responses: Iterable[ToolResponseMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Stream[AgentTurnResponseStreamChunk]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def resume( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + stream: bool, + tool_responses: Iterable[ToolResponseMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn | Stream[AgentTurnResponseStreamChunk]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["agent_id", "session_id", "tool_responses"], ["agent_id", "session_id", "stream", "tool_responses"]) + def resume( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + tool_responses: Iterable[ToolResponseMessage], + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn | Stream[AgentTurnResponseStreamChunk]: + if not agent_id: + raise ValueError(f"Expected a non-empty value for `agent_id` but received {agent_id!r}") + if not session_id: + raise ValueError(f"Expected a non-empty value for `session_id` but received {session_id!r}") + if not turn_id: + raise ValueError(f"Expected a non-empty value for `turn_id` but received {turn_id!r}") + return self._post( + f"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume", + body=maybe_transform( + { + "tool_responses": tool_responses, + "stream": stream, + }, + turn_resume_params.TurnResumeParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Turn, + stream=stream or False, + stream_cls=Stream[AgentTurnResponseStreamChunk], + ) + class AsyncTurnResource(AsyncAPIResource): @cached_property @@ -253,6 +382,7 @@ async def create( *, agent_id: str, messages: Iterable[turn_create_params.Message], + allow_turn_resume: bool | NotGiven = NOT_GIVEN, documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN, stream: Literal[False] | NotGiven = NOT_GIVEN, tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN, @@ -286,6 +416,7 @@ async def create( agent_id: str, messages: Iterable[turn_create_params.Message], stream: Literal[True], + allow_turn_resume: bool | NotGiven = NOT_GIVEN, documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN, tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN, toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN, @@ -318,6 +449,7 @@ async def create( agent_id: str, messages: Iterable[turn_create_params.Message], stream: bool, + allow_turn_resume: bool | NotGiven = NOT_GIVEN, documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN, tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN, toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN, @@ -349,6 +481,7 @@ async def create( *, agent_id: str, messages: Iterable[turn_create_params.Message], + allow_turn_resume: bool | NotGiven = NOT_GIVEN, documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN, stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN, @@ -369,6 +502,7 @@ async def create( body=await async_maybe_transform( { "messages": messages, + "allow_turn_resume": allow_turn_resume, "documents": documents, "stream": stream, "tool_config": tool_config, @@ -421,6 +555,129 @@ async def retrieve( cast_to=Turn, ) + @overload + async def resume( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + tool_responses: Iterable[ToolResponseMessage], + stream: Literal[False] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def resume( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + stream: Literal[True], + tool_responses: Iterable[ToolResponseMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AsyncStream[AgentTurnResponseStreamChunk]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def resume( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + stream: bool, + tool_responses: Iterable[ToolResponseMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn | AsyncStream[AgentTurnResponseStreamChunk]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["agent_id", "session_id", "tool_responses"], ["agent_id", "session_id", "stream", "tool_responses"]) + async def resume( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + tool_responses: Iterable[ToolResponseMessage], + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn | AsyncStream[AgentTurnResponseStreamChunk]: + if not agent_id: + raise ValueError(f"Expected a non-empty value for `agent_id` but received {agent_id!r}") + if not session_id: + raise ValueError(f"Expected a non-empty value for `session_id` but received {session_id!r}") + if not turn_id: + raise ValueError(f"Expected a non-empty value for `turn_id` but received {turn_id!r}") + return await self._post( + f"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume", + body=await async_maybe_transform( + { + "tool_responses": tool_responses, + "stream": stream, + }, + turn_resume_params.TurnResumeParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Turn, + stream=stream or False, + stream_cls=AsyncStream[AgentTurnResponseStreamChunk], + ) + class TurnResourceWithRawResponse: def __init__(self, turn: TurnResource) -> None: @@ -432,6 +689,9 @@ def __init__(self, turn: TurnResource) -> None: self.retrieve = to_raw_response_wrapper( turn.retrieve, ) + self.resume = to_raw_response_wrapper( + turn.resume, + ) class AsyncTurnResourceWithRawResponse: @@ -444,6 +704,9 @@ def __init__(self, turn: AsyncTurnResource) -> None: self.retrieve = async_to_raw_response_wrapper( turn.retrieve, ) + self.resume = async_to_raw_response_wrapper( + turn.resume, + ) class TurnResourceWithStreamingResponse: @@ -456,6 +719,9 @@ def __init__(self, turn: TurnResource) -> None: self.retrieve = to_streamed_response_wrapper( turn.retrieve, ) + self.resume = to_streamed_response_wrapper( + turn.resume, + ) class AsyncTurnResourceWithStreamingResponse: @@ -468,3 +734,6 @@ def __init__(self, turn: AsyncTurnResource) -> None: self.retrieve = async_to_streamed_response_wrapper( turn.retrieve, ) + self.resume = async_to_streamed_response_wrapper( + turn.resume, + ) diff --git a/src/llama_stack_client/types/agents/__init__.py b/src/llama_stack_client/types/agents/__init__.py index be21f291..30355cbf 100644 --- a/src/llama_stack_client/types/agents/__init__.py +++ b/src/llama_stack_client/types/agents/__init__.py @@ -5,6 +5,7 @@ from .turn import Turn as Turn from .session import Session as Session from .turn_create_params import TurnCreateParams as TurnCreateParams +from .turn_resume_params import TurnResumeParams as TurnResumeParams from .turn_response_event import TurnResponseEvent as TurnResponseEvent from .session_create_params import SessionCreateParams as SessionCreateParams from .step_retrieve_response import StepRetrieveResponse as StepRetrieveResponse diff --git a/src/llama_stack_client/types/agents/turn_create_params.py b/src/llama_stack_client/types/agents/turn_create_params.py index 357f572c..729ab74c 100644 --- a/src/llama_stack_client/types/agents/turn_create_params.py +++ b/src/llama_stack_client/types/agents/turn_create_params.py @@ -32,6 +32,8 @@ class TurnCreateParamsBase(TypedDict, total=False): messages: Required[Iterable[Message]] + allow_turn_resume: bool + documents: Iterable[Document] tool_config: ToolConfig diff --git a/src/llama_stack_client/types/agents/turn_response_event_payload.py b/src/llama_stack_client/types/agents/turn_response_event_payload.py index f12f8b03..e3315cb3 100644 --- a/src/llama_stack_client/types/agents/turn_response_event_payload.py +++ b/src/llama_stack_client/types/agents/turn_response_event_payload.py @@ -20,6 +20,7 @@ "AgentTurnResponseStepCompletePayloadStepDetails", "AgentTurnResponseTurnStartPayload", "AgentTurnResponseTurnCompletePayload", + "AgentTurnResponseTurnAwaitingInputPayload", ] @@ -72,6 +73,13 @@ class AgentTurnResponseTurnCompletePayload(BaseModel): """A single turn in an interaction with an Agentic System.""" +class AgentTurnResponseTurnAwaitingInputPayload(BaseModel): + event_type: Literal["turn_awaiting_input"] + + turn: Turn + """A single turn in an interaction with an Agentic System.""" + + TurnResponseEventPayload: TypeAlias = Annotated[ Union[ AgentTurnResponseStepStartPayload, @@ -79,6 +87,7 @@ class AgentTurnResponseTurnCompletePayload(BaseModel): AgentTurnResponseStepCompletePayload, AgentTurnResponseTurnStartPayload, AgentTurnResponseTurnCompletePayload, + AgentTurnResponseTurnAwaitingInputPayload, ], PropertyInfo(discriminator="event_type"), ] diff --git a/src/llama_stack_client/types/agents/turn_resume_params.py b/src/llama_stack_client/types/agents/turn_resume_params.py new file mode 100644 index 00000000..0df97072 --- /dev/null +++ b/src/llama_stack_client/types/agents/turn_resume_params.py @@ -0,0 +1,29 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Union, Iterable +from typing_extensions import Literal, Required, TypedDict + +from ..shared_params.tool_response_message import ToolResponseMessage + +__all__ = ["TurnResumeParamsBase", "TurnResumeParamsNonStreaming", "TurnResumeParamsStreaming"] + + +class TurnResumeParamsBase(TypedDict, total=False): + agent_id: Required[str] + + session_id: Required[str] + + tool_responses: Required[Iterable[ToolResponseMessage]] + + +class TurnResumeParamsNonStreaming(TurnResumeParamsBase, total=False): + stream: Literal[False] + + +class TurnResumeParamsStreaming(TurnResumeParamsBase): + stream: Required[Literal[True]] + + +TurnResumeParams = Union[TurnResumeParamsNonStreaming, TurnResumeParamsStreaming] diff --git a/tests/api_resources/agents/test_turn.py b/tests/api_resources/agents/test_turn.py index b64bf957..e74502bd 100644 --- a/tests/api_resources/agents/test_turn.py +++ b/tests/api_resources/agents/test_turn.py @@ -43,6 +43,7 @@ def test_method_create_with_all_params_overload_1(self, client: LlamaStackClient "context": "string", } ], + allow_turn_resume=True, documents=[ { "content": "string", @@ -151,6 +152,7 @@ def test_method_create_with_all_params_overload_2(self, client: LlamaStackClient } ], stream=True, + allow_turn_resume=True, documents=[ { "content": "string", @@ -293,6 +295,245 @@ def test_path_params_retrieve(self, client: LlamaStackClient) -> None: session_id="session_id", ) + @parametrize + def test_method_resume_overload_1(self, client: LlamaStackClient) -> None: + turn = client.agents.turn.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_method_resume_with_all_params_overload_1(self, client: LlamaStackClient) -> None: + turn = client.agents.turn.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + stream=False, + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_raw_response_resume_overload_1(self, client: LlamaStackClient) -> None: + response = client.agents.turn.with_raw_response.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + turn = response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_streaming_response_resume_overload_1(self, client: LlamaStackClient) -> None: + with client.agents.turn.with_streaming_response.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + turn = response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_resume_overload_1(self, client: LlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `agent_id` but received ''"): + client.agents.turn.with_raw_response.resume( + turn_id="turn_id", + agent_id="", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `session_id` but received ''"): + client.agents.turn.with_raw_response.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `turn_id` but received ''"): + client.agents.turn.with_raw_response.resume( + turn_id="", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + @parametrize + def test_method_resume_overload_2(self, client: LlamaStackClient) -> None: + turn_stream = client.agents.turn.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + turn_stream.response.close() + + @parametrize + def test_raw_response_resume_overload_2(self, client: LlamaStackClient) -> None: + response = client.agents.turn.with_raw_response.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = response.parse() + stream.close() + + @parametrize + def test_streaming_response_resume_overload_2(self, client: LlamaStackClient) -> None: + with client.agents.turn.with_streaming_response.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_resume_overload_2(self, client: LlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `agent_id` but received ''"): + client.agents.turn.with_raw_response.resume( + turn_id="turn_id", + agent_id="", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `session_id` but received ''"): + client.agents.turn.with_raw_response.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `turn_id` but received ''"): + client.agents.turn.with_raw_response.resume( + turn_id="", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + class TestAsyncTurn: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @@ -323,6 +564,7 @@ async def test_method_create_with_all_params_overload_1(self, async_client: Asyn "context": "string", } ], + allow_turn_resume=True, documents=[ { "content": "string", @@ -431,6 +673,7 @@ async def test_method_create_with_all_params_overload_2(self, async_client: Asyn } ], stream=True, + allow_turn_resume=True, documents=[ { "content": "string", @@ -572,3 +815,242 @@ async def test_path_params_retrieve(self, async_client: AsyncLlamaStackClient) - agent_id="agent_id", session_id="session_id", ) + + @parametrize + async def test_method_resume_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + turn = await async_client.agents.turn.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_method_resume_with_all_params_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + turn = await async_client.agents.turn.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + stream=False, + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_raw_response_resume_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.agents.turn.with_raw_response.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + turn = await response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_streaming_response_resume_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.agents.turn.with_streaming_response.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + turn = await response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_resume_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `agent_id` but received ''"): + await async_client.agents.turn.with_raw_response.resume( + turn_id="turn_id", + agent_id="", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `session_id` but received ''"): + await async_client.agents.turn.with_raw_response.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `turn_id` but received ''"): + await async_client.agents.turn.with_raw_response.resume( + turn_id="", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + @parametrize + async def test_method_resume_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + turn_stream = await async_client.agents.turn.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + await turn_stream.response.aclose() + + @parametrize + async def test_raw_response_resume_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.agents.turn.with_raw_response.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = await response.parse() + await stream.close() + + @parametrize + async def test_streaming_response_resume_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.agents.turn.with_streaming_response.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_resume_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `agent_id` but received ''"): + await async_client.agents.turn.with_raw_response.resume( + turn_id="turn_id", + agent_id="", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `session_id` but received ''"): + await async_client.agents.turn.with_raw_response.resume( + turn_id="turn_id", + agent_id="agent_id", + session_id="", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `turn_id` but received ''"): + await async_client.agents.turn.with_raw_response.resume( + turn_id="", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) diff --git a/tests/test_client.py b/tests/test_client.py index f282f616..8a2992af 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -526,7 +526,7 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" def test_base_url_env(self) -> None: - with update_env(LLAMA_STACK_CLIENT_BASE_URL="http://localhost:5000/from/env"): + with update_env(LLAMA_STACK_BASE_URL="http://localhost:5000/from/env"): client = LlamaStackClient(_strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" @@ -1316,7 +1316,7 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" def test_base_url_env(self) -> None: - with update_env(LLAMA_STACK_CLIENT_BASE_URL="http://localhost:5000/from/env"): + with update_env(LLAMA_STACK_BASE_URL="http://localhost:5000/from/env"): client = AsyncLlamaStackClient(_strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/"