Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RunnableGenerator #11214

Merged
merged 3 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 162 additions & 17 deletions libs/langchain/langchain/schema/runnable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __or__(
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSequence[Input, Other]:
Expand All @@ -132,7 +133,8 @@ def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Callable[[Other], Any],
Callable[[Iterator[Other]], Iterator[Any]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSequence[Other, Output]:
Expand Down Expand Up @@ -353,7 +355,7 @@ def transform(
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final += chunk # type: ignore[operator]
final = final + chunk # type: ignore[operator]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was incorrect before, eg. on lists it would mutate the list


if got_first_val:
yield from self.stream(final, config, **kwargs)
Expand All @@ -379,7 +381,7 @@ async def atransform(
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final += chunk # type: ignore[operator]
final = final + chunk # type: ignore[operator]

if got_first_val:
async for output in self.astream(final, config, **kwargs):
Expand Down Expand Up @@ -453,6 +455,7 @@ def _call_with_config(
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think can do just Any the optional is understood with kwargs? (doesn't really matter)

) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
Expand All @@ -465,7 +468,9 @@ def _call_with_config(
name=config.get("run_name"),
)
try:
output = call_func_with_variable_args(func, input, run_manager, config)
output = call_func_with_variable_args(
func, input, run_manager, config, **kwargs
)
except BaseException as e:
run_manager.on_chain_error(e)
raise
Expand All @@ -486,6 +491,7 @@ async def _acall_with_config(
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses."""
Expand All @@ -499,7 +505,7 @@ async def _acall_with_config(
)
try:
output = await acall_func_with_variable_args(
func, input, run_manager, config
func, input, run_manager, config, **kwargs
)
except BaseException as e:
await run_manager.on_chain_error(e)
Expand All @@ -526,6 +532,7 @@ def _batch_with_config(
*,
return_exceptions: bool = False,
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
Expand All @@ -546,7 +553,6 @@ def _batch_with_config(
)
]
try:
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = [
patch_config(c, callbacks=rm.get_child())
Expand Down Expand Up @@ -597,6 +603,7 @@ async def _abatch_with_config(
*,
return_exceptions: bool = False,
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
Expand All @@ -619,7 +626,6 @@ async def _abatch_with_config(
)
)
try:
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = [
patch_config(c, callbacks=rm.get_child())
Expand Down Expand Up @@ -668,6 +674,7 @@ def _transform_stream_with_config(
],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
"""Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks.
Expand All @@ -689,7 +696,6 @@ def _transform_stream_with_config(
name=config.get("run_name"),
)
try:
kwargs: Dict[str, Any] = {}
if accepts_config(transformer):
kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child()
Expand All @@ -706,7 +712,7 @@ def _transform_stream_with_config(
final_output = chunk
else:
try:
final_output += chunk # type: ignore[operator]
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = None
final_output_supported = False
Expand All @@ -716,7 +722,7 @@ def _transform_stream_with_config(
final_input = ichunk
else:
try:
final_input += ichunk # type: ignore[operator]
final_input = final_input + ichunk # type: ignore
except TypeError:
final_input = None
final_input_supported = False
Expand Down Expand Up @@ -746,6 +752,7 @@ async def _atransform_stream_with_config(
],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
"""Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks.
Expand All @@ -767,7 +774,6 @@ async def _atransform_stream_with_config(
name=config.get("run_name"),
)
try:
kwargs: Dict[str, Any] = {}
if accepts_config(transformer):
kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child()
Expand All @@ -784,7 +790,7 @@ async def _atransform_stream_with_config(
final_output = chunk
else:
try:
final_output += chunk # type: ignore[operator]
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = None
final_output_supported = False
Expand All @@ -794,7 +800,7 @@ async def _atransform_stream_with_config(
final_input = ichunk
else:
try:
final_input += ichunk # type: ignore[operator]
final_input = final_input + ichunk # type: ignore[operator]
except TypeError:
final_input = None
final_input_supported = False
Expand Down Expand Up @@ -1311,6 +1317,7 @@ def __or__(
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSequence[Input, Other]:
Expand All @@ -1331,7 +1338,8 @@ def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Callable[[Other], Any],
Callable[[Iterator[Other]], Iterator[Any]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSequence[Other, Output]:
Expand Down Expand Up @@ -1751,7 +1759,7 @@ def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
elif other[key] is not None:
chunk[key] += other[key]
chunk[key] = chunk[key] + other[key]
return chunk

def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
Expand All @@ -1760,7 +1768,7 @@ def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
chunk[key] += self[key]
chunk[key] = chunk[key] + self[key]
return chunk


Expand Down Expand Up @@ -2061,6 +2069,139 @@ async def input_aiter() -> AsyncIterator[Input]:
yield chunk


class RunnableGenerator(Runnable[Input, Output]):
"""
A runnable that runs a generator function.
"""

def __init__(
self,
transform: Union[
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
],
atransform: Optional[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
] = None,
) -> None:
if atransform is not None:
self._atransform = atransform

if inspect.isasyncgenfunction(transform):
self._atransform = transform
elif inspect.isgeneratorfunction(transform):
self._transform = transform
else:
raise TypeError(
"Expected a generator function type for `transform`."
f"Instead got an unsupported type: {type(transform)}"
)

@property
def InputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
if first_param and first_param.annotation != inspect.Parameter.empty:
return getattr(first_param.annotation, "__args__", (Any,))[0]
else:
return Any
except ValueError:
return Any

@property
def OutputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
sig = inspect.signature(func)
return (
getattr(sig.return_annotation, "__args__", (Any,))[0]
if sig.return_annotation != inspect.Signature.empty
else Any
)
except ValueError:
return Any

def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
return self._transform == other._transform
elif hasattr(self, "_atransform") and hasattr(other, "_atransform"):
return self._atransform == other._atransform
else:
return False
else:
return False

def __repr__(self) -> str:
return "RunnableGenerator(...)"

def transform(
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Output]:
return self._transform_stream_with_config(
input, self._transform, config, **kwargs
)

def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Output]:
return self.transform(iter([input]), config, **kwargs)

def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
final = None
for output in self.stream(input, config, **kwargs):
if final is None:
final = output
else:
final = final + output
return cast(Output, final)

def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Output]:
if not hasattr(self, "_atransform"):
raise NotImplementedError("This runnable does not support async methods.")

return self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
)

def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Output]:
async def input_aiter() -> AsyncIterator[Input]:
yield input

return self.atransform(input_aiter(), config, **kwargs)

async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
final = None
async for output in self.astream(input, config, **kwargs):
if final is None:
final = output
else:
final = final + output
return cast(Output, final)


class RunnableLambda(Runnable[Input, Output]):
"""
A runnable that runs a callable.
Expand Down Expand Up @@ -2538,15 +2679,19 @@ async def atransform(
Runnable[Input, Output],
Callable[[Input], Output],
Callable[[Input], Awaitable[Output]],
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
Mapping[str, Any],
]


def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
if isinstance(thing, Runnable):
return thing
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
return RunnableGenerator(thing)
elif callable(thing):
return RunnableLambda(thing)
return RunnableLambda(cast(Callable[[Input], Output], thing))
elif isinstance(thing, dict):
runnables: Mapping[str, Runnable[Any, Any]] = {
key: coerce_to_runnable(r) for key, r in thing.items()
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/langchain/schema/runnable/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def call_func_with_variable_args(
input: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func):
Expand All @@ -174,9 +174,9 @@ async def acall_func_with_variable_args(
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func):
Expand Down
Loading