Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

best effort parsing + handle parsing errors #20111

Merged
merged 6 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions libs/core/langchain_core/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from langchain_core.messages.ai import (
AIMessage,
AIMessageChunk,
InvalidToolCall,
ToolCall,
ToolCallChunk,
)
Expand Down Expand Up @@ -55,6 +56,7 @@
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"InvalidToolCall",
"MessageLikeRepresentation",
"SystemMessage",
"SystemMessageChunk",
Expand Down
77 changes: 23 additions & 54 deletions libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,24 @@
import warnings
from json import JSONDecodeError
from typing import Any, Dict, List, Literal, Optional
from typing import Any, List, Literal, Optional, Union

from langchain_core.load import Serializable
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
from langchain_core.messages.tool import (
InvalidToolCall,
ToolCall,
ToolCallChunk,
default_tool_chunk_parser,
default_tool_parser,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import parse_partial_json


class ToolCall(Serializable):
"""A call to a tool.

Attributes:
name: (str) the name of the tool to be called
args: (dict) the arguments to the tool call
id: (str) if provided, an identifier associated with the tool call
index: (int) if provided, the index of the tool call in a sequence
of content
"""

name: str
args: Dict[str, Any]
id: Optional[str] = None
index: Optional[int] = None


class ToolCallChunk(Serializable):
"""A chunk of a tool call (e.g., as part of a stream).

When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
all string attributes are concatenated. Chunks are only merged if their
values of `index` are equal and not None.

Example:

.. code-block:: python
left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)]
right_chunks = [ToolCallChunk(name=None, args='1}', index=0)]
(
AIMessageChunk(content="", tool_call_chunks=left_chunks)
+ AIMessageChunk(content="", tool_call_chunks=right_chunks)
).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)]

Attributes:
name: (str) if provided, a substring of the name of the tool to be called
args: (str) if provided, a JSON substring of the arguments to the tool call
id: (str) if provided, a substring of an identifier for the tool call
index: (int) if provided, the index of the tool call in a sequence
"""

name: Optional[str] = None
args: Optional[str] = None
id: Optional[str] = None
index: Optional[int] = None
from langchain_core.utils.json import (
parse_partial_json,
)


class AIMessage(BaseMessage):
Expand All @@ -68,7 +29,7 @@ class AIMessage(BaseMessage):
conversation.
"""

tool_calls: Optional[List[ToolCall]] = None
tool_calls: Optional[List[Union[ToolCall, InvalidToolCall]]] = None
"""If provided, tool calls associated with the message."""

type: Literal["ai"] = "ai"
Expand All @@ -84,10 +45,18 @@ def _backwards_compat_tool_calls(cls, values: dict) -> dict:
tool_calls = values.get("tool_calls") or values.get("tool_call_chunks")
if raw_tool_calls and not tool_calls:
warnings.warn(
"You appear to be using an old tool calling model, please upgrade "
"your packages to versions that set message tool calls."
"New langchain packages are available that more efficiently handle "
"tool calling. Please upgrade your packages to versions that set "
"message tool calls. e.g., `pip install --upgrade langchain-anthropic"
"`, pip install--upgrade langchain-openai`, etc."
)
# TODO: best-effort parsing
try:
if issubclass(cls, AIMessageChunk): # type: ignore
values["tool_call_chunks"] = default_tool_chunk_parser(raw_tool_calls)
else:
values["tool_calls"] = default_tool_parser(raw_tool_calls)
except Exception:
pass
return values


Expand Down
104 changes: 103 additions & 1 deletion libs/core/langchain_core/messages/tool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, List, Literal
import json
from typing import Any, Dict, List, Literal, Optional

from langchain_core.load import Serializable
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
Expand Down Expand Up @@ -61,3 +63,103 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
)

return super().__add__(other)


class ToolCall(Serializable):
"""A call to a tool.

Attributes:
name: (str) the name of the tool to be called
args: (dict) the arguments to the tool call
id: (str) if provided, an identifier associated with the tool call
index: (int) if provided, the index of the tool call in a sequence
of content
"""

name: str
args: Dict[str, Any]
id: Optional[str] = None
index: Optional[int] = None


class ToolCallChunk(Serializable):
"""A chunk of a tool call (e.g., as part of a stream).

When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
all string attributes are concatenated. Chunks are only merged if their
values of `index` are equal and not None.

Example:

.. code-block:: python
left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)]
right_chunks = [ToolCallChunk(name=None, args='1}', index=0)]
(
AIMessageChunk(content="", tool_call_chunks=left_chunks)
+ AIMessageChunk(content="", tool_call_chunks=right_chunks)
).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)]

