Skip to content

Commit

Permalink
Tool calling & argref changes (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan authored Oct 16, 2024
1 parent 9270797 commit 2a3037b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 46 deletions.
7 changes: 3 additions & 4 deletions src/aviary/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,17 +207,16 @@ async def _exec_tool_call(tool_call: ToolCall) -> ToolResponseMessage:
except Exception as exc:
if not handle_tool_exc:
raise
logger_msg = (
f"Failed to execute tool call for tool {tool.info.name}: {exc!r}"
)
logger_msg = f"Encountered exception during tool call for tool {tool.info.name}: {exc!r}"
# logger.exception is just too verbose and clogs up console logging. This is a
# more human-friendly version: log a readable error message and emit the exception
# at DEBUG level.
logger.error(logger_msg) # noqa: TRY400
logger.debug(str(exc), exc_info=True)
tool_exc = exc
if tool_exc:
s_content: str = f"{logger_msg}:\n{tool_exc}"
# No need to mention tool.info.name here, since it'll get wrapped in a ToolResponseMessage
s_content = f"Encountered exception during tool call: {tool_exc}"
elif isinstance(content, str):
s_content = content
elif isinstance(content, BaseModel):
Expand Down
83 changes: 50 additions & 33 deletions src/aviary/tools/argref.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@ def make_pretty_id(prefix: str = "") -> str:
return prefix + "-" + uuid_frags[0]


ARGREF_NOTE = "(Pass a string key instead of the full object)"
DEFAULT_ARGREF_NOTE = "(Pass a string key instead of the full object)"
LIST_ARGREF_NOTE = "(Pass comma-separated string keys instead of the full object)"


def argref_wrapper(wrapper, wrapped, args_to_skip: set[str] | None):
def argref_wrapper(wrapper, wrapped, args_to_skip: set[str]):
"""Inject the ARGREF_NOTE into the Args."""
args_to_skip = (args_to_skip or set()) | {"state", "return"}

# normal wraps
wrapped_func = update_wrapper(wrapper, wrapped)
# when we modify wrapped_func's annotations, we don't want to mutate wrapped
Expand All @@ -51,13 +50,17 @@ def argref_wrapper(wrapper, wrapped, args_to_skip: set[str] | None):
if param.arg_name in args_to_skip:
continue

note = DEFAULT_ARGREF_NOTE

if (
param.type_name is None
and (type_hint := orig_annots.get(param.arg_name)) is not None
):
param.type_name = _type_to_str(type_hint)
if list in {type_hint, get_origin(type_hint)}:
note = LIST_ARGREF_NOTE

param.description = (param.description or "") + f" {ARGREF_NOTE}"
param.description = (param.description or "") + f" {note}"

wrapped_func.__doc__ = compose(ds)

Expand Down Expand Up @@ -115,6 +118,7 @@ def argref_by_name( # noqa: C901, PLR0915
>>> # Equivalent to my_func(state.refs["a"], state.refs["b"])
>>> wrapped_fxn("a", "b", state=state) # doctest: +SKIP
"""
args_to_skip = (args_to_skip or set()) | {"state", "return"}

def decorator(func): # noqa: C901, PLR0915
def get_call_args(*args, **kwargs): # noqa: C901
Expand All @@ -129,26 +133,36 @@ def get_call_args(*args, **kwargs): # noqa: C901

# now convert the keynames to actual references (if they are a string)
# tuple is (arg, if was dereferenced)
def maybe_deref_arg(arg):
def maybe_deref_arg(arg, must_exist: bool) -> tuple[Any, bool]:
try:
refs = state.refs
except AttributeError as e:
raise AttributeError(
"The state object must have a 'refs' attribute to use argref_by_name decorator."
) from e

if arg in refs:
return [refs[arg]], True

if isinstance(arg, str):
try:
if arg in state.refs:
return [state.refs[arg]], True
# sometimes it is not correctly converted to a tuple
# so as an attempt to be helpful...
if all(a.strip() in state.refs for a in arg.split(",")):
return [state.refs[a.strip()] for a in arg.split(",")], True
# fall through
except AttributeError as e:
raise AttributeError(
"The state object must have a 'refs' attribute to use argref_by_name decorator."
) from e
return arg, False
# sometimes it is not correctly converted to a tuple
# so as an attempt to be helpful...
split_args = [a.strip() for a in arg.split(",")]
if all(a in refs for a in split_args):
return [refs[a] for a in split_args], True

if not must_exist:
return arg, False

raise KeyError(
f'Not a valid element of the current key-value store: "{arg}"'
)

# the split thing makes it complicated and we cannot use comprehension
deref_args = []
for i, arg in enumerate(args):
a, dr = maybe_deref_arg(arg)
# In order to support *args, allow arguments that are either ref keys or strings
a, dr = maybe_deref_arg(arg, must_exist=False)
if dr:
deref_args.extend(a)
else:
Expand All @@ -157,18 +171,21 @@ def maybe_deref_arg(arg):
# likely the user intended to use a reference
raise KeyError(f"The key {arg} is not found in state.")
deref_args.append(a)

deref_kwargs = {}
for k, v in kwargs.items():
a, dr = maybe_deref_arg(v)
if dr:
if len(a) > 1:
raise ValueError(
f"Multiple values for argument '{k}' found in state. "
"Cannot use comma-separated notation for kwargs."
)
deref_kwargs[k] = a[0]
else:
if args_to_skip and k in args_to_skip:
deref_kwargs[k] = v
continue

# In the kwarg case, force arguments to be ref keys (unless in args_to_skip)
a, _ = maybe_deref_arg(v, must_exist=True)
if len(a) > 1:
# We got multiple items, so pass the whole list
deref_kwargs[k] = a
else:
# We only got one item - pass it directly
deref_kwargs[k] = a[0]

return deref_args, deref_kwargs, state

Expand Down Expand Up @@ -234,8 +251,8 @@ def _check_arg_types(func: Callable, args, kwargs) -> None:
if expected_type and not _isinstance_with_generics(arg, expected_type):
wrong_types.append((
param,
expected_type.__name__,
type(arg).__name__,
_type_to_str(expected_type),
_type_to_str(type(arg)),
))

# Check keyword arguments
Expand All @@ -245,8 +262,8 @@ def _check_arg_types(func: Callable, args, kwargs) -> None:
wrong_types.append((
param,
# sometimes need str for generics like Union
getattr(expected_type, "__name__", str(expected_type)),
type(arg).__name__,
_type_to_str(expected_type),
_type_to_str(type(arg)),
))

if wrong_types:
Expand Down
31 changes: 22 additions & 9 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,16 +561,18 @@ def __init__(self):
self.refs = {"foo": 1}

# Check we can use argref_by_name to add 1 + 2 using a value in refs
wrapped_add = argref_by_name()(add)
wrapped_add = argref_by_name(args_to_skip={"b"})(add)
s = MyState()

result = wrapped_add("foo", 2, state=s)
# Now s.refs has a new entry at the below `name`
name = result.split()[0]
assert s.refs[name] == 1 + 2

# Check we can still use argref_by_name without refs
result = wrapped_add(6, 2, state=s)
assert s.refs[result.split()[0]] == 6 + 2
# Check kwargs work too
result = wrapped_add(a="foo", b=2, state=s)
name = result.split()[0]
assert s.refs[name] == 1 + 2


@pytest.mark.asyncio
Expand All @@ -584,12 +586,16 @@ def __init__(self):

# Check if we use a key name that doesn't exist, we blow up
with pytest.raises(KeyError, match="not found in state"):
wrapped_add("bar", 2, state=MyState())
wrapped_add("bar", 2, state=s)

# Check if state doesn't have refs, we blow up
with pytest.raises(AttributeError, match="must have a 'refs' attribute"):
wrapped_add("foo", 2, state="not a state")

# Check that we cannot pass a direct value as a kwarg
with pytest.raises(KeyError, match="Not a valid element"):
wrapped_add(a=1, b=2, state=s)


@pytest.mark.asyncio
async def test_argref_by_name_async_functions() -> None:
Expand Down Expand Up @@ -652,14 +658,14 @@ def skip_deref_test(foo: float, a: str) -> str:
"""Some docstring."""
return f"{foo} {a}"

assert skip_deref_test("foo", a="not in state", state=s) == "1 not in state"
assert skip_deref_test("foo", "not in state", state=s) == "1 not in state"
assert skip_deref_test("foo", "foo", state=s) == "1 1"

# Call in context using Tool and related classes
wrapped_add = argref_by_name()(add)
wrapped_add = argref_by_name(args_to_skip={"b"})(add)
tool = Tool.from_function(wrapped_add)

tool_call = ToolCall.from_tool(tool, "foo", 2)
tool_call = ToolCall.from_tool(tool, "foo", b=2)
action = ToolRequestMessage(tool_calls=[tool_call])
my_env = DummyEnv()
my_env.tools = [tool]
Expand Down Expand Up @@ -687,6 +693,13 @@ async def want_state(a: int, state: MyState) -> int: # noqa: ARG001
my_env.tools = [tool]
await my_env.exec_tool_calls(action, state=MyState())

# Check we can pass kwarg lists as comma-separated keys
@argref_by_name(return_direct=True)
def kwarg_list_test(a: list[int]) -> int:
return sum(a)

assert kwarg_list_test(a="foo,foo", state=s) == 2


@pytest.mark.asyncio
async def test_argref_by_name_type_checking() -> None:
Expand Down Expand Up @@ -717,7 +730,7 @@ def typed_fn(a: int, b) -> int: # noqa: ARG001
type_checked_fn(a="int_arg", b="str_arg", state=s) # correctly-typed
with pytest.raises(TypeError):
# A non-int value is passed to a by name
type_checked_fn(a="str_arg", b="bar", state=s)
type_checked_fn(a="str_arg", b="str_arg", state=s)

def complex_typed_fn(c: Sequence[int], d: int | str) -> None:
"""Some docstring."""
Expand Down

0 comments on commit 2a3037b

Please sign in to comment.