@@ -71,6 +71,39 @@ def create_user(user: UserInput, flag: bool) -> dict:
7171 assert "age" in tool .parameters ["$defs" ]["UserInput" ]["properties" ]
7272 assert "flag" in tool .parameters ["properties" ]
7373
74+ def test_add_callable_object (self ):
75+ """Test registering a callable object."""
76+
77+ class MyTool :
78+ def __init__ (self ):
79+ self .__name__ = "MyTool"
80+
81+ def __call__ (self , x : int ) -> int :
82+ return x * 2
83+
84+ manager = ToolManager ()
85+ tool = manager .add_tool (MyTool ())
86+ assert tool .name == "MyTool"
87+ assert tool .is_async is False
88+ assert tool .parameters ["properties" ]["x" ]["type" ] == "integer"
89+
90+ @pytest .mark .anyio
91+ async def test_add_async_callable_object (self ):
92+ """Test registering an async callable object."""
93+
94+ class MyAsyncTool :
95+ def __init__ (self ):
96+ self .__name__ = "MyAsyncTool"
97+
98+ async def __call__ (self , x : int ) -> int :
99+ return x * 2
100+
101+ manager = ToolManager ()
102+ tool = manager .add_tool (MyAsyncTool ())
103+ assert tool .name == "MyAsyncTool"
104+ assert tool .is_async is True
105+ assert tool .parameters ["properties" ]["x" ]["type" ] == "integer"
106+
74107 def test_add_invalid_tool (self ):
75108 manager = ToolManager ()
76109 with pytest .raises (AttributeError ):
@@ -137,6 +170,34 @@ async def double(n: int) -> int:
137170 result = await manager .call_tool ("double" , {"n" : 5 })
138171 assert result == 10
139172
173+ @pytest .mark .anyio
174+ async def test_call_object_tool (self ):
175+ class MyTool :
176+ def __init__ (self ):
177+ self .__name__ = "MyTool"
178+
179+ def __call__ (self , x : int ) -> int :
180+ return x * 2
181+
182+ manager = ToolManager ()
183+ tool = manager .add_tool (MyTool ())
184+ result = await tool .run ({"x" : 5 })
185+ assert result == 10
186+
187+ @pytest .mark .anyio
188+ async def test_call_async_object_tool (self ):
189+ class MyAsyncTool :
190+ def __init__ (self ):
191+ self .__name__ = "MyAsyncTool"
192+
193+ async def __call__ (self , x : int ) -> int :
194+ return x * 2
195+
196+ manager = ToolManager ()
197+ tool = manager .add_tool (MyAsyncTool ())
198+ result = await tool .run ({"x" : 5 })
199+ assert result == 10
200+
140201 @pytest .mark .anyio
141202 async def test_call_tool_with_default_args (self ):
142203 def add (a : int , b : int = 1 ) -> int :
0 commit comments