Attributes:
name: (str) if provided, a substring of the name of the tool to be called
args: (str) if provided, a JSON substring of the arguments to the tool call
id: (str) if provided, a substring of an identifier for the tool call
index: (int) if provided, the index of the tool call in a sequence
"""

name: Optional[str] = None
args: Optional[str] = None
id: Optional[str] = None
index: Optional[int] = None


class InvalidToolCall(Serializable):
"""Allowance for errors made by LLM.

Here we add an `error` key to surface errors made during generation
(e.g., invalid JSON arguments.)
"""

name: Optional[str] = None
args: Optional[str] = None
id: Optional[str] = None
error: Optional[str] = None


def default_tool_parser(raw_tool_calls: List[dict]) -> List[ToolCall]:
"""Best-effort parsing of tools."""
tool_calls = []
for tool_call in raw_tool_calls:
if "function" not in tool_call:
function_args = None
function_name = None
else:
function_args = json.loads(tool_call["function"]["arguments"])
function_name = tool_call["function"]["name"]
parsed = ToolCall(
name=function_name or "",
args=function_args or {},
id=tool_call.get("id"),
index=tool_call.get("index"),
)
tool_calls.append(parsed)
return tool_calls


def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]:
"""Best-effort parsing of tool chunks."""
tool_call_chunks = []
for tool_call in raw_tool_calls:
if "function" not in tool_call:
function_args = None
function_name = None
else:
function_args = tool_call["function"]["arguments"]
function_name = tool_call["function"]["name"]
parsed = ToolCallChunk(
name=function_name,
args=function_args,
id=tool_call.get("id"),
index=tool_call.get("index"),
)
tool_call_chunks.append(parsed)
return tool_call_chunks
90 changes: 61 additions & 29 deletions libs/core/langchain_core/output_parsers/openai_tools.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,86 @@
import copy
import json
from json import JSONDecodeError
from typing import Any, List, Type
from typing import Any, Dict, List, Optional, Type

from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.output_parsers import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.utils.json import parse_partial_json


def parse_tool_call(
raw_tool_call: Dict[str, Any],
*,
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> Optional[Dict[str, Any]]:
"""Parse a single tool call."""
if "function" not in raw_tool_call:
return None
if partial:
try:
function_args = parse_partial_json(
raw_tool_call["function"]["arguments"], strict=strict
)
except (JSONDecodeError, TypeError): # None args raise TypeError
return None
else:
try:
function_args = json.loads(
raw_tool_call["function"]["arguments"], strict=strict
)
except JSONDecodeError as e:
raise OutputParserException(
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
f"Received JSONDecodeError {e}"
)
parsed = {
"name": raw_tool_call["function"]["name"] or "",
"args": function_args or {},
}
if return_id:
parsed["id"] = raw_tool_call["id"]
return parsed


def make_invalid_tool_call(
raw_tool_call: Dict[str, Any],
error_msg: Optional[str],
) -> InvalidToolCall:
"""Create an InvalidToolCall from a raw tool call."""
return InvalidToolCall(
name=raw_tool_call["function"]["name"],
args=raw_tool_call["function"]["arguments"],
id=raw_tool_call.get("id"),
error=error_msg,
)


def parse_tool_calls(
raw_tool_calls: List[dict],
*,
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> List[dict]:
"""Parse a list of tool calls."""
final_tools = []
exceptions = []
for tool_call in raw_tool_calls:
if "function" not in tool_call:
try:
parsed = parse_tool_call(
tool_call, partial=partial, strict=strict, return_id=return_id
)
if parsed:
final_tools.append(parsed)
except OutputParserException as e:
exceptions.append(str(e))
continue
if partial:
try:
function_args = parse_partial_json(
tool_call["function"]["arguments"], strict=strict
)
except (JSONDecodeError, TypeError): # None args raise TypeError
continue
else:
try:
function_args = json.loads(
tool_call["function"]["arguments"], strict=strict
)
except JSONDecodeError as e:
exceptions.append(
f"Function {tool_call['function']['name']} arguments:\n\n"
f"{tool_call['function']['arguments']}\n\nare not valid JSON. "
f"Received JSONDecodeError {e}"
)
continue
parsed = {
"name": tool_call["function"]["name"] or "",
"args": function_args or {},
}
if return_id:
parsed["id"] = tool_call["id"]
final_tools.append(parsed)
if exceptions:
raise OutputParserException("\n\n".join(exceptions))
return final_tools
Expand Down
1 change: 1 addition & 0 deletions libs/core/tests/unit_tests/messages/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"InvalidToolCall",
"SystemMessage",
"SystemMessageChunk",
"ToolCall",
Expand Down
Loading
Loading