diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 6177f4ad..aa002440 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -9,15 +9,18 @@ from llama_stack_client import LlamaStackClient from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.agents.agent_turn_response_stream_chunk import ( + AgentTurnResponseStreamChunk, +) from llama_stack_client.types.agents.turn import CompletionMessage, Turn from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup -from llama_stack_client.types.agents.agent_turn_response_stream_chunk import AgentTurnResponseStreamChunk from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.shared_params.agent_config import ToolConfig from llama_stack_client.types.shared_params.response_format import ResponseFormat from llama_stack_client.types.shared_params.sampling_params import SamplingParams -from .client_tool import client_tool, ClientTool +from ..._types import Headers +from .client_tool import ClientTool, client_tool from .tool_parser import ToolParser DEFAULT_MAX_ITER = 10 @@ -27,7 +30,9 @@ class AgentUtils: @staticmethod - def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]]) -> List[ClientTool]: + def get_client_tools( + tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]], + ) -> List[ClientTool]: if not tools: return [] @@ -37,7 +42,10 @@ def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool, Callable[ @staticmethod def get_tool_calls(chunk: AgentTurnResponseStreamChunk, tool_parser: Optional[ToolParser] = None) -> List[ToolCall]: - if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: + if chunk.event.payload.event_type not in { + "turn_complete", + "turn_awaiting_input", + }: return [] message = chunk.event.payload.turn.output_message @@ -51,7 +59,10 @@ def get_tool_calls(chunk: AgentTurnResponseStreamChunk, tool_parser: Optional[To @staticmethod def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]: - if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: + if chunk.event.payload.event_type not in [ + "turn_complete", + "turn_awaiting_input", + ]: return None return chunk.event.payload.turn.turn_id @@ -228,7 +239,10 @@ def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam: if tool_call.tool_name in self.builtin_tools: tool_result = self.client.tool_runtime.invoke_tool( tool_name=tool_call.tool_name, - kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]}, + kwargs={ + **tool_call.arguments, + **self.builtin_tools[tool_call.tool_name], + }, ) return ToolResponseParam( call_id=tool_call.call_id, @@ -250,11 +264,21 @@ def create_turn( toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, stream: bool = True, + extra_headers: Headers | None = None, ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: if stream: - return self._create_turn_streaming(messages, session_id, toolgroups, documents) + return self._create_turn_streaming(messages, session_id, toolgroups, documents, extra_headers=extra_headers) else: - chunks = [x for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)] + chunks = [ + x + for x in self._create_turn_streaming( + messages, + session_id, + toolgroups, + documents, + extra_headers=extra_headers, + ) + ] if not chunks: raise Exception("Turn did not complete") @@ -276,6 +300,7 @@ def _create_turn_streaming( session_id: Optional[str] = None, toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, + extra_headers: Headers | None = None, ) -> Iterator[AgentTurnResponseStreamChunk]: n_iter = 0 @@ -288,6 +313,7 @@ def _create_turn_streaming( stream=True, documents=documents, toolgroups=toolgroups, + extra_headers=extra_headers, ) # 2. process turn and resume if there's a tool call @@ -324,6 +350,7 @@ def _create_turn_streaming( turn_id=turn_id, tool_responses=tool_responses, stream=True, + extra_headers=extra_headers, ) n_iter += 1 @@ -478,7 +505,10 @@ async def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam: if tool_call.tool_name in self.builtin_tools: tool_result = await self.client.tool_runtime.invoke_tool( tool_name=tool_call.tool_name, - kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]}, + kwargs={ + **tool_call.arguments, + **self.builtin_tools[tool_call.tool_name], + }, ) return ToolResponseParam( call_id=tool_call.call_id,