Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a streaming json parser #11193

Merged
merged 15 commits into from
Sep 29, 2023
119 changes: 115 additions & 4 deletions libs/langchain/langchain/output_parsers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
import json
import re
from json import JSONDecodeError
from typing import Any, List
from typing import Any, Callable, List, Optional

import jsonpatch

from langchain.schema import BaseOutputParser, OutputParserException
from langchain.schema.output import ChatGeneration, Generation
from langchain.schema.output_parser import BaseCumulativeTransformOutputParser


def _replace_new_line(match: re.Match[str]) -> str:
Expand Down Expand Up @@ -38,7 +42,70 @@ def _custom_parser(multiline_string: str) -> str:
return multiline_string


def parse_json_markdown(json_string: str) -> dict:
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
# MIT License
def parse_partial_json(s: str) -> Any:
# Attempt to parse the string as-is.
try:
return json.loads(s)
except json.JSONDecodeError:
pass

# Initialize variables.
new_s = ""
stack = []
is_inside_string = False
escaped = False

# Process each character in the string one at a time.
for char in s:
if is_inside_string:
if char == '"' and not escaped:
is_inside_string = False
elif char == "\n" and not escaped:
char = "\\n" # Replace the newline character with the escape sequence.
elif char == "\\":
escaped = not escaped
else:
escaped = False
else:
if char == '"':
is_inside_string = True
escaped = False
elif char == "{":
stack.append("}")
elif char == "[":
stack.append("]")
elif char == "}" or char == "]":
if stack and stack[-1] == char:
stack.pop()
else:
# Mismatched closing character; the input is malformed.
return None

# Append the processed character to the new string.
new_s += char

# If we're still inside a string at the end of processing,
# we need to close the string.
if is_inside_string:
new_s += '"'

# Close any remaining open structures in the reverse order that they were opened.
for closing_char in reversed(stack):
new_s += closing_char

# Attempt to parse the modified string as JSON.
try:
return json.loads(new_s)
except json.JSONDecodeError:
# If we still can't parse the string as JSON, return None to indicate failure.
return None


def parse_json_markdown(
json_string: str, parser: Callable[[str], Any] = json.loads
nfcampos marked this conversation as resolved.
Show resolved Hide resolved
) -> dict:
"""
Parse a JSON string from a Markdown string.

Expand All @@ -65,7 +132,7 @@ def parse_json_markdown(json_string: str) -> dict:
json_str = _custom_parser(json_str)

# Parse the JSON string into a Python dictionary
parsed = json.loads(json_str)
parsed = parser(json_str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I'm misreading the regexp, but it seems to require closing ```? Should that requirement be relaxed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, but I don't want to change that in this pr


return parsed

Expand Down Expand Up @@ -101,10 +168,54 @@ class SimpleJsonOutputParser(BaseOutputParser[Any]):
def parse(self, text: str) -> Any:
text = text.strip()
try:
return json.loads(text)
return parse_partial_json(text)
except JSONDecodeError as e:
raise OutputParserException(f"Invalid json output: {text}") from e

@property
def _type(self) -> str:
return "simple_json_output_parser"


class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
@property
def _type(self) -> str:
return "partial_functions_json"

def parse_result(self, result: List[Generation]) -> Any:
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
)
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
function_call = message.additional_kwargs["function_call"]
except KeyError:
return None
try:
return parse_partial_json(function_call["arguments"])
except KeyError:
return None

def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch

def parse(self, text: str) -> Any:
nfcampos marked this conversation as resolved.
Show resolved Hide resolved
pass


class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
@property
def _type(self) -> str:
return "partial_functions_json"

def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch

def parse(self, text: str) -> Any:
return parse_json_markdown(text, parse_partial_json)
72 changes: 70 additions & 2 deletions libs/langchain/langchain/schema/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
from typing_extensions import get_args

from langchain.load.serializable import Serializable
from langchain.schema.messages import AnyMessage, BaseMessage
from langchain.schema.output import ChatGeneration, Generation
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
from langchain.schema.output import (
ChatGeneration,
ChatGenerationChunk,
Generation,
GenerationChunk,
)
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig

Expand Down Expand Up @@ -329,6 +334,69 @@ async def atransform(
yield chunk


class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""

diff: bool = False

def _diff(self, prev: Optional[T], next: T) -> T:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps document this method?

Do you forsee a situation in which _diff is defined but the user passes diff=False ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea some users might want the full cumulative output every time

raise NotImplementedError()

def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None
acc_gen = None
for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)

if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen

parsed = self.parse_result([acc_gen])
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed

async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
prev_parsed = None
acc_gen = None
async for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)

if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen

parsed = self.parse_result([acc_gen])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be using aparse_result if it's available?

if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed


class StrOutputParser(BaseTransformOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string."""

Expand Down
Loading