Skip to content

Commit 67b3d02

Browse files
committed
fix(manual): update lib/ code for the input_schema changes
1 parent 064b98b commit 67b3d02

File tree

2 files changed

+60
-25
lines changed

2 files changed

+60
-25
lines changed

src/llama_stack_client/lib/agents/client_tool.py

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,17 @@
1919
Union,
2020
)
2121

22+
from typing_extensions import TypedDict
23+
2224
from llama_stack_client.types import CompletionMessage, Message
2325
from llama_stack_client.types.alpha import ToolResponse
24-
from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam
26+
from llama_stack_client.types.tool_def_param import ToolDefParam
27+
28+
29+
class JSONSchema(TypedDict, total=False):
30+
type: str
31+
properties: Dict[str, Any]
32+
required: List[str]
2533

2634

2735
class ClientTool:
@@ -47,28 +55,18 @@ def get_description(self) -> str:
4755
raise NotImplementedError
4856

4957
@abstractmethod
50-
def get_params_definition(self) -> Dict[str, Parameter]:
58+
def get_input_schema(self) -> JSONSchema:
5159
raise NotImplementedError
5260

5361
def get_instruction_string(self) -> str:
5462
return f"Use the function '{self.get_name()}' to: {self.get_description()}"
5563

56-
def parameters_for_system_prompt(self) -> str:
57-
return json.dumps(
58-
{
59-
"name": self.get_name(),
60-
"description": self.get_description(),
61-
"parameters": {name: definition for name, definition in self.get_params_definition().items()},
62-
}
63-
)
64-
6564
def get_tool_definition(self) -> ToolDefParam:
6665
return ToolDefParam(
6766
name=self.get_name(),
6867
description=self.get_description(),
69-
parameters=list(self.get_params_definition().values()),
68+
input_schema=self.get_input_schema(),
7069
metadata={},
71-
tool_prompt_format="python_list",
7270
)
7371

7472
def run(
@@ -148,6 +146,37 @@ def async_run_impl(self, **kwargs):
148146
T = TypeVar("T", bound=Callable)
149147

150148

149+
def _python_type_to_json_schema_type(type_hint: Any) -> str:
150+
"""Convert Python type hints to JSON Schema type strings."""
151+
# Handle Union types (e.g., Optional[str])
152+
origin = get_origin(type_hint)
153+
if origin is Union:
154+
# Get non-None types from Union
155+
args = [arg for arg in get_args(type_hint) if arg is not type(None)]
156+
if args:
157+
type_hint = args[0] # Use first non-None type
158+
159+
# Get the actual type if it's a generic
160+
if hasattr(type_hint, "__origin__"):
161+
type_hint = type_hint.__origin__
162+
163+
# Map Python types to JSON Schema types
164+
type_name = getattr(type_hint, "__name__", str(type_hint))
165+
166+
type_mapping = {
167+
"bool": "boolean",
168+
"int": "integer",
169+
"float": "number",
170+
"str": "string",
171+
"list": "array",
172+
"dict": "object",
173+
"List": "array",
174+
"Dict": "object",
175+
}
176+
177+
return type_mapping.get(type_name, "string") # Default to string if unknown
178+
179+
151180
def client_tool(func: T) -> ClientTool:
152181
"""
153182
Decorator to convert a function into a ClientTool.
@@ -188,13 +217,14 @@ def get_description(self) -> str:
188217
f"No description found for client tool {__name__}. Please provide a RST-style docstring with description and :param tags for each parameter."
189218
)
190219

191-
def get_params_definition(self) -> Dict[str, Parameter]:
220+
def get_input_schema(self) -> JSONSchema:
192221
hints = get_type_hints(func)
193222
# Remove return annotation if present
194223
hints.pop("return", None)
195224

196225
# Get parameter descriptions from docstring
197-
params = {}
226+
properties = {}
227+
required = []
198228
sig = inspect.signature(func)
199229
doc = inspect.getdoc(func) or ""
200230

@@ -212,15 +242,20 @@ def get_params_definition(self) -> Dict[str, Parameter]:
212242
param = sig.parameters[name]
213243
is_optional_type = get_origin(type_hint) is Union and type(None) in get_args(type_hint)
214244
is_required = param.default == inspect.Parameter.empty and not is_optional_type
215-
params[name] = Parameter(
216-
name=name,
217-
description=param_doc or f"Parameter {name}",
218-
parameter_type=type_hint.__name__,
219-
default=(param.default if param.default != inspect.Parameter.empty else None),
220-
required=is_required,
221-
)
222245

223-
return params
246+
properties[name] = {
247+
"type": _python_type_to_json_schema_type(type_hint),
248+
"description": param_doc,
249+
}
250+
251+
if is_required:
252+
required.append(name)
253+
254+
return {
255+
"type": "object",
256+
"properties": properties,
257+
"required": required,
258+
}
224259

225260
def run_impl(self, **kwargs) -> Any:
226261
if inspect.iscoroutinefunction(func):

src/llama_stack_client/lib/agents/react/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_tool_defs(
3737
{
3838
"name": tool.identifier,
3939
"description": tool.description,
40-
"parameters": tool.parameters,
40+
"input_schema": tool.input_schema,
4141
}
4242
for tool in client.tools.list(toolgroup_id=toolgroup_id)
4343
]
@@ -48,7 +48,7 @@ def get_tool_defs(
4848
{
4949
"name": tool.get_name(),
5050
"description": tool.get_description(),
51-
"parameters": tool.get_params_definition(),
51+
"input_schema": tool.get_input_schema(),
5252
}
5353
for tool in client_tools
5454
]

0 commit comments

Comments
 (0)