Skip to content

Commit

Permalink
Add RunnableGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Sep 29, 2023
1 parent db05ea2 commit a6996c8
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 8 deletions.
153 changes: 147 additions & 6 deletions libs/langchain/langchain/schema/runnable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def _call_with_config(
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
Expand All @@ -463,7 +464,9 @@ def _call_with_config(
name=config.get("run_name"),
)
try:
output = call_func_with_variable_args(func, input, run_manager, config)
output = call_func_with_variable_args(
func, input, run_manager, config, **kwargs
)
except BaseException as e:
run_manager.on_chain_error(e)
raise
Expand All @@ -484,6 +487,7 @@ async def _acall_with_config(
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses."""
Expand All @@ -497,7 +501,7 @@ async def _acall_with_config(
)
try:
output = await acall_func_with_variable_args(
func, input, run_manager, config
func, input, run_manager, config, **kwargs
)
except BaseException as e:
await run_manager.on_chain_error(e)
Expand All @@ -524,6 +528,7 @@ def _batch_with_config(
*,
return_exceptions: bool = False,
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
Expand All @@ -544,7 +549,6 @@ def _batch_with_config(
)
]
try:
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = [
patch_config(c, callbacks=rm.get_child())
Expand Down Expand Up @@ -595,6 +599,7 @@ async def _abatch_with_config(
*,
return_exceptions: bool = False,
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
Expand All @@ -617,7 +622,6 @@ async def _abatch_with_config(
)
)
try:
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = [
patch_config(c, callbacks=rm.get_child())
Expand Down Expand Up @@ -666,6 +670,7 @@ def _transform_stream_with_config(
],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
"""Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks.
Expand All @@ -687,7 +692,6 @@ def _transform_stream_with_config(
name=config.get("run_name"),
)
try:
kwargs: Dict[str, Any] = {}
if accepts_config(transformer):
kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child()
Expand Down Expand Up @@ -744,6 +748,7 @@ async def _atransform_stream_with_config(
],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
"""Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks.
Expand All @@ -765,7 +770,6 @@ async def _atransform_stream_with_config(
name=config.get("run_name"),
)
try:
kwargs: Dict[str, Any] = {}
if accepts_config(transformer):
kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child()
Expand Down Expand Up @@ -2046,6 +2050,139 @@ async def input_aiter() -> AsyncIterator[Input]:
yield chunk


class RunnableGenerator(Runnable[Input, Output]):
"""
A runnable that runs a generator function.
"""

def __init__(
self,
transform: Union[
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
],
atransform: Optional[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
] = None,
) -> None:
if atransform is not None:
self._atransform = atransform

if inspect.isasyncgenfunction(transform):
self._atransform = transform
elif inspect.isgeneratorfunction(transform):
self._transform = transform
else:
raise TypeError(
"Expected a generator function type for `transform`."
f"Instead got an unsupported type: {type(transform)}"
)

@property
def InputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
if first_param and first_param.annotation != inspect.Parameter.empty:
return first_param.annotation
else:
return Any
except ValueError:
return Any

@property
def OutputType(self) -> Type[Output]:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
sig = inspect.signature(func)
return (
sig.return_annotation
if sig.return_annotation != inspect.Signature.empty
else Any
)
except ValueError:
return Any

def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
return self._transform == other._transform
elif hasattr(self, "_atransform") and hasattr(other, "_atransform"):
return self._atransform == other._atransform
else:
return False
else:
return False

def __repr__(self) -> str:
return "RunnableGenerator(...)"

def transform(
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> Iterator[Output]:
return self._transform_stream_with_config(
input, self._transform, config, **kwargs
)

def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> Iterator[Output]:
return self.transform(iter([input]), config, **kwargs)

def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
final = None
for output in self.stream(input, config, **kwargs):
if final is None:
final = output
else:
final += output
return final

async def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> AsyncIterator[Output]:
if not hasattr(self, "_atransform"):
raise NotImplementedError("This runnable does not support async methods.")

return self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
)

async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> AsyncIterator[Output]:
async def input_aiter() -> AsyncIterator[Input]:
yield input

return self.atransform(input_aiter(), config, **kwargs)

async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
final = None
async for output in self.astream(input, config):
if final is None:
final = output
else:
final += output
return final


class RunnableLambda(Runnable[Input, Output]):
"""
A runnable that runs a callable.
Expand Down Expand Up @@ -2523,13 +2660,17 @@ async def atransform(
Runnable[Input, Output],
Callable[[Input], Output],
Callable[[Input], Awaitable[Output]],
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
Mapping[str, Any],
]


def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
if isinstance(thing, Runnable):
return thing
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
return RunnableGenerator(thing)
elif callable(thing):
return RunnableLambda(thing)
elif isinstance(thing, dict):
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/langchain/schema/runnable/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def call_func_with_variable_args(
input: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func):
Expand All @@ -174,9 +174,9 @@ async def acall_func_with_variable_args(
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func):
Expand Down

0 comments on commit a6996c8

Please sign in to comment.