From 08681155d85f4b330f777fbfe4e83482ba457b7a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 17:07:26 -0800 Subject: [PATCH 1/6] Sync updates from stainless branch: yanxi0830/dev --- src/llama_stack_client/_client.py | 12 +- src/llama_stack_client/_utils/_logs.py | 2 +- .../resources/agents/turn.py | 261 +++++++++- .../types/agents/__init__.py | 1 + .../types/agents/turn_continue_params.py | 29 ++ .../agents/turn_response_event_payload.py | 9 + tests/api_resources/agents/test_turn.py | 478 ++++++++++++++++++ tests/test_client.py | 4 +- 8 files changed, 786 insertions(+), 10 deletions(-) create mode 100644 src/llama_stack_client/types/agents/turn_continue_params.py diff --git a/src/llama_stack_client/_client.py b/src/llama_stack_client/_client.py index bb5bb755..760eaeee 100644 --- a/src/llama_stack_client/_client.py +++ b/src/llama_stack_client/_client.py @@ -126,14 +126,14 @@ def __init__( ) -> None: """Construct a new synchronous llama-stack-client client instance. - This automatically infers the `api_key` argument from the `LLAMA_STACK_CLIENT_API_KEY` environment variable if it is not provided. + This automatically infers the `api_key` argument from the `LLAMA_STACK_API_KEY` environment variable if it is not provided. """ if api_key is None: - api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY") + api_key = os.environ.get("LLAMA_STACK_API_KEY") self.api_key = api_key if base_url is None: - base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL") + base_url = os.environ.get("LLAMA_STACK_BASE_URL") if base_url is None: base_url = f"http://any-hosted-llama-stack.com" @@ -342,14 +342,14 @@ def __init__( ) -> None: """Construct a new async llama-stack-client client instance. - This automatically infers the `api_key` argument from the `LLAMA_STACK_CLIENT_API_KEY` environment variable if it is not provided. + This automatically infers the `api_key` argument from the `LLAMA_STACK_API_KEY` environment variable if it is not provided. """ if api_key is None: - api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY") + api_key = os.environ.get("LLAMA_STACK_API_KEY") self.api_key = api_key if base_url is None: - base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL") + base_url = os.environ.get("LLAMA_STACK_BASE_URL") if base_url is None: base_url = f"http://any-hosted-llama-stack.com" diff --git a/src/llama_stack_client/_utils/_logs.py b/src/llama_stack_client/_utils/_logs.py index 39ff9635..49f3ee8c 100644 --- a/src/llama_stack_client/_utils/_logs.py +++ b/src/llama_stack_client/_utils/_logs.py @@ -14,7 +14,7 @@ def _basic_config() -> None: def setup_logging() -> None: - env = os.environ.get("LLAMA_STACK_CLIENT_LOG") + env = os.environ.get("LLAMA_STACK_LOG") if env == "debug": _basic_config() logger.setLevel(logging.DEBUG) diff --git a/src/llama_stack_client/resources/agents/turn.py b/src/llama_stack_client/resources/agents/turn.py index da659e26..fd3414a8 100644 --- a/src/llama_stack_client/resources/agents/turn.py +++ b/src/llama_stack_client/resources/agents/turn.py @@ -23,8 +23,9 @@ ) from ..._streaming import Stream, AsyncStream from ..._base_client import make_request_options -from ...types.agents import turn_create_params +from ...types.agents import turn_create_params, turn_continue_params from ...types.agents.turn import Turn +from ...types.shared_params.tool_response_message import ToolResponseMessage from ...types.agents.agent_turn_response_stream_chunk import AgentTurnResponseStreamChunk __all__ = ["TurnResource", "AsyncTurnResource"] @@ -225,6 +226,129 @@ def retrieve( cast_to=Turn, ) + @overload + def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + tool_responses: Iterable[ToolResponseMessage], + stream: Literal[False] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + stream: Literal[True], + tool_responses: Iterable[ToolResponseMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Stream[AgentTurnResponseStreamChunk]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + stream: bool, + tool_responses: Iterable[ToolResponseMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn | Stream[AgentTurnResponseStreamChunk]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["agent_id", "session_id", "tool_responses"], ["agent_id", "session_id", "stream", "tool_responses"]) + def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + tool_responses: Iterable[ToolResponseMessage], + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn | Stream[AgentTurnResponseStreamChunk]: + if not agent_id: + raise ValueError(f"Expected a non-empty value for `agent_id` but received {agent_id!r}") + if not session_id: + raise ValueError(f"Expected a non-empty value for `session_id` but received {session_id!r}") + if not turn_id: + raise ValueError(f"Expected a non-empty value for `turn_id` but received {turn_id!r}") + return self._post( + f"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/continue", + body=maybe_transform( + { + "tool_responses": tool_responses, + "stream": stream, + }, + turn_continue_params.TurnContinueParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Turn, + stream=stream or False, + stream_cls=Stream[AgentTurnResponseStreamChunk], + ) + class AsyncTurnResource(AsyncAPIResource): @cached_property @@ -421,6 +545,129 @@ async def retrieve( cast_to=Turn, ) + @overload + async def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + tool_responses: Iterable[ToolResponseMessage], + stream: Literal[False] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + stream: Literal[True], + tool_responses: Iterable[ToolResponseMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AsyncStream[AgentTurnResponseStreamChunk]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + stream: bool, + tool_responses: Iterable[ToolResponseMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn | AsyncStream[AgentTurnResponseStreamChunk]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["agent_id", "session_id", "tool_responses"], ["agent_id", "session_id", "stream", "tool_responses"]) + async def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + tool_responses: Iterable[ToolResponseMessage], + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn | AsyncStream[AgentTurnResponseStreamChunk]: + if not agent_id: + raise ValueError(f"Expected a non-empty value for `agent_id` but received {agent_id!r}") + if not session_id: + raise ValueError(f"Expected a non-empty value for `session_id` but received {session_id!r}") + if not turn_id: + raise ValueError(f"Expected a non-empty value for `turn_id` but received {turn_id!r}") + return await self._post( + f"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/continue", + body=await async_maybe_transform( + { + "tool_responses": tool_responses, + "stream": stream, + }, + turn_continue_params.TurnContinueParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Turn, + stream=stream or False, + stream_cls=AsyncStream[AgentTurnResponseStreamChunk], + ) + class TurnResourceWithRawResponse: def __init__(self, turn: TurnResource) -> None: @@ -432,6 +679,9 @@ def __init__(self, turn: TurnResource) -> None: self.retrieve = to_raw_response_wrapper( turn.retrieve, ) + self.continue_ = to_raw_response_wrapper( + turn.continue_, + ) class AsyncTurnResourceWithRawResponse: @@ -444,6 +694,9 @@ def __init__(self, turn: AsyncTurnResource) -> None: self.retrieve = async_to_raw_response_wrapper( turn.retrieve, ) + self.continue_ = async_to_raw_response_wrapper( + turn.continue_, + ) class TurnResourceWithStreamingResponse: @@ -456,6 +709,9 @@ def __init__(self, turn: TurnResource) -> None: self.retrieve = to_streamed_response_wrapper( turn.retrieve, ) + self.continue_ = to_streamed_response_wrapper( + turn.continue_, + ) class AsyncTurnResourceWithStreamingResponse: @@ -468,3 +724,6 @@ def __init__(self, turn: AsyncTurnResource) -> None: self.retrieve = async_to_streamed_response_wrapper( turn.retrieve, ) + self.continue_ = async_to_streamed_response_wrapper( + turn.continue_, + ) diff --git a/src/llama_stack_client/types/agents/__init__.py b/src/llama_stack_client/types/agents/__init__.py index be21f291..e1a831dd 100644 --- a/src/llama_stack_client/types/agents/__init__.py +++ b/src/llama_stack_client/types/agents/__init__.py @@ -6,6 +6,7 @@ from .session import Session as Session from .turn_create_params import TurnCreateParams as TurnCreateParams from .turn_response_event import TurnResponseEvent as TurnResponseEvent +from .turn_continue_params import TurnContinueParams as TurnContinueParams from .session_create_params import SessionCreateParams as SessionCreateParams from .step_retrieve_response import StepRetrieveResponse as StepRetrieveResponse from .session_create_response import SessionCreateResponse as SessionCreateResponse diff --git a/src/llama_stack_client/types/agents/turn_continue_params.py b/src/llama_stack_client/types/agents/turn_continue_params.py new file mode 100644 index 00000000..c4baf885 --- /dev/null +++ b/src/llama_stack_client/types/agents/turn_continue_params.py @@ -0,0 +1,29 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Union, Iterable +from typing_extensions import Literal, Required, TypedDict + +from ..shared_params.tool_response_message import ToolResponseMessage + +__all__ = ["TurnContinueParamsBase", "TurnContinueParamsNonStreaming", "TurnContinueParamsStreaming"] + + +class TurnContinueParamsBase(TypedDict, total=False): + agent_id: Required[str] + + session_id: Required[str] + + tool_responses: Required[Iterable[ToolResponseMessage]] + + +class TurnContinueParamsNonStreaming(TurnContinueParamsBase, total=False): + stream: Literal[False] + + +class TurnContinueParamsStreaming(TurnContinueParamsBase): + stream: Required[Literal[True]] + + +TurnContinueParams = Union[TurnContinueParamsNonStreaming, TurnContinueParamsStreaming] diff --git a/src/llama_stack_client/types/agents/turn_response_event_payload.py b/src/llama_stack_client/types/agents/turn_response_event_payload.py index f12f8b03..e3315cb3 100644 --- a/src/llama_stack_client/types/agents/turn_response_event_payload.py +++ b/src/llama_stack_client/types/agents/turn_response_event_payload.py @@ -20,6 +20,7 @@ "AgentTurnResponseStepCompletePayloadStepDetails", "AgentTurnResponseTurnStartPayload", "AgentTurnResponseTurnCompletePayload", + "AgentTurnResponseTurnAwaitingInputPayload", ] @@ -72,6 +73,13 @@ class AgentTurnResponseTurnCompletePayload(BaseModel): """A single turn in an interaction with an Agentic System.""" +class AgentTurnResponseTurnAwaitingInputPayload(BaseModel): + event_type: Literal["turn_awaiting_input"] + + turn: Turn + """A single turn in an interaction with an Agentic System.""" + + TurnResponseEventPayload: TypeAlias = Annotated[ Union[ AgentTurnResponseStepStartPayload, @@ -79,6 +87,7 @@ class AgentTurnResponseTurnCompletePayload(BaseModel): AgentTurnResponseStepCompletePayload, AgentTurnResponseTurnStartPayload, AgentTurnResponseTurnCompletePayload, + AgentTurnResponseTurnAwaitingInputPayload, ], PropertyInfo(discriminator="event_type"), ] diff --git a/tests/api_resources/agents/test_turn.py b/tests/api_resources/agents/test_turn.py index b64bf957..2b935b69 100644 --- a/tests/api_resources/agents/test_turn.py +++ b/tests/api_resources/agents/test_turn.py @@ -293,6 +293,245 @@ def test_path_params_retrieve(self, client: LlamaStackClient) -> None: session_id="session_id", ) + @parametrize + def test_method_continue_overload_1(self, client: LlamaStackClient) -> None: + turn = client.agents.turn.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_method_continue_with_all_params_overload_1(self, client: LlamaStackClient) -> None: + turn = client.agents.turn.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + stream=False, + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_raw_response_continue_overload_1(self, client: LlamaStackClient) -> None: + response = client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + turn = response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_streaming_response_continue_overload_1(self, client: LlamaStackClient) -> None: + with client.agents.turn.with_streaming_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + turn = response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_continue_overload_1(self, client: LlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `agent_id` but received ''"): + client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `session_id` but received ''"): + client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `turn_id` but received ''"): + client.agents.turn.with_raw_response.continue_( + turn_id="", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + @parametrize + def test_method_continue_overload_2(self, client: LlamaStackClient) -> None: + turn_stream = client.agents.turn.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + turn_stream.response.close() + + @parametrize + def test_raw_response_continue_overload_2(self, client: LlamaStackClient) -> None: + response = client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = response.parse() + stream.close() + + @parametrize + def test_streaming_response_continue_overload_2(self, client: LlamaStackClient) -> None: + with client.agents.turn.with_streaming_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_continue_overload_2(self, client: LlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `agent_id` but received ''"): + client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `session_id` but received ''"): + client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `turn_id` but received ''"): + client.agents.turn.with_raw_response.continue_( + turn_id="", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + class TestAsyncTurn: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @@ -572,3 +811,242 @@ async def test_path_params_retrieve(self, async_client: AsyncLlamaStackClient) - agent_id="agent_id", session_id="session_id", ) + + @parametrize + async def test_method_continue_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + turn = await async_client.agents.turn.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_method_continue_with_all_params_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + turn = await async_client.agents.turn.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + stream=False, + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_raw_response_continue_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + turn = await response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_streaming_response_continue_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.agents.turn.with_streaming_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + turn = await response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_continue_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `agent_id` but received ''"): + await async_client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `session_id` but received ''"): + await async_client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `turn_id` but received ''"): + await async_client.agents.turn.with_raw_response.continue_( + turn_id="", + agent_id="agent_id", + session_id="session_id", + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + @parametrize + async def test_method_continue_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + turn_stream = await async_client.agents.turn.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + await turn_stream.response.aclose() + + @parametrize + async def test_raw_response_continue_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = await response.parse() + await stream.close() + + @parametrize + async def test_streaming_response_continue_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.agents.turn.with_streaming_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_continue_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `agent_id` but received ''"): + await async_client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `session_id` but received ''"): + await async_client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `turn_id` but received ''"): + await async_client.agents.turn.with_raw_response.continue_( + turn_id="", + agent_id="agent_id", + session_id="session_id", + stream=True, + tool_responses=[ + { + "call_id": "call_id", + "content": "string", + "role": "tool", + "tool_name": "brave_search", + } + ], + ) diff --git a/tests/test_client.py b/tests/test_client.py index f282f616..8a2992af 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -526,7 +526,7 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" def test_base_url_env(self) -> None: - with update_env(LLAMA_STACK_CLIENT_BASE_URL="http://localhost:5000/from/env"): + with update_env(LLAMA_STACK_BASE_URL="http://localhost:5000/from/env"): client = LlamaStackClient(_strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" @@ -1316,7 +1316,7 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" def test_base_url_env(self) -> None: - with update_env(LLAMA_STACK_CLIENT_BASE_URL="http://localhost:5000/from/env"): + with update_env(LLAMA_STACK_BASE_URL="http://localhost:5000/from/env"): client = AsyncLlamaStackClient(_strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" From 8958330746b26d2fd11348950f1a0976958561e9 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 17:17:30 -0800 Subject: [PATCH 2/6] add lib fix --- src/llama_stack_client/lib/agents/agent.py | 4 ++-- src/llama_stack_client/lib/agents/event_logger.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 0a8ab226..8e44d8d7 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -65,7 +65,7 @@ def create_session(self, session_name: str) -> int: return self.session_id def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: - if chunk.event.payload.event_type != "turn_complete": + if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: return [] message = chunk.event.payload.turn.output_message @@ -133,7 +133,7 @@ def create_turn( else: chunks = [] for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents): - if chunk.event.payload.event_type == "turn_complete": + if chunk.event.payload.event_type in ["turn_complete", "turn_awaiting_input"]: chunks.append(chunk) pass if not chunks: diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index d7fa514a..8a40353b 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -75,7 +75,7 @@ def _yield_printable_events( event = chunk.event event_type = event.payload.event_type - if event_type in {"turn_start", "turn_complete"}: + if event_type in {"turn_start", "turn_complete", "turn_awaiting_input"}: # Currently not logging any turn realted info yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey") return @@ -149,7 +149,7 @@ def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional if hasattr(chunk, "event"): previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None previous_step_type = ( - chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None + chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete", "turn_awaiting_input"} else None ) return previous_event_type, previous_step_type return None, None From 81787f26725f31ff17e68bd6bec27baea44d67f7 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 18:19:47 -0800 Subject: [PATCH 3/6] agent sdk --- src/llama_stack_client/lib/agents/agent.py | 87 +++++++++++----------- 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 8e44d8d7..936cb0d4 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -77,6 +77,12 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall] return message.tool_calls + def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: + if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: + return None + + return chunk.event.payload.turn.turn_id + def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: assert len(tool_calls) == 1, "Only one tool call is supported" tool_call = tool_calls[0] @@ -163,19 +169,46 @@ def _create_turn_streaming( stop = False n_iter = 0 max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) - while not stop and n_iter < max_iter: - response = self.client.agents.turn.create( + + # 1. create an agent turn + turn_response = self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + toolgroups=toolgroups, + ) + is_turn_complete = True + turn_id = None + for chunk in turn_response: + tool_calls = self._get_tool_calls(chunk) + if hasattr(chunk, "error"): + yield chunk + return + elif not tool_calls: + yield chunk + else: + is_turn_complete = False + turn_id = self._get_turn_id(chunk) + break + + # 2. while the turn is not complete, continue the turn + while not is_turn_complete and n_iter < max_iter: + assert turn_id is not None, "turn_id is None" + + # run the tools + tool_response_message = self._run_tool(tool_calls) + + continue_response = self.client.agents.turn.continue_( agent_id=self.agent_id, - # use specified session_id or last session created session_id=session_id or self.session_id[-1], - messages=messages, + turn_id=turn_id, + tool_responses=[tool_response_message], stream=True, - documents=documents, - toolgroups=toolgroups, ) - # by default, we stop after the first turn - stop = True - for chunk in response: + for chunk in continue_response: tool_calls = self._get_tool_calls(chunk) if hasattr(chunk, "error"): yield chunk @@ -183,39 +216,7 @@ def _create_turn_streaming( elif not tool_calls: yield chunk else: - tool_execution_start_time = datetime.now() + is_turn_complete = False + turn_id = self._get_turn_id(chunk) tool_response_message = self._run_tool(tool_calls) - tool_execution_step = ToolExecutionStep( - step_type="tool_execution", - step_id=str(uuid.uuid4()), - tool_calls=tool_calls, - tool_responses=[ - ToolResponse( - tool_name=tool_response_message.tool_name, - content=tool_response_message.content, - call_id=tool_response_message.call_id, - ) - ], - turn_id=chunk.event.payload.turn.turn_id, - completed_at=datetime.now(), - started_at=tool_execution_start_time, - ) - yield AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - event_type="step_complete", - step_id=tool_execution_step.step_id, - step_type="tool_execution", - step_details=tool_execution_step, - ) - ) - ) - - # HACK: append the tool execution step to the turn - chunk.event.payload.turn.steps.append(tool_execution_step) - yield chunk - - # continue the turn when there's a tool call - stop = False - messages = [tool_response_message] n_iter += 1 From a327f32e0df20396c3091531953db60a84844488 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 18:35:34 -0800 Subject: [PATCH 4/6] new loop --- src/llama_stack_client/lib/agents/agent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 936cb0d4..f85f33b2 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -192,10 +192,12 @@ def _create_turn_streaming( else: is_turn_complete = False turn_id = self._get_turn_id(chunk) + yield chunk break # 2. while the turn is not complete, continue the turn while not is_turn_complete and n_iter < max_iter: + is_turn_complete = True assert turn_id is not None, "turn_id is None" # run the tools @@ -218,5 +220,4 @@ def _create_turn_streaming( else: is_turn_complete = False turn_id = self._get_turn_id(chunk) - tool_response_message = self._run_tool(tool_calls) n_iter += 1 From a0f0b3808276ffb69acd9d0de6c35b27c101c372 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 19:35:56 -0800 Subject: [PATCH 5/6] use continue --- src/llama_stack_client/lib/agents/agent.py | 12 ++---------- src/llama_stack_client/lib/agents/event_logger.py | 4 +++- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index f85f33b2..ec2cb23b 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -13,18 +13,10 @@ from llama_stack_client.types.agents.turn_create_response import ( AgentTurnResponseStreamChunk, ) -from llama_stack_client.types.agents.turn_response_event import TurnResponseEvent -from llama_stack_client.types.agents.turn_response_event_payload import ( - AgentTurnResponseStepCompletePayload, -) from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.agents.turn import CompletionMessage from .client_tool import ClientTool from .tool_parser import ToolParser -from datetime import datetime -import uuid -from llama_stack_client.types.tool_execution_step import ToolExecutionStep -from llama_stack_client.types.tool_response import ToolResponse DEFAULT_MAX_ITER = 10 @@ -169,7 +161,7 @@ def _create_turn_streaming( stop = False n_iter = 0 max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) - + # 1. create an agent turn turn_response = self.client.agents.turn.create( agent_id=self.agent_id, @@ -194,7 +186,7 @@ def _create_turn_streaming( turn_id = self._get_turn_id(chunk) yield chunk break - + # 2. while the turn is not complete, continue the turn while not is_turn_complete and n_iter < max_iter: is_turn_complete = True diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index 8a40353b..40a1d359 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -149,7 +149,9 @@ def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional if hasattr(chunk, "event"): previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None previous_step_type = ( - chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete", "turn_awaiting_input"} else None + chunk.event.payload.step_type + if previous_event_type not in {"turn_start", "turn_complete", "turn_awaiting_input"} + else None ) return previous_event_type, previous_step_type return None, None From bb428adefaee71c78cf8538d70dae8dd4995adcc Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 19:40:52 -0800 Subject: [PATCH 6/6] non-streaming case --- src/llama_stack_client/lib/agents/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index ec2cb23b..ab4570be 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -131,7 +131,7 @@ def create_turn( else: chunks = [] for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents): - if chunk.event.payload.event_type in ["turn_complete", "turn_awaiting_input"]: + if chunk.event.payload.event_type == "turn_complete": chunks.append(chunk) pass if not chunks: