Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a streaming json parser #11193

Merged
merged 15 commits into from
Sep 29, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
return_values={"output": message.content}, log=message.content
)

def parse_result(self, result: List[Generation]) -> Union[AgentAction, AgentFinish]:
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
Expand Down
92 changes: 85 additions & 7 deletions libs/langchain/langchain/output_parsers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import json
import re
from json import JSONDecodeError
from typing import Any, List
from typing import Any, Callable, List, Optional

from langchain.schema import BaseOutputParser, OutputParserException
import jsonpatch

from langchain.schema.output_parser import (
BaseCumulativeTransformOutputParser,
OutputParserException,
)


def _replace_new_line(match: re.Match[str]) -> str:
Expand Down Expand Up @@ -38,7 +43,70 @@ def _custom_parser(multiline_string: str) -> str:
return multiline_string


def parse_json_markdown(json_string: str) -> dict:
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
# MIT License
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
# Attempt to parse the string as-is.
try:
return json.loads(s, strict=strict)
except json.JSONDecodeError:
pass

# Initialize variables.
new_s = ""
stack = []
is_inside_string = False
escaped = False

# Process each character in the string one at a time.
for char in s:
if is_inside_string:
if char == '"' and not escaped:
is_inside_string = False
elif char == "\n" and not escaped:
char = "\\n" # Replace the newline character with the escape sequence.
elif char == "\\":
escaped = not escaped
else:
escaped = False
else:
if char == '"':
is_inside_string = True
escaped = False
elif char == "{":
stack.append("}")
elif char == "[":
stack.append("]")
elif char == "}" or char == "]":
if stack and stack[-1] == char:
stack.pop()
else:
# Mismatched closing character; the input is malformed.
return None

# Append the processed character to the new string.
new_s += char

# If we're still inside a string at the end of processing,
# we need to close the string.
if is_inside_string:
new_s += '"'

# Close any remaining open structures in the reverse order that they were opened.
for closing_char in reversed(stack):
new_s += closing_char

# Attempt to parse the modified string as JSON.
try:
return json.loads(new_s, strict=strict)
except json.JSONDecodeError:
# If we still can't parse the string as JSON, return None to indicate failure.
return None


def parse_json_markdown(
json_string: str, parser: Callable[[str], Any] = json.loads
nfcampos marked this conversation as resolved.
Show resolved Hide resolved
) -> dict:
"""
Parse a JSON string from a Markdown string.

Expand All @@ -65,7 +133,7 @@ def parse_json_markdown(json_string: str) -> dict:
json_str = _custom_parser(json_str)

# Parse the JSON string into a Python dictionary
parsed = json.loads(json_str)
parsed = parser(json_str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I'm misreading the regexp, but it seems to require closing ```? Should that requirement be relaxed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, but I don't want to change that in this pr


return parsed

Expand Down Expand Up @@ -95,13 +163,23 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
return json_obj


class SimpleJsonOutputParser(BaseOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object."""
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object.

When used in streaming mode, it will yield partial JSON objects containing
all the keys that have been returned so far.

In streaming, if `diff` is set to `True`, yields JSONPatch operations
describing the difference between the previous and the current object.
"""

def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch

def parse(self, text: str) -> Any:
nfcampos marked this conversation as resolved.
Show resolved Hide resolved
text = text.strip()
try:
return json.loads(text)
return parse_json_markdown(text.strip(), parse_partial_json)
nfcampos marked this conversation as resolved.
Show resolved Hide resolved
except JSONDecodeError as e:
raise OutputParserException(f"Invalid json output: {text}") from e

Expand Down
105 changes: 79 additions & 26 deletions libs/langchain/langchain/output_parsers/openai_functions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import copy
import json
from typing import Any, Dict, List, Type, Union
from typing import Any, Dict, List, Optional, Type, Union

import jsonpatch

from langchain.output_parsers.json import parse_partial_json
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema import (
ChatGeneration,
Generation,
OutputParserException,
)
from langchain.schema.output_parser import BaseGenerationOutputParser
from langchain.schema.output_parser import (
BaseCumulativeTransformOutputParser,
BaseGenerationOutputParser,
)


class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
Expand All @@ -17,7 +23,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
args_only: bool = True
"""Whether to only return the arguments to the function call."""

def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this raise on patial = True?

generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
Expand All @@ -34,7 +40,7 @@ def parse_result(self, result: List[Generation]) -> Any:
return func_call


class JsonOutputFunctionsParser(OutputFunctionsParser):
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse an output as the Json object."""

strict: bool = False
Expand All @@ -45,25 +51,72 @@ class JsonOutputFunctionsParser(OutputFunctionsParser):
Useful when the parsed output may include unicode characters or new lines.
"""

def parse_result(self, result: List[Generation]) -> Any:
function_call_info = super().parse_result(result)
if self.args_only:
try:
return json.loads(function_call_info, strict=self.strict)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
else:
try:
function_call_info["arguments"] = json.loads(
function_call_info["arguments"], strict=self.strict
)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
return function_call_info
args_only: bool = True
"""Whether to only return the arguments to the function call."""

def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
)
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
function_call = message.additional_kwargs["function_call"]
except KeyError as exc:
if partial:
return None
else:
raise OutputParserException(f"Could not parse function call: {exc}")
try:
if partial:
if self.args_only:
return parse_partial_json(
function_call["arguments"], strict=self.strict
)
else:
return {
**function_call,
"arguments": parse_partial_json(
function_call["arguments"], strict=self.strict
),
}
else:
if self.args_only:
try:
return json.loads(
function_call["arguments"], strict=self.strict
)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
else:
try:
return {
**function_call,
"arguments": json.loads(
function_call["arguments"], strict=self.strict
),
}
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
except KeyError:
return None

# This method would be called by the default implementation of `parse_result`
# but we're overriding that method so it's not needed.
def parse(self, text: str) -> Any:
raise NotImplementedError()


class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
Expand All @@ -72,7 +125,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str
"""The name of the key to return."""

def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
res = super().parse_result(result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

propagate partial?

return res[self.key_name]

Expand All @@ -97,7 +150,7 @@ def validate_schema(cls, values: Dict) -> Dict:
)
return values

def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

propagate partial and have underlying implementation raise if it's not supported?

_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
Expand All @@ -114,6 +167,6 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
attr_name: str
"""The name of the attribute to return."""

def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
result = super().parse_result(result)
return getattr(result, self.attr_name)
Loading
Loading