@@ -54,20 +54,25 @@ def get_tool_def(self, name: str) -> ToolDefinition | None:
5454 except KeyError :
5555 return None
5656
57- async def handle_call (self , call : ToolCallPart , allow_partial : bool = False ) -> Any :
57+ async def handle_call (
58+ self , call : ToolCallPart , allow_partial : bool = False , wrap_validation_errors : bool = True
59+ ) -> Any :
5860 """Handle a tool call by validating the arguments, calling the tool, and handling retries.
5961
6062 Args:
6163 call: The tool call part to handle.
6264 allow_partial: Whether to allow partial validation of the tool arguments.
65+ wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
6366 """
6467 if (tool := self .tools .get (call .tool_name )) and tool .tool_def .kind == 'output' :
6568 # Output tool calls are not traced
66- return await self ._call_tool (call , allow_partial )
69+ return await self ._call_tool (call , allow_partial , wrap_validation_errors )
6770 else :
68- return await self ._call_tool_traced (call , allow_partial )
71+ return await self ._call_tool_traced (call , allow_partial , wrap_validation_errors )
6972
70- async def _call_tool (self , call : ToolCallPart , allow_partial : bool = False ) -> Any :
73+ async def _call_tool (
74+ self , call : ToolCallPart , allow_partial : bool = False , wrap_validation_errors : bool = True
75+ ) -> Any :
7176 name = call .tool_name
7277 tool = self .tools .get (name )
7378 try :
@@ -100,30 +105,35 @@ async def _call_tool(self, call: ToolCallPart, allow_partial: bool = False) -> A
100105 if current_retry == max_retries :
101106 raise UnexpectedModelBehavior (f'Tool { name !r} exceeded max retries count of { max_retries } ' ) from e
102107 else :
103- if isinstance (e , ValidationError ):
104- m = _messages .RetryPromptPart (
105- tool_name = name ,
106- content = e .errors (include_url = False , include_context = False ),
107- tool_call_id = call .tool_call_id ,
108- )
109- e = ToolRetryError (m )
110- elif isinstance (e , ModelRetry ):
111- m = _messages .RetryPromptPart (
112- tool_name = name ,
113- content = e .message ,
114- tool_call_id = call .tool_call_id ,
115- )
116- e = ToolRetryError (m )
117- else :
118- assert_never (e )
108+ if wrap_validation_errors :
109+ if isinstance (e , ValidationError ):
110+ m = _messages .RetryPromptPart (
111+ tool_name = name ,
112+ content = e .errors (include_url = False , include_context = False ),
113+ tool_call_id = call .tool_call_id ,
114+ )
115+ e = ToolRetryError (m )
116+ elif isinstance (e , ModelRetry ):
117+ m = _messages .RetryPromptPart (
118+ tool_name = name ,
119+ content = e .message ,
120+ tool_call_id = call .tool_call_id ,
121+ )
122+ e = ToolRetryError (m )
123+ else :
124+ assert_never (e )
125+
126+ if not allow_partial :
127+ self .ctx .retries [name ] = current_retry + 1
119128
120- self .ctx .retries [name ] = current_retry + 1
121129 raise e
122130 else :
123131 self .ctx .retries .pop (name , None )
124132 return output
125133
126- async def _call_tool_traced (self , call : ToolCallPart , allow_partial : bool = False ) -> Any :
134+ async def _call_tool_traced (
135+ self , call : ToolCallPart , allow_partial : bool = False , wrap_validation_errors : bool = True
136+ ) -> Any :
127137 """See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
128138 span_attributes = {
129139 'gen_ai.tool.name' : call .tool_name ,
@@ -152,7 +162,7 @@ async def _call_tool_traced(self, call: ToolCallPart, allow_partial: bool = Fals
152162 }
153163 with self .ctx .tracer .start_as_current_span ('running tool' , attributes = span_attributes ) as span :
154164 try :
155- tool_result = await self ._call_tool (call , allow_partial )
165+ tool_result = await self ._call_tool (call , allow_partial , wrap_validation_errors )
156166 except ToolRetryError as e :
157167 part = e .tool_retry
158168 if self .ctx .trace_include_content and span .is_recording ():
0 commit comments