19
19
Union ,
20
20
)
21
21
22
+ from typing_extensions import TypedDict
23
+
22
24
from llama_stack_client .types import CompletionMessage , Message
23
25
from llama_stack_client .types .alpha import ToolResponse
24
- from llama_stack_client .types .tool_def_param import Parameter , ToolDefParam
26
+ from llama_stack_client .types .tool_def_param import ToolDefParam
27
+
28
+
29
+ class JSONSchema (TypedDict , total = False ):
30
+ type : str
31
+ properties : Dict [str , Any ]
32
+ required : List [str ]
25
33
26
34
27
35
class ClientTool :
@@ -47,28 +55,18 @@ def get_description(self) -> str:
47
55
raise NotImplementedError
48
56
49
57
@abstractmethod
50
- def get_params_definition (self ) -> Dict [ str , Parameter ] :
58
+ def get_input_schema (self ) -> JSONSchema :
51
59
raise NotImplementedError
52
60
53
61
def get_instruction_string (self ) -> str :
54
62
return f"Use the function '{ self .get_name ()} ' to: { self .get_description ()} "
55
63
56
- def parameters_for_system_prompt (self ) -> str :
57
- return json .dumps (
58
- {
59
- "name" : self .get_name (),
60
- "description" : self .get_description (),
61
- "parameters" : {name : definition for name , definition in self .get_params_definition ().items ()},
62
- }
63
- )
64
-
65
64
def get_tool_definition (self ) -> ToolDefParam :
66
65
return ToolDefParam (
67
66
name = self .get_name (),
68
67
description = self .get_description (),
69
- parameters = list ( self .get_params_definition (). values () ),
68
+ input_schema = self .get_input_schema ( ),
70
69
metadata = {},
71
- tool_prompt_format = "python_list" ,
72
70
)
73
71
74
72
def run (
@@ -148,6 +146,37 @@ def async_run_impl(self, **kwargs):
148
146
T = TypeVar ("T" , bound = Callable )
149
147
150
148
149
+ def _python_type_to_json_schema_type (type_hint : Any ) -> str :
150
+ """Convert Python type hints to JSON Schema type strings."""
151
+ # Handle Union types (e.g., Optional[str])
152
+ origin = get_origin (type_hint )
153
+ if origin is Union :
154
+ # Get non-None types from Union
155
+ args = [arg for arg in get_args (type_hint ) if arg is not type (None )]
156
+ if args :
157
+ type_hint = args [0 ] # Use first non-None type
158
+
159
+ # Get the actual type if it's a generic
160
+ if hasattr (type_hint , "__origin__" ):
161
+ type_hint = type_hint .__origin__
162
+
163
+ # Map Python types to JSON Schema types
164
+ type_name = getattr (type_hint , "__name__" , str (type_hint ))
165
+
166
+ type_mapping = {
167
+ "bool" : "boolean" ,
168
+ "int" : "integer" ,
169
+ "float" : "number" ,
170
+ "str" : "string" ,
171
+ "list" : "array" ,
172
+ "dict" : "object" ,
173
+ "List" : "array" ,
174
+ "Dict" : "object" ,
175
+ }
176
+
177
+ return type_mapping .get (type_name , "string" ) # Default to string if unknown
178
+
179
+
151
180
def client_tool (func : T ) -> ClientTool :
152
181
"""
153
182
Decorator to convert a function into a ClientTool.
@@ -188,13 +217,14 @@ def get_description(self) -> str:
188
217
f"No description found for client tool { __name__ } . Please provide a RST-style docstring with description and :param tags for each parameter."
189
218
)
190
219
191
- def get_params_definition (self ) -> Dict [ str , Parameter ] :
220
+ def get_input_schema (self ) -> JSONSchema :
192
221
hints = get_type_hints (func )
193
222
# Remove return annotation if present
194
223
hints .pop ("return" , None )
195
224
196
225
# Get parameter descriptions from docstring
197
- params = {}
226
+ properties = {}
227
+ required = []
198
228
sig = inspect .signature (func )
199
229
doc = inspect .getdoc (func ) or ""
200
230
@@ -212,15 +242,20 @@ def get_params_definition(self) -> Dict[str, Parameter]:
212
242
param = sig .parameters [name ]
213
243
is_optional_type = get_origin (type_hint ) is Union and type (None ) in get_args (type_hint )
214
244
is_required = param .default == inspect .Parameter .empty and not is_optional_type
215
- params [name ] = Parameter (
216
- name = name ,
217
- description = param_doc or f"Parameter { name } " ,
218
- parameter_type = type_hint .__name__ ,
219
- default = (param .default if param .default != inspect .Parameter .empty else None ),
220
- required = is_required ,
221
- )
222
245
223
- return params
246
+ properties [name ] = {
247
+ "type" : _python_type_to_json_schema_type (type_hint ),
248
+ "description" : param_doc ,
249
+ }
250
+
251
+ if is_required :
252
+ required .append (name )
253
+
254
+ return {
255
+ "type" : "object" ,
256
+ "properties" : properties ,
257
+ "required" : required ,
258
+ }
224
259
225
260
def run_impl (self , ** kwargs ) -> Any :
226
261
if inspect .iscoroutinefunction (func ):
0 commit comments