-
Notifications
You must be signed in to change notification settings - Fork 16k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement RunnablePassthrough.assign(...) #11222
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,7 @@ | |
patch_config, | ||
) | ||
from langchain.schema.runnable.utils import ( | ||
AddableDict, | ||
Input, | ||
Output, | ||
accepts_config, | ||
|
@@ -1748,30 +1749,6 @@ async def input_aiter() -> AsyncIterator[Input]: | |
yield chunk | ||
|
||
|
||
class RunnableMapChunk(Dict[str, Any]): | ||
""" | ||
Partial output from a RunnableMap | ||
""" | ||
|
||
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk: | ||
chunk = RunnableMapChunk(self) | ||
for key in other: | ||
if key not in chunk or chunk[key] is None: | ||
chunk[key] = other[key] | ||
elif other[key] is not None: | ||
chunk[key] = chunk[key] + other[key] | ||
return chunk | ||
|
||
def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk: | ||
chunk = RunnableMapChunk(other) | ||
for key in self: | ||
if key not in chunk or chunk[key] is None: | ||
chunk[key] = self[key] | ||
elif self[key] is not None: | ||
chunk[key] = chunk[key] + self[key] | ||
return chunk | ||
|
||
|
||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): | ||
""" | ||
A runnable that runs a mapping of runnables in parallel, | ||
|
@@ -1814,14 +1791,18 @@ def InputType(self) -> Any: | |
|
||
@property | ||
def input_schema(self) -> type[BaseModel]: | ||
if all(not s.input_schema.__custom_root_type__ for s in self.steps.values()): | ||
if all( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works in more cases |
||
s.input_schema.schema().get("type", "object") == "object" | ||
for s in self.steps.values() | ||
): | ||
# This is correct, but pydantic typings/mypy don't think so. | ||
return create_model( # type: ignore[call-overload] | ||
"RunnableMapInput", | ||
**{ | ||
k: (v.type_, v.default) | ||
for step in self.steps.values() | ||
for k, v in step.input_schema.__fields__.items() | ||
if k != "__root__" | ||
}, | ||
) | ||
|
||
|
@@ -1934,7 +1915,7 @@ def _transform( | |
input: Iterator[Input], | ||
run_manager: CallbackManagerForChainRun, | ||
config: RunnableConfig, | ||
) -> Iterator[RunnableMapChunk]: | ||
) -> Iterator[AddableDict]: | ||
# Shallow copy steps to ignore mutations while in progress | ||
steps = dict(self.steps) | ||
# Each step gets a copy of the input iterator, | ||
|
@@ -1967,7 +1948,7 @@ def _transform( | |
for future in completed_futures: | ||
(step_name, generator) = futures.pop(future) | ||
try: | ||
chunk = RunnableMapChunk({step_name: future.result()}) | ||
chunk = AddableDict({step_name: future.result()}) | ||
yield chunk | ||
futures[executor.submit(next, generator)] = ( | ||
step_name, | ||
|
@@ -1999,7 +1980,7 @@ async def _atransform( | |
input: AsyncIterator[Input], | ||
run_manager: AsyncCallbackManagerForChainRun, | ||
config: RunnableConfig, | ||
) -> AsyncIterator[RunnableMapChunk]: | ||
) -> AsyncIterator[AddableDict]: | ||
# Shallow copy steps to ignore mutations while in progress | ||
steps = dict(self.steps) | ||
# Each step gets a copy of the input iterator, | ||
|
@@ -2038,7 +2019,7 @@ async def get_next_chunk(generator: AsyncIterator) -> Optional[Output]: | |
for task in completed_tasks: | ||
(step_name, generator) = tasks.pop(task) | ||
try: | ||
chunk = RunnableMapChunk({step_name: task.result()}) | ||
chunk = AddableDict({step_name: task.result()}) | ||
yield chunk | ||
new_task = asyncio.create_task(get_next_chunk(generator)) | ||
tasks[new_task] = (step_name, generator) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,28 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, AsyncIterator, Iterator, List, Optional, Type | ||
import asyncio | ||
import threading | ||
from typing import ( | ||
Any, | ||
AsyncIterator, | ||
Callable, | ||
Dict, | ||
Iterator, | ||
List, | ||
Mapping, | ||
Optional, | ||
Type, | ||
Union, | ||
cast, | ||
) | ||
|
||
from langchain.load.serializable import Serializable | ||
from langchain.schema.runnable.base import Input, Runnable | ||
from langchain.schema.runnable.config import RunnableConfig | ||
from langchain.pydantic_v1 import BaseModel, create_model | ||
from langchain.schema.runnable.base import Input, Runnable, RunnableMap | ||
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config | ||
from langchain.schema.runnable.utils import AddableDict | ||
from langchain.utils.aiter import atee, py_anext | ||
from langchain.utils.iter import safetee | ||
|
||
|
||
def identity(x: Input) -> Input: | ||
|
@@ -38,6 +56,30 @@ def InputType(self) -> Any: | |
def OutputType(self) -> Any: | ||
return self.input_type or Any | ||
|
||
@classmethod | ||
def assign( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would we use that for? |
||
cls, | ||
**kwargs: Union[ | ||
Runnable[Dict[str, Any], Any], | ||
Callable[[Dict[str, Any]], Any], | ||
Mapping[ | ||
str, | ||
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]], | ||
], | ||
], | ||
) -> RunnableAssign: | ||
""" | ||
Merge the Dict input with the output produced by the mapping argument. | ||
|
||
Args: | ||
mapping: A mapping from keys to runnables or callables. | ||
|
||
Returns: | ||
A runnable that merges the Dict input with the output produced by the | ||
mapping argument. | ||
""" | ||
return RunnableAssign(RunnableMap(kwargs)) | ||
|
||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: | ||
return self._call_with_config(identity, input, config) | ||
|
||
|
@@ -65,3 +107,155 @@ async def atransform( | |
) -> AsyncIterator[Input]: | ||
async for chunk in self._atransform_stream_with_config(input, identity, config): | ||
yield chunk | ||
|
||
|
||
class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]): | ||
""" | ||
A runnable that assigns key-value pairs to Dict[str, Any] inputs. | ||
""" | ||
|
||
mapper: RunnableMap[Dict[str, Any]] | ||
|
||
def __init__(self, mapper: RunnableMap[Dict[str, Any]], **kwargs: Any) -> None: | ||
super().__init__(mapper=mapper, **kwargs) | ||
|
||
@classmethod | ||
def is_lc_serializable(cls) -> bool: | ||
return True | ||
|
||
@classmethod | ||
def get_lc_namespace(cls) -> List[str]: | ||
return cls.__module__.split(".")[:-1] | ||
|
||
@property | ||
def input_schema(self) -> type[BaseModel]: | ||
map_input_schema = self.mapper.input_schema | ||
if not map_input_schema.__custom_root_type__: | ||
# ie. it's a dict | ||
return map_input_schema | ||
|
||
return super().input_schema | ||
|
||
@property | ||
def output_schema(self) -> type[BaseModel]: | ||
map_input_schema = self.mapper.input_schema | ||
map_output_schema = self.mapper.output_schema | ||
if ( | ||
not map_input_schema.__custom_root_type__ | ||
and not map_output_schema.__custom_root_type__ | ||
): | ||
# ie. both are dicts | ||
return create_model( # type: ignore[call-overload] | ||
"RunnableAssignOutput", | ||
**{ | ||
k: (v.type_, v.default) | ||
for s in (map_input_schema, map_output_schema) | ||
for k, v in s.__fields__.items() | ||
}, | ||
) | ||
|
||
return super().output_schema | ||
|
||
def invoke( | ||
self, | ||
input: Dict[str, Any], | ||
config: Optional[RunnableConfig] = None, | ||
**kwargs: Any, | ||
) -> Dict[str, Any]: | ||
assert isinstance(input, dict) | ||
return { | ||
**input, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that this is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this is on purpose, so you can eg modify an existing key There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same semantics as JS Object.assign |
||
**self.mapper.invoke(input, config, **kwargs), | ||
} | ||
|
||
async def ainvoke( | ||
self, | ||
input: Dict[str, Any], | ||
config: Optional[RunnableConfig] = None, | ||
**kwargs: Any, | ||
) -> Dict[str, Any]: | ||
assert isinstance(input, dict) | ||
return { | ||
**input, | ||
**await self.mapper.ainvoke(input, config, **kwargs), | ||
} | ||
|
||
def transform( | ||
self, | ||
input: Iterator[Dict[str, Any]], | ||
config: Optional[RunnableConfig] = None, | ||
**kwargs: Any, | ||
) -> Iterator[Dict[str, Any]]: | ||
# collect mapper keys | ||
mapper_keys = set(self.mapper.steps.keys()) | ||
# create two streams, one for the map and one for the passthrough | ||
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock()) | ||
# create map output stream | ||
map_output = self.mapper.transform(for_map, config, **kwargs) | ||
# get executor to start map output stream in background | ||
with get_executor_for_config(config or {}) as executor: | ||
# start map output stream | ||
first_map_chunk_future = executor.submit(next, map_output) # type: ignore | ||
# consume passthrough stream | ||
for chunk in for_passthrough: | ||
assert isinstance(chunk, dict) | ||
# remove mapper keys from passthrough chunk, to be overwritten by map | ||
filtered = AddableDict( | ||
{k: v for k, v in chunk.items() if k not in mapper_keys} | ||
) | ||
if filtered: | ||
yield filtered | ||
# yield map output | ||
yield cast(Dict[str, Any], first_map_chunk_future.result()) | ||
for chunk in map_output: | ||
yield chunk | ||
|
||
async def atransform( | ||
self, | ||
input: AsyncIterator[Dict[str, Any]], | ||
config: Optional[RunnableConfig] = None, | ||
**kwargs: Any, | ||
) -> AsyncIterator[Dict[str, Any]]: | ||
# collect mapper keys | ||
mapper_keys = set(self.mapper.steps.keys()) | ||
# create two streams, one for the map and one for the passthrough | ||
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock()) | ||
# create map output stream | ||
map_output = self.mapper.atransform(for_map, config, **kwargs) | ||
# start map output stream | ||
first_map_chunk_task: asyncio.Task = asyncio.create_task( | ||
py_anext(map_output), # type: ignore[arg-type] | ||
) | ||
# consume passthrough stream | ||
async for chunk in for_passthrough: | ||
assert isinstance(chunk, dict) | ||
# remove mapper keys from passthrough chunk, to be overwritten by map output | ||
filtered = AddableDict( | ||
{k: v for k, v in chunk.items() if k not in mapper_keys} | ||
) | ||
if filtered: | ||
yield filtered | ||
# yield map output | ||
yield await first_map_chunk_task | ||
async for chunk in map_output: | ||
yield chunk | ||
|
||
def stream( | ||
self, | ||
input: Dict[str, Any], | ||
config: Optional[RunnableConfig] = None, | ||
**kwargs: Any, | ||
) -> Iterator[Dict[str, Any]]: | ||
return self.transform(iter([input]), config, **kwargs) | ||
|
||
async def astream( | ||
self, | ||
input: Dict[str, Any], | ||
config: Optional[RunnableConfig] = None, | ||
**kwargs: Any, | ||
) -> AsyncIterator[Dict[str, Any]]: | ||
async def input_aiter() -> AsyncIterator[Dict[str, Any]]: | ||
yield input | ||
|
||
async for chunk in self.atransform(input_aiter(), config, **kwargs): | ||
yield chunk |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to utils