Skip to content

Commit fc4face

Browse files
authored
[feat] More specific error handling with simple function tools (#184)
* update AsyncSimpleFunctionTool error handling * update SimpleFunctionTool error handling * changelog * coverage
1 parent 1b2411c commit fc4face

File tree

3 files changed

+106
-21
lines changed

3 files changed

+106
-21
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
1717

1818
### Changed
1919

20+
- [feat] More specific error handling with simple function tools (#184)
2021
- refactor: Store custom desc in _desc for SimpleFunctionTool and AsyncSimpleFunctionTool. (#181)
2122
- Remove return_history param in OllamaLLM.chat() (#126)
2223
- Rename continue_conversation_with_tool_results to continue_chat_with_tool_results (#123)

src/llm_agents_from_scratch/tools/simple_function.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Simple Function Tool."""
22

33
import inspect
4+
import json
45
from typing import Any, Awaitable, Callable, get_type_hints
56

6-
from jsonschema import validate
7+
from jsonschema import SchemaError, ValidationError, validate
78

89
from llm_agents_from_scratch.base.tool import AsyncBaseTool, BaseTool
910
from llm_agents_from_scratch.data_structures import ToolCall, ToolCallResult
@@ -125,18 +126,35 @@ def __call__(
125126
try:
126127
# validate the arguments
127128
validate(tool_call.arguments, schema=self.parameters_json_schema)
129+
except (SchemaError, ValidationError) as e:
130+
error_details = {
131+
"error_type": e.__class__.__name__,
132+
"message": e.message,
133+
}
134+
return ToolCallResult(
135+
tool_call_id=tool_call.id_,
136+
content=json.dumps(error_details),
137+
error=True,
138+
)
139+
140+
try:
128141
# execute the function
129142
res = self.func(**tool_call.arguments)
130-
content = str(res)
131-
error = False
132143
except Exception as e:
133-
content = f"Failed to execute function call: {e}"
134-
error = True
144+
error_details = {
145+
"error_type": e.__class__.__name__,
146+
"message": f"Internal error while executing tool: {str(e)}",
147+
}
148+
return ToolCallResult(
149+
tool_call_id=tool_call.id_,
150+
content=json.dumps(error_details),
151+
error=True,
152+
)
135153

136154
return ToolCallResult(
137155
tool_call_id=tool_call.id_,
138-
content=content,
139-
error=error,
156+
content=str(res),
157+
error=False,
140158
)
141159

142160

@@ -198,16 +216,33 @@ async def __call__(
198216
try:
199217
# validate the arguments
200218
validate(tool_call.arguments, schema=self.parameters_json_schema)
219+
except (SchemaError, ValidationError) as e:
220+
error_details = {
221+
"error_type": e.__class__.__name__,
222+
"message": e.message,
223+
}
224+
return ToolCallResult(
225+
tool_call_id=tool_call.id_,
226+
content=json.dumps(error_details),
227+
error=True,
228+
)
229+
230+
try:
201231
# execute the function
202232
res = await self.func(**tool_call.arguments)
203-
content = str(res)
204-
error = False
205233
except Exception as e:
206-
content = f"Failed to execute function call: {e}"
207-
error = True
234+
error_details = {
235+
"error_type": e.__class__.__name__,
236+
"message": f"Internal error while executing tool: {str(e)}",
237+
}
238+
return ToolCallResult(
239+
tool_call_id=tool_call.id_,
240+
content=json.dumps(error_details),
241+
error=True,
242+
)
208243

209244
return ToolCallResult(
210245
tool_call_id=tool_call.id_,
211-
content=content,
212-
error=error,
246+
content=str(res),
247+
error=False,
213248
)

tests/tools/test_function_tool.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ async def my_mock_fn_3(
3838
return f"{param1} and {param2}"
3939

4040

41+
def my_mock_fn_that_raises(
42+
param1: int,
43+
param2: str = "x",
44+
*args: Any,
45+
**kwargs: Any,
46+
) -> str:
47+
raise RuntimeError("Oops!")
48+
49+
4150
@pytest.mark.parametrize(
4251
("func", "properties", "required"),
4352
[
@@ -114,7 +123,7 @@ def test_function_tool_call(mock_validate: MagicMock) -> None:
114123
assert result.error is False
115124

116125

117-
def test_function_tool_call_returns_error() -> None:
126+
def test_function_tool_call_returns_validation_error() -> None:
118127
"""Tests a function tool call raises error at validation of params."""
119128
tool = SimpleFunctionTool(my_mock_fn_1, desc="mock desc")
120129
tool_call = ToolCall(
@@ -124,10 +133,29 @@ def test_function_tool_call_returns_error() -> None:
124133

125134
result = tool(tool_call=tool_call)
126135

127-
assert (
128-
"Failed to execute function call: '1' is not of type 'number'"
129-
in result.content
136+
expected_content = (
137+
'{"error_type": "ValidationError", "message": "\'1\' '
138+
"is not of type 'number'\"}"
130139
)
140+
assert expected_content == result.content
141+
assert result.error is True
142+
143+
144+
def test_function_tool_call_returns_execution_error() -> None:
145+
"""Tests a function tool call raises error at validation of params."""
146+
tool = SimpleFunctionTool(my_mock_fn_that_raises, desc="mock desc")
147+
tool_call = ToolCall(
148+
tool_name="my_mock_fn_that_raises",
149+
arguments={"param1": 1, "param2": "y"},
150+
)
151+
152+
result = tool(tool_call=tool_call)
153+
154+
expected_content = (
155+
'{"error_type": "RuntimeError", '
156+
'"message": "Internal error while executing tool: Oops!"}'
157+
)
158+
assert expected_content == result.content
131159
assert result.error is True
132160

133161

@@ -166,7 +194,7 @@ async def test_async_function_tool_call(mock_validate: MagicMock) -> None:
166194

167195

168196
@pytest.mark.asyncio
169-
async def test_async_function_tool_call_returns_error() -> None:
197+
async def test_async_function_tool_call_returns_validation_error() -> None:
170198
"""Tests a function tool call."""
171199
tool = AsyncSimpleFunctionTool(my_mock_fn_1, desc="mock desc")
172200
tool_call = ToolCall(
@@ -176,8 +204,29 @@ async def test_async_function_tool_call_returns_error() -> None:
176204

177205
result = await tool(tool_call=tool_call)
178206

179-
assert (
180-
"Failed to execute function call: '1' is not of type 'number'"
181-
in result.content
207+
expected_content = (
208+
'{"error_type": "ValidationError", "message": "\'1\' '
209+
"is not of type 'number'\"}"
210+
)
211+
212+
assert expected_content == result.content
213+
assert result.error is True
214+
215+
216+
@pytest.mark.asyncio
217+
async def test_async_function_tool_call_returns_execution_error() -> None:
218+
"""Tests a function tool call raises error at validation of params."""
219+
tool = AsyncSimpleFunctionTool(my_mock_fn_that_raises, desc="mock desc")
220+
tool_call = ToolCall(
221+
tool_name="my_mock_fn_that_raises",
222+
arguments={"param1": 1, "param2": "y"},
223+
)
224+
225+
result = await tool(tool_call=tool_call)
226+
227+
expected_content = (
228+
'{"error_type": "RuntimeError", '
229+
'"message": "Internal error while executing tool: Oops!"}'
182230
)
231+
assert expected_content == result.content
183232
assert result.error is True

0 commit comments

Comments
 (0)