Skip to content

Commit

Permalink
Add a streaming json parser (#11193)
Browse files Browse the repository at this point in the history
<img width="1728" alt="Screenshot 2023-09-28 at 20 15 01"
src="https://github.com/langchain-ai/langchain/assets/56902/ed0644c3-6db7-41b9-9543-e34fce46d3e5">


<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
  • Loading branch information
nfcampos authored Sep 29, 2023
2 parents b4354b7 + ee56c61 commit 1ddf9f7
Show file tree
Hide file tree
Showing 7 changed files with 612 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,16 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
return_values={"output": message.content}, log=message.content
)

def parse_result(self, result: List[Generation]) -> Union[AgentAction, AgentFinish]:
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
return self._parse_ai_message(message)

async def aparse_result(
self, result: List[Generation]
self, result: List[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
Expand Down
92 changes: 85 additions & 7 deletions libs/langchain/langchain/output_parsers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import json
import re
from json import JSONDecodeError
from typing import Any, List
from typing import Any, Callable, List, Optional

from langchain.schema import BaseOutputParser, OutputParserException
import jsonpatch

from langchain.schema.output_parser import (
BaseCumulativeTransformOutputParser,
OutputParserException,
)


def _replace_new_line(match: re.Match[str]) -> str:
Expand Down Expand Up @@ -38,7 +43,70 @@ def _custom_parser(multiline_string: str) -> str:
return multiline_string


def parse_json_markdown(json_string: str) -> dict:
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
# MIT License
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
# Attempt to parse the string as-is.
try:
return json.loads(s, strict=strict)
except json.JSONDecodeError:
pass

# Initialize variables.
new_s = ""
stack = []
is_inside_string = False
escaped = False

# Process each character in the string one at a time.
for char in s:
if is_inside_string:
if char == '"' and not escaped:
is_inside_string = False
elif char == "\n" and not escaped:
char = "\\n" # Replace the newline character with the escape sequence.
elif char == "\\":
escaped = not escaped
else:
escaped = False
else:
if char == '"':
is_inside_string = True
escaped = False
elif char == "{":
stack.append("}")
elif char == "[":
stack.append("]")
elif char == "}" or char == "]":
if stack and stack[-1] == char:
stack.pop()
else:
# Mismatched closing character; the input is malformed.
return None

# Append the processed character to the new string.
new_s += char

# If we're still inside a string at the end of processing,
# we need to close the string.
if is_inside_string:
new_s += '"'

# Close any remaining open structures in the reverse order that they were opened.
for closing_char in reversed(stack):
new_s += closing_char

# Attempt to parse the modified string as JSON.
try:
return json.loads(new_s, strict=strict)
except json.JSONDecodeError:
# If we still can't parse the string as JSON, return None to indicate failure.
return None


def parse_json_markdown(
json_string: str, *, parser: Callable[[str], Any] = json.loads
) -> dict:
"""
Parse a JSON string from a Markdown string.
Expand All @@ -65,7 +133,7 @@ def parse_json_markdown(json_string: str) -> dict:
json_str = _custom_parser(json_str)

# Parse the JSON string into a Python dictionary
parsed = json.loads(json_str)
parsed = parser(json_str)

return parsed

Expand Down Expand Up @@ -95,13 +163,23 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
return json_obj


class SimpleJsonOutputParser(BaseOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object."""
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object.
When used in streaming mode, it will yield partial JSON objects containing
all the keys that have been returned so far.
In streaming, if `diff` is set to `True`, yields JSONPatch operations
describing the difference between the previous and the current object.
"""

def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch

def parse(self, text: str) -> Any:
text = text.strip()
try:
return json.loads(text)
return parse_json_markdown(text.strip(), parser=parse_partial_json)
except JSONDecodeError as e:
raise OutputParserException(f"Invalid json output: {text}") from e

Expand Down
105 changes: 79 additions & 26 deletions libs/langchain/langchain/output_parsers/openai_functions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import copy
import json
from typing import Any, Dict, List, Type, Union
from typing import Any, Dict, List, Optional, Type, Union

import jsonpatch

from langchain.output_parsers.json import parse_partial_json
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema import (
ChatGeneration,
Generation,
OutputParserException,
)
from langchain.schema.output_parser import BaseGenerationOutputParser
from langchain.schema.output_parser import (
BaseCumulativeTransformOutputParser,
BaseGenerationOutputParser,
)


class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
Expand All @@ -17,7 +23,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
args_only: bool = True
"""Whether to only return the arguments to the function call."""

def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
Expand All @@ -34,7 +40,7 @@ def parse_result(self, result: List[Generation]) -> Any:
return func_call


class JsonOutputFunctionsParser(OutputFunctionsParser):
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse an output as the Json object."""

strict: bool = False
Expand All @@ -45,25 +51,72 @@ class JsonOutputFunctionsParser(OutputFunctionsParser):
Useful when the parsed output may include unicode characters or new lines.
"""

def parse_result(self, result: List[Generation]) -> Any:
function_call_info = super().parse_result(result)
if self.args_only:
try:
return json.loads(function_call_info, strict=self.strict)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
else:
try:
function_call_info["arguments"] = json.loads(
function_call_info["arguments"], strict=self.strict
)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
return function_call_info
args_only: bool = True
"""Whether to only return the arguments to the function call."""

def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
)
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
function_call = message.additional_kwargs["function_call"]
except KeyError as exc:
if partial:
return None
else:
raise OutputParserException(f"Could not parse function call: {exc}")
try:
if partial:
if self.args_only:
return parse_partial_json(
function_call["arguments"], strict=self.strict
)
else:
return {
**function_call,
"arguments": parse_partial_json(
function_call["arguments"], strict=self.strict
),
}
else:
if self.args_only:
try:
return json.loads(
function_call["arguments"], strict=self.strict
)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
else:
try:
return {
**function_call,
"arguments": json.loads(
function_call["arguments"], strict=self.strict
),
}
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
except KeyError:
return None

# This method would be called by the default implementation of `parse_result`
# but we're overriding that method so it's not needed.
def parse(self, text: str) -> Any:
raise NotImplementedError()


class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
Expand All @@ -72,7 +125,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str
"""The name of the key to return."""

def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
res = super().parse_result(result)
return res[self.key_name]

Expand All @@ -97,7 +150,7 @@ def validate_schema(cls, values: Dict) -> Dict:
)
return values

def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
Expand All @@ -114,6 +167,6 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
attr_name: str
"""The name of the attribute to return."""

def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
result = super().parse_result(result)
return getattr(result, self.attr_name)
Loading

0 comments on commit 1ddf9f7

Please sign in to comment.