Skip to content

Commit

Permalink
langgraph: expand handle_tool_errors in ToolNode (#1667)
Browse files Browse the repository at this point in the history
This change expands error-handling functionality of the `ToolNode` by
introducing more options for `handle_tool_errors`. Default behavior of
the `ToolNode` is unchanged -- all errors are handled and wrapped in a
`ToolMessage` to be sent back to LLM.

With this change, users have flexibility to only handle the exceptions
that they need to pass back to the LLM:

* they can specify exceptions to handle by passing a tuple of exceptions
in `handle_tool_errors`
* specify `handle_tool_errors=True/str/callable`
* when `handle_tool_errors` is a callable, the signature will be
inspected and exceptions from the signature will be handled

---------

Co-authored-by: vbarda <vadym@langchain.dev>
  • Loading branch information
isahers1 and vbarda authored Oct 24, 2024
1 parent 62a5ec5 commit bdc75a2
Show file tree
Hide file tree
Showing 2 changed files with 484 additions and 84 deletions.
136 changes: 127 additions & 9 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import inspect
import json
from copy import copy
from typing import (
Expand All @@ -16,6 +17,7 @@
Type,
Union,
cast,
get_type_hints,
)

from langchain_core.messages import (
Expand Down Expand Up @@ -67,13 +69,96 @@ def msg_content_output(output: Any) -> str | List[dict]:
return str(output)


def _handle_tool_error(
e: Exception,
*,
flag: Union[
bool,
str,
Callable[..., str],
tuple[type[Exception], ...],
],
) -> str:
if isinstance(flag, (bool, tuple)):
content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
elif isinstance(flag, str):
content = flag
elif callable(flag):
content = flag(e)
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {flag}"
)
return content


def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception]]:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
if params:
# If it's a method, the first argument is typically 'self' or 'cls'
if params[0].name in ["self", "cls"] and len(params) == 2:
first_param = params[1]
else:
first_param = params[0]

type_hints = get_type_hints(handler)
if first_param.name in type_hints:
origin = get_origin(first_param.annotation)
if origin is Union:
args = get_args(first_param.annotation)
if all(issubclass(arg, Exception) for arg in args):
return tuple(args)
else:
raise ValueError(
"All types in the error handler error annotation must be Exception types. "
"For example, `def custom_handler(e: Union[ValueError, TypeError])`. "
f"Got '{first_param.annotation}' instead."
)

exception_type = type_hints[first_param.name]
if Exception in exception_type.__mro__:
return (exception_type,)
else:
raise ValueError(
f"Arbitrary types are not supported in the error handler signature. "
"Please annotate the error with either a specific Exception type or a union of Exception types. "
"For example, `def custom_handler(e: ValueError)` or `def custom_handler(e: Union[ValueError, TypeError])`. "
f"Got '{exception_type}' instead."
)

# If no type information is available, return (Exception,) for backwards compatibility.
return (Exception,)


class ToolNode(RunnableCallable):
"""A node that runs the tools called in the last AIMessage.
It can be used either in StateGraph with a "messages" key (or a custom key passed via ToolNode's 'messages_key').
It can be used either in StateGraph with a "messages" state key (or a custom key passed via ToolNode's 'messages_key').
If multiple tool calls are requested, they will be run in parallel. The output will be
a list of ToolMessages, one for each tool call.
Args:
tools: A sequence of tools that can be invoked by the ToolNode.
name: The name of the ToolNode in the graph. Defaults to "tools".
tags: Optional tags to associate with the node. Defaults to None.
handle_tool_errors: How to handle tool errors raised by tools inside the node. Defaults to True.
Must be one of the following:
- True: all errors will be caught and
a ToolMessage with a default error message (TOOL_CALL_ERROR_TEMPLATE) will be returned.
- str: all errors will be caught and
a ToolMessage with the string value of 'handle_tool_errors' will be returned.
- tuple[type[Exception], ...]: exceptions in the tuple will be caught and
a ToolMessage with a default error message (TOOL_CALL_ERROR_TEMPLATE) will be returned.
- Callable[..., str]: exceptions from the signature of the callable will be caught and
a ToolMessage with the string value of the result of the 'handle_tool_errors' callable will be returned.
- False: none of the errors raised by the tools will be caught
messages_key: The state key in the input that contains the list of messages.
The same key will be used for the output from the ToolNode.
Defaults to "messages".
The `ToolNode` is roughly analogous to:
```python
Expand Down Expand Up @@ -101,7 +186,9 @@ def __init__(
*,
name: str = "tools",
tags: Optional[list[str]] = None,
handle_tool_errors: Optional[bool] = True,
handle_tool_errors: Union[
bool, str, Callable[..., str], tuple[type[Exception], ...]
] = True,
messages_key: str = "messages",
) -> None:
super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
Expand Down Expand Up @@ -181,14 +268,29 @@ def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
)
return tool_message
except Exception as e:
if not self.handle_tool_errors:
if isinstance(self.handle_tool_errors, tuple):
handled_types: tuple = self.handle_tool_errors
elif callable(self.handle_tool_errors):
handled_types = _infer_handled_types(self.handle_tool_errors)
else:
# default behavior is catching all exceptions
handled_types = (Exception,)

# Unhandled
if not self.handle_tool_errors or not isinstance(e, handled_types):
raise e
content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
return ToolMessage(content, name=call["name"], tool_call_id=call["id"])
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_errors)

return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)

async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message

try:
input = {**call, **{"type": "tool_call"}}
tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke(
Expand All @@ -199,10 +301,24 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage
)
return tool_message
except Exception as e:
if not self.handle_tool_errors:
if isinstance(self.handle_tool_errors, tuple):
handled_types: tuple = self.handle_tool_errors
elif callable(self.handle_tool_errors):
handled_types = _infer_handled_types(self.handle_tool_errors)
else:
# default behavior is catching all exceptions
handled_types = (Exception,)

# Unhandled
if not self.handle_tool_errors or not isinstance(e, handled_types):
raise e
content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
return ToolMessage(content, name=call["name"], tool_call_id=call["id"])
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_errors)

return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)

def _parse_input(
self,
Expand Down Expand Up @@ -240,7 +356,9 @@ def _validate_tool_call(self, call: ToolCall) -> Optional[ToolMessage]:
requested_tool=requested_tool,
available_tools=", ".join(self.tools_by_name.keys()),
)
return ToolMessage(content, name=requested_tool, tool_call_id=call["id"])
return ToolMessage(
content, name=requested_tool, tool_call_id=call["id"], status="error"
)
else:
return None

Expand Down
Loading

0 comments on commit bdc75a2

Please sign in to comment.