Skip to content

Commit e72d9e8

Browse files
authored
fix: accept extra_headers in agent.create_turn and pass them faithfully (#228)
We use extra headers to pass provider data if you wanted to override the provider data on a per-API-call basis. Normally provider data is set via LlamaStackClient initialization and is passed down in each API call automatically.
1 parent e39ba88 commit e72d9e8

File tree

1 file changed

+39
-9
lines changed
  • src/llama_stack_client/lib/agents

1 file changed

+39
-9
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@
99
from llama_stack_client import LlamaStackClient
1010
from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage
1111
from llama_stack_client.types.agent_create_params import AgentConfig
12+
from llama_stack_client.types.agents.agent_turn_response_stream_chunk import (
13+
AgentTurnResponseStreamChunk,
14+
)
1215
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
1316
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
14-
from llama_stack_client.types.agents.agent_turn_response_stream_chunk import AgentTurnResponseStreamChunk
1517
from llama_stack_client.types.shared.tool_call import ToolCall
1618
from llama_stack_client.types.shared_params.agent_config import ToolConfig
1719
from llama_stack_client.types.shared_params.response_format import ResponseFormat
1820
from llama_stack_client.types.shared_params.sampling_params import SamplingParams
1921

20-
from .client_tool import client_tool, ClientTool
22+
from ..._types import Headers
23+
from .client_tool import ClientTool, client_tool
2124
from .tool_parser import ToolParser
2225

2326
DEFAULT_MAX_ITER = 10
@@ -27,7 +30,9 @@
2730

2831
class AgentUtils:
2932
@staticmethod
30-
def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]]) -> List[ClientTool]:
33+
def get_client_tools(
34+
tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]],
35+
) -> List[ClientTool]:
3136
if not tools:
3237
return []
3338

@@ -37,7 +42,10 @@ def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool, Callable[
3742

3843
@staticmethod
3944
def get_tool_calls(chunk: AgentTurnResponseStreamChunk, tool_parser: Optional[ToolParser] = None) -> List[ToolCall]:
40-
if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}:
45+
if chunk.event.payload.event_type not in {
46+
"turn_complete",
47+
"turn_awaiting_input",
48+
}:
4149
return []
4250

4351
message = chunk.event.payload.turn.output_message
@@ -51,7 +59,10 @@ def get_tool_calls(chunk: AgentTurnResponseStreamChunk, tool_parser: Optional[To
5159

5260
@staticmethod
5361
def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
54-
if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]:
62+
if chunk.event.payload.event_type not in [
63+
"turn_complete",
64+
"turn_awaiting_input",
65+
]:
5566
return None
5667

5768
return chunk.event.payload.turn.turn_id
@@ -228,7 +239,10 @@ def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam:
228239
if tool_call.tool_name in self.builtin_tools:
229240
tool_result = self.client.tool_runtime.invoke_tool(
230241
tool_name=tool_call.tool_name,
231-
kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]},
242+
kwargs={
243+
**tool_call.arguments,
244+
**self.builtin_tools[tool_call.tool_name],
245+
},
232246
)
233247
return ToolResponseParam(
234248
call_id=tool_call.call_id,
@@ -250,11 +264,21 @@ def create_turn(
250264
toolgroups: Optional[List[Toolgroup]] = None,
251265
documents: Optional[List[Document]] = None,
252266
stream: bool = True,
267+
extra_headers: Headers | None = None,
253268
) -> Iterator[AgentTurnResponseStreamChunk] | Turn:
254269
if stream:
255-
return self._create_turn_streaming(messages, session_id, toolgroups, documents)
270+
return self._create_turn_streaming(messages, session_id, toolgroups, documents, extra_headers=extra_headers)
256271
else:
257-
chunks = [x for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)]
272+
chunks = [
273+
x
274+
for x in self._create_turn_streaming(
275+
messages,
276+
session_id,
277+
toolgroups,
278+
documents,
279+
extra_headers=extra_headers,
280+
)
281+
]
258282
if not chunks:
259283
raise Exception("Turn did not complete")
260284

@@ -276,6 +300,7 @@ def _create_turn_streaming(
276300
session_id: Optional[str] = None,
277301
toolgroups: Optional[List[Toolgroup]] = None,
278302
documents: Optional[List[Document]] = None,
303+
extra_headers: Headers | None = None,
279304
) -> Iterator[AgentTurnResponseStreamChunk]:
280305
n_iter = 0
281306

@@ -288,6 +313,7 @@ def _create_turn_streaming(
288313
stream=True,
289314
documents=documents,
290315
toolgroups=toolgroups,
316+
extra_headers=extra_headers,
291317
)
292318

293319
# 2. process turn and resume if there's a tool call
@@ -324,6 +350,7 @@ def _create_turn_streaming(
324350
turn_id=turn_id,
325351
tool_responses=tool_responses,
326352
stream=True,
353+
extra_headers=extra_headers,
327354
)
328355
n_iter += 1
329356

@@ -478,7 +505,10 @@ async def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam:
478505
if tool_call.tool_name in self.builtin_tools:
479506
tool_result = await self.client.tool_runtime.invoke_tool(
480507
tool_name=tool_call.tool_name,
481-
kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]},
508+
kwargs={
509+
**tool_call.arguments,
510+
**self.builtin_tools[tool_call.tool_name],
511+
},
482512
)
483513
return ToolResponseParam(
484514
call_id=tool_call.call_id,

0 commit comments

Comments
 (0)