|
3 | 3 | from time import sleep
|
4 | 4 | import typer
|
5 | 5 | from pydantic import BaseModel, Json, TypeAdapter
|
| 6 | +from pydantic_core import SchemaValidator, core_schema |
6 | 7 | from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type
|
7 | 8 | import json, requests
|
8 | 9 |
|
|
13 | 14 | from examples.openai.prompting import ToolsPromptStyle
|
14 | 15 | from examples.openai.subprocesses import spawn_subprocess
|
15 | 16 |
|
16 |
| -def _get_params_schema(fn: Callable[[Any], Any], verbose): |
17 |
| - if isinstance(fn, OpenAPIMethod): |
18 |
| - return fn.parameters_schema |
19 |
| - |
20 |
| - # converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) |
21 |
| - schema = TypeAdapter(fn).json_schema() |
22 |
| - # Do NOT call converter.resolve_refs(schema) here. Let the server resolve local refs. |
23 |
| - if verbose: |
24 |
| - sys.stderr.write(f'# PARAMS SCHEMA: {json.dumps(schema, indent=2)}\n') |
25 |
| - return schema |
| 17 | +def make_call_adapter(ta: TypeAdapter, fn: Callable[..., Any]): |
| 18 | + args_validator = SchemaValidator(core_schema.call_schema( |
| 19 | + arguments=ta.core_schema['arguments_schema'], |
| 20 | + function=fn, |
| 21 | + )) |
| 22 | + return lambda **kwargs: args_validator.validate_python(kwargs) |
26 | 23 |
|
27 | 24 | def completion_with_tool_usage(
|
28 | 25 | *,
|
@@ -50,18 +47,28 @@ def completion_with_tool_usage(
|
50 | 47 | schema = type_adapter.json_schema()
|
51 | 48 | response_format=ResponseFormat(type="json_object", schema=schema)
|
52 | 49 |
|
53 |
| - tool_map = {fn.__name__: fn for fn in tools} |
54 |
| - tools_schemas = [ |
55 |
| - Tool( |
56 |
| - type="function", |
57 |
| - function=ToolFunction( |
58 |
| - name=fn.__name__, |
59 |
| - description=fn.__doc__ or '', |
60 |
| - parameters=_get_params_schema(fn, verbose=verbose) |
| 50 | + tool_map = {} |
| 51 | + tools_schemas = [] |
| 52 | + for fn in tools: |
| 53 | + if isinstance(fn, OpenAPIMethod): |
| 54 | + tool_map[fn.__name__] = fn |
| 55 | + parameters_schema = fn.parameters_schema |
| 56 | + else: |
| 57 | + ta = TypeAdapter(fn) |
| 58 | + tool_map[fn.__name__] = make_call_adapter(ta, fn) |
| 59 | + parameters_schema = ta.json_schema() |
| 60 | + if verbose: |
| 61 | + sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(parameters_schema, indent=2)}\n') |
| 62 | + tools_schemas.append( |
| 63 | + Tool( |
| 64 | + type="function", |
| 65 | + function=ToolFunction( |
| 66 | + name=fn.__name__, |
| 67 | + description=fn.__doc__ or '', |
| 68 | + parameters=parameters_schema, |
| 69 | + ) |
61 | 70 | )
|
62 | 71 | )
|
63 |
| - for fn in tools |
64 |
| - ] |
65 | 72 |
|
66 | 73 | i = 0
|
67 | 74 | while (max_iterations is None or i < max_iterations):
|
@@ -106,7 +113,7 @@ def completion_with_tool_usage(
|
106 | 113 | sys.stdout.write(f'⚙️ {pretty_call}')
|
107 | 114 | sys.stdout.flush()
|
108 | 115 | tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments)
|
109 |
| - sys.stdout.write(f" -> {tool_result}\n") |
| 116 | + sys.stdout.write(f" → {tool_result}\n") |
110 | 117 | messages.append(Message(
|
111 | 118 | tool_call_id=tool_call.id,
|
112 | 119 | role="tool",
|
@@ -203,6 +210,8 @@ def main(
|
203 | 210 | if std_tools:
|
204 | 211 | tool_functions.extend(collect_functions(StandardTools))
|
205 | 212 |
|
| 213 | + sys.stdout.write(f'🛠️ {", ".join(fn.__name__ for fn in tool_functions)}\n') |
| 214 | + |
206 | 215 | response_model: Union[type, Json[Any]] = None #str
|
207 | 216 | if format:
|
208 | 217 | if format in types:
|
|
0 commit comments