9
9
from llama_stack_client import LlamaStackClient
10
10
from llama_stack_client .types import ToolResponseMessage , ToolResponseParam , UserMessage
11
11
from llama_stack_client .types .agent_create_params import AgentConfig
12
+ from llama_stack_client .types .agents .agent_turn_response_stream_chunk import (
13
+ AgentTurnResponseStreamChunk ,
14
+ )
12
15
from llama_stack_client .types .agents .turn import CompletionMessage , Turn
13
16
from llama_stack_client .types .agents .turn_create_params import Document , Toolgroup
14
- from llama_stack_client .types .agents .agent_turn_response_stream_chunk import AgentTurnResponseStreamChunk
15
17
from llama_stack_client .types .shared .tool_call import ToolCall
16
18
from llama_stack_client .types .shared_params .agent_config import ToolConfig
17
19
from llama_stack_client .types .shared_params .response_format import ResponseFormat
18
20
from llama_stack_client .types .shared_params .sampling_params import SamplingParams
19
21
20
- from .client_tool import client_tool , ClientTool
22
+ from ..._types import Headers
23
+ from .client_tool import ClientTool , client_tool
21
24
from .tool_parser import ToolParser
22
25
23
26
DEFAULT_MAX_ITER = 10
27
30
28
31
class AgentUtils :
29
32
@staticmethod
30
- def get_client_tools (tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ]]]]) -> List [ClientTool ]:
33
+ def get_client_tools (
34
+ tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ]]]],
35
+ ) -> List [ClientTool ]:
31
36
if not tools :
32
37
return []
33
38
@@ -37,7 +42,10 @@ def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool, Callable[
37
42
38
43
@staticmethod
39
44
def get_tool_calls (chunk : AgentTurnResponseStreamChunk , tool_parser : Optional [ToolParser ] = None ) -> List [ToolCall ]:
40
- if chunk .event .payload .event_type not in {"turn_complete" , "turn_awaiting_input" }:
45
+ if chunk .event .payload .event_type not in {
46
+ "turn_complete" ,
47
+ "turn_awaiting_input" ,
48
+ }:
41
49
return []
42
50
43
51
message = chunk .event .payload .turn .output_message
@@ -51,7 +59,10 @@ def get_tool_calls(chunk: AgentTurnResponseStreamChunk, tool_parser: Optional[To
51
59
52
60
@staticmethod
53
61
def get_turn_id (chunk : AgentTurnResponseStreamChunk ) -> Optional [str ]:
54
- if chunk .event .payload .event_type not in ["turn_complete" , "turn_awaiting_input" ]:
62
+ if chunk .event .payload .event_type not in [
63
+ "turn_complete" ,
64
+ "turn_awaiting_input" ,
65
+ ]:
55
66
return None
56
67
57
68
return chunk .event .payload .turn .turn_id
@@ -228,7 +239,10 @@ def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam:
228
239
if tool_call .tool_name in self .builtin_tools :
229
240
tool_result = self .client .tool_runtime .invoke_tool (
230
241
tool_name = tool_call .tool_name ,
231
- kwargs = {** tool_call .arguments , ** self .builtin_tools [tool_call .tool_name ]},
242
+ kwargs = {
243
+ ** tool_call .arguments ,
244
+ ** self .builtin_tools [tool_call .tool_name ],
245
+ },
232
246
)
233
247
return ToolResponseParam (
234
248
call_id = tool_call .call_id ,
@@ -250,11 +264,21 @@ def create_turn(
250
264
toolgroups : Optional [List [Toolgroup ]] = None ,
251
265
documents : Optional [List [Document ]] = None ,
252
266
stream : bool = True ,
267
+ extra_headers : Headers | None = None ,
253
268
) -> Iterator [AgentTurnResponseStreamChunk ] | Turn :
254
269
if stream :
255
- return self ._create_turn_streaming (messages , session_id , toolgroups , documents )
270
+ return self ._create_turn_streaming (messages , session_id , toolgroups , documents , extra_headers = extra_headers )
256
271
else :
257
- chunks = [x for x in self ._create_turn_streaming (messages , session_id , toolgroups , documents )]
272
+ chunks = [
273
+ x
274
+ for x in self ._create_turn_streaming (
275
+ messages ,
276
+ session_id ,
277
+ toolgroups ,
278
+ documents ,
279
+ extra_headers = extra_headers ,
280
+ )
281
+ ]
258
282
if not chunks :
259
283
raise Exception ("Turn did not complete" )
260
284
@@ -276,6 +300,7 @@ def _create_turn_streaming(
276
300
session_id : Optional [str ] = None ,
277
301
toolgroups : Optional [List [Toolgroup ]] = None ,
278
302
documents : Optional [List [Document ]] = None ,
303
+ extra_headers : Headers | None = None ,
279
304
) -> Iterator [AgentTurnResponseStreamChunk ]:
280
305
n_iter = 0
281
306
@@ -288,6 +313,7 @@ def _create_turn_streaming(
288
313
stream = True ,
289
314
documents = documents ,
290
315
toolgroups = toolgroups ,
316
+ extra_headers = extra_headers ,
291
317
)
292
318
293
319
# 2. process turn and resume if there's a tool call
@@ -324,6 +350,7 @@ def _create_turn_streaming(
324
350
turn_id = turn_id ,
325
351
tool_responses = tool_responses ,
326
352
stream = True ,
353
+ extra_headers = extra_headers ,
327
354
)
328
355
n_iter += 1
329
356
@@ -478,7 +505,10 @@ async def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam:
478
505
if tool_call .tool_name in self .builtin_tools :
479
506
tool_result = await self .client .tool_runtime .invoke_tool (
480
507
tool_name = tool_call .tool_name ,
481
- kwargs = {** tool_call .arguments , ** self .builtin_tools [tool_call .tool_name ]},
508
+ kwargs = {
509
+ ** tool_call .arguments ,
510
+ ** self .builtin_tools [tool_call .tool_name ],
511
+ },
482
512
)
483
513
return ToolResponseParam (
484
514
call_id = tool_call .call_id ,
0 commit comments