diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 7ee4be6960256..9d45d86551fbc 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -453,6 +453,7 @@ def _call_with_config( input: Input, config: Optional[RunnableConfig], run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" @@ -465,7 +466,9 @@ def _call_with_config( name=config.get("run_name"), ) try: - output = call_func_with_variable_args(func, input, run_manager, config) + output = call_func_with_variable_args( + func, input, run_manager, config, **kwargs + ) except BaseException as e: run_manager.on_chain_error(e) raise @@ -486,6 +489,7 @@ async def _acall_with_config( input: Input, config: Optional[RunnableConfig], run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" @@ -499,7 +503,7 @@ async def _acall_with_config( ) try: output = await acall_func_with_variable_args( - func, input, run_manager, config + func, input, run_manager, config, **kwargs ) except BaseException as e: await run_manager.on_chain_error(e) @@ -526,6 +530,7 @@ def _batch_with_config( *, return_exceptions: bool = False, run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> List[Output]: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" @@ -546,7 +551,6 @@ def _batch_with_config( ) ] try: - kwargs: Dict[str, Any] = {} if accepts_config(func): kwargs["config"] = [ patch_config(c, callbacks=rm.get_child()) @@ -597,6 +601,7 @@ async def _abatch_with_config( *, return_exceptions: bool = False, run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> List[Output]: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" @@ -619,7 +624,6 @@ async def _abatch_with_config( ) ) try: - kwargs: Dict[str, Any] = {} if accepts_config(func): kwargs["config"] = [ patch_config(c, callbacks=rm.get_child()) @@ -668,6 +672,7 @@ def _transform_stream_with_config( ], config: Optional[RunnableConfig], run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> Iterator[Output]: """Helper method to transform an Iterator of Input values into an Iterator of Output values, with callbacks. @@ -689,7 +694,6 @@ def _transform_stream_with_config( name=config.get("run_name"), ) try: - kwargs: Dict[str, Any] = {} if accepts_config(transformer): kwargs["config"] = patch_config( config, callbacks=run_manager.get_child() @@ -746,6 +750,7 @@ async def _atransform_stream_with_config( ], config: Optional[RunnableConfig], run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> AsyncIterator[Output]: """Helper method to transform an Async Iterator of Input values into an Async Iterator of Output values, with callbacks. @@ -767,7 +772,6 @@ async def _atransform_stream_with_config( name=config.get("run_name"), ) try: - kwargs: Dict[str, Any] = {} if accepts_config(transformer): kwargs["config"] = patch_config( config, callbacks=run_manager.get_child() @@ -2061,6 +2065,139 @@ async def input_aiter() -> AsyncIterator[Input]: yield chunk +class RunnableGenerator(Runnable[Input, Output]): + """ + A runnable that runs a generator function. + """ + + def __init__( + self, + transform: Union[ + Callable[[Iterator[Input]], Iterator[Output]], + Callable[[AsyncIterator[Input]], AsyncIterator[Output]], + ], + atransform: Optional[ + Callable[[AsyncIterator[Input]], AsyncIterator[Output]] + ] = None, + ) -> None: + if atransform is not None: + self._atransform = atransform + + if inspect.isasyncgenfunction(transform): + self._atransform = transform + elif inspect.isgeneratorfunction(transform): + self._transform = transform + else: + raise TypeError( + "Expected a generator function type for `transform`." + f"Instead got an unsupported type: {type(transform)}" + ) + + @property + def InputType(self) -> Any: + func = getattr(self, "_transform", None) or getattr(self, "_atransform") + try: + params = inspect.signature(func).parameters + first_param = next(iter(params.values()), None) + if first_param and first_param.annotation != inspect.Parameter.empty: + return first_param.annotation + else: + return Any + except ValueError: + return Any + + @property + def OutputType(self) -> Type[Output]: + func = getattr(self, "_transform", None) or getattr(self, "_atransform") + try: + sig = inspect.signature(func) + return ( + sig.return_annotation + if sig.return_annotation != inspect.Signature.empty + else Any + ) + except ValueError: + return Any + + def __eq__(self, other: Any) -> bool: + if isinstance(other, RunnableGenerator): + if hasattr(self, "_transform") and hasattr(other, "_transform"): + return self._transform == other._transform + elif hasattr(self, "_atransform") and hasattr(other, "_atransform"): + return self._atransform == other._atransform + else: + return False + else: + return False + + def __repr__(self) -> str: + return "RunnableGenerator(...)" + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any | None, + ) -> Iterator[Output]: + return self._transform_stream_with_config( + input, self._transform, config, **kwargs + ) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Any | None, + ) -> Iterator[Output]: + return self.transform(iter([input]), config, **kwargs) + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + final = None + for output in self.stream(input, config, **kwargs): + if final is None: + final = output + else: + final += output + return final + + async def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any | None, + ) -> AsyncIterator[Output]: + if not hasattr(self, "_atransform"): + raise NotImplementedError("This runnable does not support async methods.") + + return self._atransform_stream_with_config( + input, self._atransform, config, **kwargs + ) + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Any | None, + ) -> AsyncIterator[Output]: + async def input_aiter() -> AsyncIterator[Input]: + yield input + + return self.atransform(input_aiter(), config, **kwargs) + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Output: + final = None + async for output in self.astream(input, config): + if final is None: + final = output + else: + final += output + return final + + class RunnableLambda(Runnable[Input, Output]): """ A runnable that runs a callable. @@ -2538,6 +2675,8 @@ async def atransform( Runnable[Input, Output], Callable[[Input], Output], Callable[[Input], Awaitable[Output]], + Callable[[Iterator[Input]], Iterator[Output]], + Callable[[AsyncIterator[Input]], AsyncIterator[Output]], Mapping[str, Any], ] @@ -2545,6 +2684,8 @@ async def atransform( def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: if isinstance(thing, Runnable): return thing + elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing): + return RunnableGenerator(thing) elif callable(thing): return RunnableLambda(thing) elif isinstance(thing, dict): diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 6ae120ad7f3ab..06d979cff082c 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -152,9 +152,9 @@ def call_func_with_variable_args( input: Input, run_manager: CallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: """Call function that may optionally accept a run_manager and/or config.""" - kwargs: Dict[str, Any] = {} if accepts_config(func): kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) if accepts_run_manager(func): @@ -174,9 +174,9 @@ async def acall_func_with_variable_args( input: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: """Call function that may optionally accept a run_manager and/or config.""" - kwargs: Dict[str, Any] = {} if accepts_config(func): kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) if accepts_run_manager(func):