Skip to content

Commit 2ba7150

Browse files
committed
agent: fix wait --std-tools
1 parent 9b9f195 commit 2ba7150

File tree

2 files changed

+67
-46
lines changed

2 files changed

+67
-46
lines changed

examples/agent/agent.py

+30-21
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from time import sleep
44
import typer
55
from pydantic import BaseModel, Json, TypeAdapter
6+
from pydantic_core import SchemaValidator, core_schema
67
from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type
78
import json, requests
89

@@ -13,16 +14,12 @@
1314
from examples.openai.prompting import ToolsPromptStyle
1415
from examples.openai.subprocesses import spawn_subprocess
1516

16-
def _get_params_schema(fn: Callable[[Any], Any], verbose):
17-
if isinstance(fn, OpenAPIMethod):
18-
return fn.parameters_schema
19-
20-
# converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
21-
schema = TypeAdapter(fn).json_schema()
22-
# Do NOT call converter.resolve_refs(schema) here. Let the server resolve local refs.
23-
if verbose:
24-
sys.stderr.write(f'# PARAMS SCHEMA: {json.dumps(schema, indent=2)}\n')
25-
return schema
17+
def make_call_adapter(ta: TypeAdapter, fn: Callable[..., Any]):
18+
args_validator = SchemaValidator(core_schema.call_schema(
19+
arguments=ta.core_schema['arguments_schema'],
20+
function=fn,
21+
))
22+
return lambda **kwargs: args_validator.validate_python(kwargs)
2623

2724
def completion_with_tool_usage(
2825
*,
@@ -50,18 +47,28 @@ def completion_with_tool_usage(
5047
schema = type_adapter.json_schema()
5148
response_format=ResponseFormat(type="json_object", schema=schema)
5249

53-
tool_map = {fn.__name__: fn for fn in tools}
54-
tools_schemas = [
55-
Tool(
56-
type="function",
57-
function=ToolFunction(
58-
name=fn.__name__,
59-
description=fn.__doc__ or '',
60-
parameters=_get_params_schema(fn, verbose=verbose)
50+
tool_map = {}
51+
tools_schemas = []
52+
for fn in tools:
53+
if isinstance(fn, OpenAPIMethod):
54+
tool_map[fn.__name__] = fn
55+
parameters_schema = fn.parameters_schema
56+
else:
57+
ta = TypeAdapter(fn)
58+
tool_map[fn.__name__] = make_call_adapter(ta, fn)
59+
parameters_schema = ta.json_schema()
60+
if verbose:
61+
sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(parameters_schema, indent=2)}\n')
62+
tools_schemas.append(
63+
Tool(
64+
type="function",
65+
function=ToolFunction(
66+
name=fn.__name__,
67+
description=fn.__doc__ or '',
68+
parameters=parameters_schema,
69+
)
6170
)
6271
)
63-
for fn in tools
64-
]
6572

6673
i = 0
6774
while (max_iterations is None or i < max_iterations):
@@ -106,7 +113,7 @@ def completion_with_tool_usage(
106113
sys.stdout.write(f'⚙️ {pretty_call}')
107114
sys.stdout.flush()
108115
tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments)
109-
sys.stdout.write(f" -> {tool_result}\n")
116+
sys.stdout.write(f" {tool_result}\n")
110117
messages.append(Message(
111118
tool_call_id=tool_call.id,
112119
role="tool",
@@ -203,6 +210,8 @@ def main(
203210
if std_tools:
204211
tool_functions.extend(collect_functions(StandardTools))
205212

213+
sys.stdout.write(f'🛠️ {", ".join(fn.__name__ for fn in tool_functions)}\n')
214+
206215
response_model: Union[type, Json[Any]] = None #str
207216
if format:
208217
if format in types:

examples/agent/tools/std_tools.py

+37-25
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,18 @@ class Duration(BaseModel):
1616
years: Optional[int] = None
1717

1818
def __str__(self) -> str:
19-
return f"{self.years} years, {self.months} months, {self.days} days, {self.hours} hours, {self.minutes} minutes, {self.seconds} seconds"
19+
return ', '.join([
20+
x
21+
for x in [
22+
f"{self.years} years" if self.years else None,
23+
f"{self.months} months" if self.months else None,
24+
f"{self.days} days" if self.days else None,
25+
f"{self.hours} hours" if self.hours else None,
26+
f"{self.minutes} minutes" if self.minutes else None,
27+
f"{self.seconds} seconds" if self.seconds else None,
28+
]
29+
if x is not None
30+
])
2031

2132
@property
2233
def get_total_seconds(self) -> int:
@@ -36,25 +47,6 @@ def __call__(self):
3647
sys.stderr.write(f"Waiting for {self.duration}...\n")
3748
time.sleep(self.duration.get_total_seconds)
3849

39-
class WaitForDate(BaseModel):
40-
until: date
41-
42-
def __call__(self):
43-
# Get the current date
44-
current_date = datetime.date.today()
45-
46-
if self.until < current_date:
47-
raise ValueError("Target date cannot be in the past.")
48-
49-
time_diff = datetime.datetime.combine(self.until, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min)
50-
51-
days, seconds = time_diff.days, time_diff.seconds
52-
53-
sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {self.until}...\n")
54-
time.sleep(days * 86400 + seconds)
55-
sys.stderr.write(f"Reached the target date: {self.until}\n")
56-
57-
5850
class StandardTools:
5951

6052
@staticmethod
@@ -66,12 +58,32 @@ def ask_user(question: str) -> str:
6658
return typer.prompt(question)
6759

6860
@staticmethod
69-
def wait(_for: Union[WaitForDuration, WaitForDate]) -> None:
70-
'''
71-
Wait for a certain amount of time before continuing.
72-
This can be used to wait for a specific duration or until a specific date.
61+
def wait_for_duration(duration: Duration) -> None:
62+
'Wait for a certain amount of time before continuing.'
63+
64+
# sys.stderr.write(f"Waiting for {duration}...\n")
65+
time.sleep(duration.get_total_seconds)
66+
67+
@staticmethod
68+
def wait_for_date(target_date: date) -> None:
69+
f'''
70+
Wait until a specific date is reached before continuing.
71+
Today's date is {datetime.date.today()}
7372
'''
74-
return _for()
73+
74+
# Get the current date
75+
current_date = datetime.date.today()
76+
77+
if target_date < current_date:
78+
raise ValueError("Target date cannot be in the past.")
79+
80+
time_diff = datetime.datetime.combine(target_date, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min)
81+
82+
days, seconds = time_diff.days, time_diff.seconds
83+
84+
# sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {target_date}...\n")
85+
time.sleep(days * 86400 + seconds)
86+
# sys.stderr.write(f"Reached the target date: {target_date}\n")
7587

7688
@staticmethod
7789
def say_out_loud(something: str) -> None:

0 commit comments

Comments
 (0)