From f1aacf68a8dff579fbc204116d4c7ddeb47e7e1c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 9 Jul 2024 17:52:29 -0700 Subject: [PATCH 1/9] lib: Checkpoint pending writes whenever a node finishes - Whenever a node finishes, checkpoint pending writes --- .../langgraph/checkpoint/aiosqlite.py | 43 ++++++++- libs/langgraph/langgraph/checkpoint/base.py | 18 ++++ libs/langgraph/langgraph/checkpoint/memory.py | 41 +++++++- libs/langgraph/langgraph/checkpoint/sqlite.py | 31 ++++++ libs/langgraph/langgraph/pregel/__init__.py | 96 ++++++++++++++----- libs/langgraph/langgraph/pregel/debug.py | 6 +- libs/langgraph/langgraph/pregel/io.py | 4 +- libs/langgraph/langgraph/pregel/types.py | 1 + libs/langgraph/tests/test_pregel.py | 11 +++ libs/langgraph/tests/test_pregel_async.py | 20 ++++ 10 files changed, 240 insertions(+), 31 deletions(-) diff --git a/libs/langgraph/langgraph/checkpoint/aiosqlite.py b/libs/langgraph/langgraph/checkpoint/aiosqlite.py index b636e02d3..f412162a1 100644 --- a/libs/langgraph/langgraph/checkpoint/aiosqlite.py +++ b/libs/langgraph/langgraph/checkpoint/aiosqlite.py @@ -2,7 +2,16 @@ import functools from contextlib import AbstractAsyncContextManager from types import TracebackType -from typing import Any, AsyncIterator, Dict, Iterator, Optional, TypeVar +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + Optional, + Sequence, + Tuple, + TypeVar, +) import aiosqlite from langchain_core.runnables import RunnableConfig @@ -203,6 +212,15 @@ async def setup(self) -> None: metadata BLOB, PRIMARY KEY (thread_id, thread_ts) ); + CREATE TABLE IF NOT EXISTS writes ( + thread_id TEXT NOT NULL, + thread_ts TEXT NOT NULL, + task_id TEXT NOT NULL, + idx INTEGER NOT NULL, + channel TEXT NOT NULL, + value BLOB, + PRIMARY KEY (thread_id, thread_ts, task_id, idx) + ); """ ): await self.conn.commit() @@ -358,3 +376,26 @@ async def aput( "thread_ts": checkpoint["id"], } } + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + ) -> None: + await self.setup() + async with self.conn.executemany( + "INSERT OR REPLACE INTO writes (thread_id, thread_ts, task_id, idx, channel, value) VALUES (?, ?, ?, ?, ?, ?)", + [ + ( + str(config["configurable"]["thread_id"]), + str(config["configurable"]["thread_ts"]), + task_id, + idx, + channel, + self.serde.dumps(value), + ) + for idx, (channel, value) in enumerate(writes) + ], + ): + await self.conn.commit() diff --git a/libs/langgraph/langgraph/checkpoint/base.py b/libs/langgraph/langgraph/checkpoint/base.py index fab38fa99..f9e5f5f48 100644 --- a/libs/langgraph/langgraph/checkpoint/base.py +++ b/libs/langgraph/langgraph/checkpoint/base.py @@ -10,6 +10,7 @@ Literal, NamedTuple, Optional, + Tuple, TypedDict, TypeVar, Union, @@ -117,6 +118,7 @@ class CheckpointTuple(NamedTuple): checkpoint: Checkpoint metadata: CheckpointMetadata parent_config: Optional[RunnableConfig] = None + pending_writes: Optional[List[Tuple[str, Any]]] = None CheckpointThreadId = ConfigurableFieldSpec( @@ -177,6 +179,14 @@ def put( ) -> RunnableConfig: raise NotImplementedError + def put_writes( + self, + config: RunnableConfig, + writes: List[Tuple[str, Any]], + task_id: str, + ) -> None: + raise NotImplementedError + async def aget(self, config: RunnableConfig) -> Optional[Checkpoint]: if value := await self.aget_tuple(config): return value.checkpoint @@ -203,6 +213,14 @@ async def aput( ) -> RunnableConfig: raise NotImplementedError + async def aput_writes( + self, + config: RunnableConfig, + writes: List[Tuple[str, Any]], + task_id: str, + ) -> None: + raise NotImplementedError + def get_next_version(self, current: Optional[V], channel: BaseChannel) -> V: """Get the next version of a channel. Default is to use int versions, incrementing by 1. If you override, you can use str/int/float versions, as long as they are monotonically increasing.""" diff --git a/libs/langgraph/langgraph/checkpoint/memory.py b/libs/langgraph/langgraph/checkpoint/memory.py index 5f87aa736..cc5bf8ca1 100644 --- a/libs/langgraph/langgraph/checkpoint/memory.py +++ b/libs/langgraph/langgraph/checkpoint/memory.py @@ -1,7 +1,7 @@ import asyncio from collections import defaultdict from functools import partial -from typing import Any, AsyncIterator, Dict, Iterator, Optional +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple from langchain_core.runnables import RunnableConfig @@ -53,6 +53,7 @@ def __init__( ) -> None: super().__init__(serde=serde) self.storage = defaultdict(dict) + self.writes = defaultdict(list) def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the in-memory storage. @@ -72,19 +73,23 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: if ts := config["configurable"].get("thread_ts"): if saved := self.storage[thread_id].get(ts): checkpoint, metadata = saved + writes = self.writes[(thread_id, ts)] return CheckpointTuple( config=config, checkpoint=self.serde.loads(checkpoint), metadata=self.serde.loads(metadata), + pending_writes=[(c, self.serde.loads(v)) for c, v in writes], ) else: if checkpoints := self.storage[thread_id]: ts = max(checkpoints.keys()) checkpoint, metadata = checkpoints[ts] + writes = self.writes[(thread_id, ts)] return CheckpointTuple( config={"configurable": {"thread_id": thread_id, "thread_ts": ts}}, checkpoint=self.serde.loads(checkpoint), metadata=self.serde.loads(metadata), + pending_writes=[(c, self.serde.loads(v)) for c, v in writes], ) def list( @@ -168,6 +173,30 @@ def put( } } + def put_writes( + self, + config: RunnableConfig, + writes: List[Tuple[str, Any]], + task_id: str, + ) -> RunnableConfig: + """Save a list of writes to the in-memory storage. + + This method saves a list of writes to the in-memory storage. The writes are associated + with the provided config. + + Args: + config (RunnableConfig): The config to associate with the writes. + writes (list[tuple[str, Any]]): The writes to save. + + Returns: + RunnableConfig: The updated config containing the saved writes' timestamp. + """ + thread_id = config["configurable"]["thread_id"] + ts = config["configurable"]["thread_ts"] + self.writes[(thread_id, ts)].extend( + [(c, self.serde.dumps(v)) for c, v in writes] + ) + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Asynchronous version of get_tuple. @@ -224,3 +253,13 @@ async def aput( return await asyncio.get_running_loop().run_in_executor( None, self.put, config, checkpoint, metadata ) + + async def aput_writes( + self, + config: RunnableConfig, + writes: List[Tuple[str, Any]], + task_id: str, + ) -> RunnableConfig: + return await asyncio.get_running_loop().run_in_executor( + None, self.put_writes, config, writes, task_id + ) diff --git a/libs/langgraph/langgraph/checkpoint/sqlite.py b/libs/langgraph/langgraph/checkpoint/sqlite.py index 6a4c35c85..ef1a1b309 100644 --- a/libs/langgraph/langgraph/checkpoint/sqlite.py +++ b/libs/langgraph/langgraph/checkpoint/sqlite.py @@ -171,6 +171,15 @@ def setup(self) -> None: metadata BLOB, PRIMARY KEY (thread_id, thread_ts) ); + CREATE TABLE IF NOT EXISTS writes ( + thread_id TEXT NOT NULL, + thread_ts TEXT NOT NULL, + task_id TEXT NOT NULL, + idx INTEGER NOT NULL, + channel TEXT NOT NULL, + value BLOB, + PRIMARY KEY (thread_id, thread_ts, task_id, idx) + ); """ ) @@ -394,6 +403,28 @@ def put( } } + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + ) -> None: + with self.lock, self.cursor() as cur: + cur.executemany( + "INSERT OR REPLACE INTO writes (thread_id, thread_ts, task_id, idx, channel, value) VALUES (?, ?, ?, ?, ?, ?)", + [ + ( + str(config["configurable"]["thread_id"]), + str(config["configurable"]["thread_ts"]), + task_id, + idx, + channel, + self.serde.dumps(value), + ) + for idx, (channel, value) in enumerate(writes) + ], + ) + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database asynchronously. diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index a86674ca3..2fde0d1bd 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -2,6 +2,7 @@ import asyncio import concurrent.futures +import json import time from collections import defaultdict, deque from functools import partial @@ -23,6 +24,7 @@ get_type_hints, overload, ) +from uuid import UUID, uuid5 from langchain_core.callbacks.manager import AsyncParentRunManager, ParentRunManager from langchain_core.globals import get_debug @@ -442,7 +444,7 @@ def get_state_history( and signature(self.checkpointer.list).parameters.get("filter") is None ): raise ValueError("Checkpointer does not support filtering") - for config, checkpoint, metadata, parent_config in self.checkpointer.list( + for config, checkpoint, metadata, parent_config, _ in self.checkpointer.list( config, before=before, limit=limit, filter=filter ): with ChannelsManager( @@ -489,6 +491,7 @@ async def aget_state_history( checkpoint, metadata, parent_config, + _, ) in self.checkpointer.alist(config, before=before, limit=limit, filter=filter): async with AsyncChannelsManager( self.channels, checkpoint, config @@ -565,6 +568,7 @@ def update_state( deque(), None, [INTERRUPT], + str(uuid5(UUID(checkpoint["id"]), INTERRUPT)), ) # execute task task.proc.invoke( @@ -653,6 +657,7 @@ async def aupdate_state( deque(), None, [INTERRUPT], + str(uuid5(UUID(checkpoint["id"]), INTERRUPT)), ) # execute task await task.proc.ainvoke( @@ -879,6 +884,23 @@ def stream( self.managed_values_dict, config, self ) as managed: + def put_writes(id: str, writes: Sequence[tuple[str, Any]]) -> None: + if self.checkpointer is not None: + bg.append( + executor.submit( + self.checkpointer.put_writes, + { + **checkpoint_config, + "configurable": { + **checkpoint_config["configurable"], + "thread_ts": checkpoint["id"], + }, + }, + writes, + id, + ) + ) + def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: nonlocal checkpoint, checkpoint_config, channels @@ -1050,6 +1072,9 @@ def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: # exception will be handled in panic_or_proceed futures.clear() else: + # save task writes to checkpointer + if self.checkpointer is not None: + put_writes(task.id, task.writes) # yield updates output for the finished task if "updates" in stream_modes: yield from _with_mode( @@ -1076,7 +1101,7 @@ def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: # combine pending writes from all tasks pending_writes = deque[tuple[str, Any]]() - for _, _, _, writes, _, _ in next_tasks: + for _, _, _, writes, _, _, _ in next_tasks: pending_writes.extend(writes) if debug: @@ -1240,6 +1265,24 @@ async def astream( self.managed_values_dict, config, self ) as managed: + def put_writes(id: str, writes: Sequence[tuple[str, Any]]) -> None: + if self.checkpointer is not None: + bg.append( + asyncio.create_task( + self.checkpointer.aput_writes( + { + **checkpoint_config, + "configurable": { + **checkpoint_config["configurable"], + "thread_ts": checkpoint["id"], + }, + }, + writes, + id, + ) + ) + ) + def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: nonlocal checkpoint, checkpoint_config, channels @@ -1406,6 +1449,9 @@ def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: # exception will be handle in panic_or_proceed futures.clear() else: + # save task writes to checkpointer + if self.checkpointer is not None: + put_writes(task.id, task.writes) # yield updates output for the finished task if "updates" in stream_modes: for chunk in _with_mode( @@ -1434,7 +1480,7 @@ def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: # combine pending writes from all tasks pending_writes = deque[tuple[str, Any]]() - for _, _, _, writes, _, _ in next_tasks: + for _, _, _, writes, _, _, _ in next_tasks: pending_writes.extend(writes) if debug: @@ -1671,7 +1717,7 @@ def _should_interrupt( # and any triggered node is in interrupt_nodes list and any( node - for node, _, _, _, config, _ in tasks + for node, _, _, _, config, _, _ in tasks if ( (not config or TAG_HIDDEN not in config.get("tags")) if interrupt_nodes == "*" @@ -1825,6 +1871,14 @@ def _prepare_next_tasks( continue if for_execution: if node := processes[packet.node].get_node(): + triggers = [TASKS] + metadata = { + "langgraph_step": step, + "langgraph_node": packet.node, + "langgraph_triggers": triggers, + "langgraph_task_idx": len(tasks), + } + task_id = str(uuid5(UUID(checkpoint["id"]), json.dumps(metadata))) writes = deque() tasks.append( PregelExecutableTask( @@ -1836,14 +1890,7 @@ def _prepare_next_tasks( merge_configs( config, processes[packet.node].config, - { - "metadata": { - "langgraph_step": step, - "langgraph_node": packet.node, - "langgraph_triggers": [TASKS], - "langgraph_task_idx": len(tasks), - } - }, + {"metadata": metadata}, ), run_name=packet.node, callbacks=( @@ -1861,7 +1908,8 @@ def _prepare_next_tasks( ), }, ), - [TASKS], + triggers, + task_id, ) ) else: @@ -1879,7 +1927,7 @@ def _prepare_next_tasks( for name, proc in processes.items(): seen = checkpoint["versions_seen"][name] # If any of the channels read by this process were updated - if triggers := [ + if triggers := sorted( chan for chan in proc.triggers if not isinstance( @@ -1887,7 +1935,7 @@ def _prepare_next_tasks( ) and checkpoint["channel_versions"].get(chan, null_version) > seen.get(chan, null_version) - ]: + ): channels_to_consume.update(triggers) try: val = next(_proc_input(step, name, proc, managed, channels)) @@ -1906,8 +1954,14 @@ def _prepare_next_tasks( if for_execution: if node := proc.get_node(): + metadata = { + "langgraph_step": step, + "langgraph_node": name, + "langgraph_triggers": triggers, + "langgraph_task_idx": len(tasks), + } + task_id = str(uuid5(UUID(checkpoint["id"]), json.dumps(metadata))) writes = deque() - triggers = sorted(triggers) tasks.append( PregelExecutableTask( name, @@ -1918,14 +1972,7 @@ def _prepare_next_tasks( merge_configs( config, proc.config, - { - "metadata": { - "langgraph_step": step, - "langgraph_node": name, - "langgraph_triggers": triggers, - "langgraph_task_idx": len(tasks), - } - }, + {"metadata": metadata}, ), run_name=name, callbacks=( @@ -1948,6 +1995,7 @@ def _prepare_next_tasks( }, ), triggers, + task_id, ) ) else: diff --git a/libs/langgraph/langgraph/pregel/debug.py b/libs/langgraph/langgraph/pregel/debug.py index 7fb294903..0293690a7 100644 --- a/libs/langgraph/langgraph/pregel/debug.py +++ b/libs/langgraph/langgraph/pregel/debug.py @@ -66,7 +66,7 @@ def map_debug_tasks( step: int, tasks: list[PregelExecutableTask] ) -> Iterator[DebugOutputTask]: ts = datetime.now(timezone.utc).isoformat() - for name, input, _, _, config, triggers in tasks: + for name, input, _, _, config, triggers, _ in tasks: if config is not None and TAG_HIDDEN in config.get("tags", []): continue @@ -91,7 +91,7 @@ def map_debug_task_results( stream_channels_list: Sequence[str], ) -> Iterator[DebugOutputTaskResult]: ts = datetime.now(timezone.utc).isoformat() - for name, _, _, writes, config, _ in tasks: + for name, _, _, writes, config, _, _ in tasks: if config is not None and TAG_HIDDEN in config.get("tags", []): continue @@ -138,7 +138,7 @@ def print_step_tasks(step: int, next_tasks: list[PregelExecutableTask]) -> None: ) + "\n".join( f"- {get_colored_text(name, 'green')} -> {pformat(val)}" - for name, val, _, _, _, _ in next_tasks + for name, val, _, _, _, _, _ in next_tasks ) ) diff --git a/libs/langgraph/langgraph/pregel/io.py b/libs/langgraph/langgraph/pregel/io.py index fc1fc38ea..d33741697 100644 --- a/libs/langgraph/langgraph/pregel/io.py +++ b/libs/langgraph/langgraph/pregel/io.py @@ -105,7 +105,7 @@ def map_output_updates( if isinstance(output_channels, str): if updated := [ (node, value) - for node, _, _, writes, _, _ in output_tasks + for node, _, _, writes, _, _, _ in output_tasks for chan, value in writes if chan == output_channels ]: @@ -122,7 +122,7 @@ def map_output_updates( node, {chan: value for chan, value in writes if chan in output_channels}, ) - for node, _, _, writes, _, _ in output_tasks + for node, _, _, writes, _, _, _ in output_tasks if any(chan in output_channels for chan, _ in writes) ]: grouped = defaultdict(list) diff --git a/libs/langgraph/langgraph/pregel/types.py b/libs/langgraph/langgraph/pregel/types.py index e47d94b57..7d23eb009 100644 --- a/libs/langgraph/langgraph/pregel/types.py +++ b/libs/langgraph/langgraph/pregel/types.py @@ -18,6 +18,7 @@ class PregelExecutableTask(NamedTuple): writes: deque[tuple[str, Any]] config: RunnableConfig triggers: list[str] + id: str class StateSnapshot(NamedTuple): diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 39b339edf..0ed619735 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -14,6 +14,7 @@ Literal, Optional, Sequence, + Tuple, TypedDict, Union, ) @@ -193,6 +194,12 @@ def put( ) -> RunnableConfig: raise ValueError("Faulty put") + class FaultyPutWritesCheckpointer(MemorySaver): + def put_writes( + self, config: RunnableConfig, writes: List[Tuple[str, Any]], task_id: str + ) -> RunnableConfig: + raise ValueError("Faulty put_writes") + class FaultyVersionCheckpointer(MemorySaver): def get_next_version(self, current: Optional[int], channel: BaseChannel) -> int: raise ValueError("Faulty get_next_version") @@ -213,6 +220,10 @@ def logic(inp: str) -> str: with pytest.raises(ValueError, match="Faulty put"): graph.invoke("", {"configurable": {"thread_id": "thread-1"}}) + graph = builder.compile(checkpointer=FaultyPutWritesCheckpointer()) + with pytest.raises(ValueError, match="Faulty put_writes"): + graph.invoke("", {"configurable": {"thread_id": "thread-1"}}) + graph = builder.compile(checkpointer=FaultyVersionCheckpointer()) with pytest.raises(ValueError, match="Faulty get_next_version"): graph.invoke("", {"configurable": {"thread_id": "thread-1"}}) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index cf47292a1..6184ac09c 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -11,8 +11,10 @@ AsyncIterator, Dict, Generator, + List, Optional, Sequence, + Tuple, TypedDict, Union, ) @@ -75,6 +77,12 @@ async def aput( ) -> RunnableConfig: raise ValueError("Faulty put") + class FaultyPutWritesCheckpointer(MemorySaver): + async def aput_writes( + self, config: RunnableConfig, writes: List[Tuple[str, Any]], task_id: str + ) -> RunnableConfig: + raise ValueError("Faulty put_writes") + class FaultyVersionCheckpointer(MemorySaver): def get_next_version(self, current: Optional[int], channel: BaseChannel) -> int: raise ValueError("Faulty get_next_version") @@ -111,6 +119,18 @@ def logic(inp: str) -> str: ): pass + graph = builder.compile(checkpointer=FaultyPutWritesCheckpointer()) + with pytest.raises(ValueError, match="Faulty put_writes"): + await graph.ainvoke("", {"configurable": {"thread_id": "thread-1"}}) + with pytest.raises(ValueError, match="Faulty put_writes"): + async for _ in graph.astream("", {"configurable": {"thread_id": "thread-2"}}): + pass + with pytest.raises(ValueError, match="Faulty put_writes"): + async for _ in graph.astream_events( + "", {"configurable": {"thread_id": "thread-3"}}, version="v2" + ): + pass + graph = builder.compile(checkpointer=FaultyVersionCheckpointer()) with pytest.raises(ValueError, match="Faulty get_next_version"): await graph.ainvoke("", {"configurable": {"thread_id": "thread-1"}}) From fbe40cdda6a6db65d02dedee21743153c943ac74 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 9 Jul 2024 17:56:23 -0700 Subject: [PATCH 2/9] Rename arg --- libs/langgraph/langgraph/pregel/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 2fde0d1bd..2774800a0 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -884,7 +884,7 @@ def stream( self.managed_values_dict, config, self ) as managed: - def put_writes(id: str, writes: Sequence[tuple[str, Any]]) -> None: + def put_writes(task_id: str, writes: Sequence[tuple[str, Any]]) -> None: if self.checkpointer is not None: bg.append( executor.submit( @@ -897,7 +897,7 @@ def put_writes(id: str, writes: Sequence[tuple[str, Any]]) -> None: }, }, writes, - id, + task_id, ) ) @@ -1265,7 +1265,7 @@ async def astream( self.managed_values_dict, config, self ) as managed: - def put_writes(id: str, writes: Sequence[tuple[str, Any]]) -> None: + def put_writes(task_id: str, writes: Sequence[tuple[str, Any]]) -> None: if self.checkpointer is not None: bg.append( asyncio.create_task( @@ -1278,7 +1278,7 @@ def put_writes(id: str, writes: Sequence[tuple[str, Any]]) -> None: }, }, writes, - id, + task_id, ) ) ) From 83562c2cc52f916bab37d348a84743ac1d44eec9 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 10 Jul 2024 09:13:40 -0700 Subject: [PATCH 3/9] Add tests, resume from pending writes --- .../langgraph/checkpoint/aiosqlite.py | 94 ++++++++-------- libs/langgraph/langgraph/checkpoint/base.py | 2 +- libs/langgraph/langgraph/checkpoint/memory.py | 10 +- libs/langgraph/langgraph/checkpoint/sqlite.py | 69 ++++++------ libs/langgraph/langgraph/pregel/__init__.py | 30 +++++- libs/langgraph/tests/test_pregel.py | 90 ++++++++++++++++ libs/langgraph/tests/test_pregel_async.py | 101 +++++++++++++++++- 7 files changed, 306 insertions(+), 90 deletions(-) diff --git a/libs/langgraph/langgraph/checkpoint/aiosqlite.py b/libs/langgraph/langgraph/checkpoint/aiosqlite.py index f412162a1..431684d1a 100644 --- a/libs/langgraph/langgraph/checkpoint/aiosqlite.py +++ b/libs/langgraph/langgraph/checkpoint/aiosqlite.py @@ -242,56 +242,58 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. """ await self.setup() - if config["configurable"].get("thread_ts"): - async with self.conn.execute( - "SELECT checkpoint, parent_ts, metadata FROM checkpoints WHERE thread_id = ? AND thread_ts = ?", - ( - str(config["configurable"]["thread_id"]), - str(config["configurable"]["thread_ts"]), - ), - ) as cursor: - if value := await cursor.fetchone(): - return CheckpointTuple( - config, - self.serde.loads(value[0]), - self.serde.loads(value[2]) if value[2] is not None else {}, - ( - { - "configurable": { - "thread_id": config["configurable"]["thread_id"], - "thread_ts": value[1], - } - } - if value[1] - else None - ), - ) - else: - async with self.conn.execute( - "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ORDER BY thread_ts DESC LIMIT 1", - (str(config["configurable"]["thread_id"]),), - ) as cursor: - if value := await cursor.fetchone(): - return CheckpointTuple( + async with self.conn.cursor() as cur: + # find the latest checkpoint for the thread_id + if config["configurable"].get("thread_ts"): + await cur.execute( + "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND thread_ts = ?", + ( + str(config["configurable"]["thread_id"]), + str(config["configurable"]["thread_ts"]), + ), + ) + else: + await cur.execute( + "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ORDER BY thread_ts DESC LIMIT 1", + (str(config["configurable"]["thread_id"]),), + ) + # if a checkpoint is found, return it + if value := await cur.fetchone(): + if not config["configurable"].get("thread_ts"): + config = { + "configurable": { + "thread_id": value[0], + "thread_ts": value[1], + } + } + # find any pending writes + await cur.execute( + "SELECT task_id, channel, value FROM writes WHERE thread_id = ? AND thread_ts = ?", + ( + str(config["configurable"]["thread_id"]), + str(config["configurable"]["thread_ts"]), + ), + ) + # deserialize the checkpoint and metadata + return CheckpointTuple( + config, + self.serde.loads(value[3]), + self.serde.loads(value[4]) if value[4] is not None else {}, + ( { "configurable": { "thread_id": value[0], - "thread_ts": value[1], - } - }, - self.serde.loads(value[3]), - self.serde.loads(value[4]) if value[4] is not None else {}, - ( - { - "configurable": { - "thread_id": value[0], - "thread_ts": value[2], - } + "thread_ts": value[2], } - if value[2] - else None - ), - ) + } + if value[2] + else None + ), + [ + (task_id, channel, self.serde.loads(value)) + async for task_id, channel, value in cur + ], + ) async def alist( self, diff --git a/libs/langgraph/langgraph/checkpoint/base.py b/libs/langgraph/langgraph/checkpoint/base.py index f9e5f5f48..2cee04adc 100644 --- a/libs/langgraph/langgraph/checkpoint/base.py +++ b/libs/langgraph/langgraph/checkpoint/base.py @@ -118,7 +118,7 @@ class CheckpointTuple(NamedTuple): checkpoint: Checkpoint metadata: CheckpointMetadata parent_config: Optional[RunnableConfig] = None - pending_writes: Optional[List[Tuple[str, Any]]] = None + pending_writes: Optional[List[Tuple[str, str, Any]]] = None CheckpointThreadId = ConfigurableFieldSpec( diff --git a/libs/langgraph/langgraph/checkpoint/memory.py b/libs/langgraph/langgraph/checkpoint/memory.py index cc5bf8ca1..bd5dc6fd1 100644 --- a/libs/langgraph/langgraph/checkpoint/memory.py +++ b/libs/langgraph/langgraph/checkpoint/memory.py @@ -78,7 +78,9 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: config=config, checkpoint=self.serde.loads(checkpoint), metadata=self.serde.loads(metadata), - pending_writes=[(c, self.serde.loads(v)) for c, v in writes], + pending_writes=[ + (id, c, self.serde.loads(v)) for id, c, v in writes + ], ) else: if checkpoints := self.storage[thread_id]: @@ -89,7 +91,9 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: config={"configurable": {"thread_id": thread_id, "thread_ts": ts}}, checkpoint=self.serde.loads(checkpoint), metadata=self.serde.loads(metadata), - pending_writes=[(c, self.serde.loads(v)) for c, v in writes], + pending_writes=[ + (id, c, self.serde.loads(v)) for id, c, v in writes + ], ) def list( @@ -194,7 +198,7 @@ def put_writes( thread_id = config["configurable"]["thread_id"] ts = config["configurable"]["thread_ts"] self.writes[(thread_id, ts)].extend( - [(c, self.serde.dumps(v)) for c, v in writes] + [(task_id, c, self.serde.dumps(v)) for c, v in writes] ) async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: diff --git a/libs/langgraph/langgraph/checkpoint/sqlite.py b/libs/langgraph/langgraph/checkpoint/sqlite.py index ef1a1b309..6ac9ee593 100644 --- a/libs/langgraph/langgraph/checkpoint/sqlite.py +++ b/libs/langgraph/langgraph/checkpoint/sqlite.py @@ -242,56 +242,57 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: CheckpointTuple(...) """ # noqa with self.cursor(transaction=False) as cur: + # find the latest checkpoint for the thread_id if config["configurable"].get("thread_ts"): cur.execute( - "SELECT checkpoint, parent_ts, metadata FROM checkpoints WHERE thread_id = ? AND thread_ts = ?", + "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND thread_ts = ?", ( str(config["configurable"]["thread_id"]), str(config["configurable"]["thread_ts"]), ), ) - if value := cur.fetchone(): - return CheckpointTuple( - config, - self.serde.loads(value[0]), - self.serde.loads(value[2]) if value[2] is not None else {}, - ( - { - "configurable": { - "thread_id": config["configurable"]["thread_id"], - "thread_ts": value[1], - } - } - if value[1] - else None - ), - ) else: cur.execute( "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ORDER BY thread_ts DESC LIMIT 1", (str(config["configurable"]["thread_id"]),), ) - if value := cur.fetchone(): - return CheckpointTuple( + # if a checkpoint is found, return it + if value := cur.fetchone(): + if not config["configurable"].get("thread_ts"): + config = { + "configurable": { + "thread_id": value[0], + "thread_ts": value[1], + } + } + # find any pending writes + cur.execute( + "SELECT task_id, channel, value FROM writes WHERE thread_id = ? AND thread_ts = ?", + ( + str(config["configurable"]["thread_id"]), + str(config["configurable"]["thread_ts"]), + ), + ) + # deserialize the checkpoint and metadata + return CheckpointTuple( + config, + self.serde.loads(value[3]), + self.serde.loads(value[4]) if value[4] is not None else {}, + ( { "configurable": { "thread_id": value[0], - "thread_ts": value[1], + "thread_ts": value[2], } - }, - self.serde.loads(value[3]), - self.serde.loads(value[4]) if value[4] is not None else {}, - ( - { - "configurable": { - "thread_id": value[0], - "thread_ts": value[2], - } - } - if value[2] - else None - ), - ) + } + if value[2] + else None + ), + [ + (task_id, channel, self.serde.loads(value)) + for task_id, channel, value in cur + ], + ) def list( self, diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 2774800a0..03b28509f 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -985,8 +985,7 @@ def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: # increment start to 0 start += 1 else: - # if received no input, take that as signal to proceed - # past previous interrupt, if any + # no input is taken as signal to proceed past previous interrupt checkpoint = copy_checkpoint(checkpoint) for k in self.stream_channels_list: if k in checkpoint["channel_versions"]: @@ -1016,6 +1015,15 @@ def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: ), ) + # assign pending writes to tasks + if saved and saved.pending_writes: + for task in next_tasks: + task.writes.extend( + (c, v) + for tid, c, v in saved.pending_writes + if tid == task.id + ) + # if no more tasks, we're done if not next_tasks: if step == start: @@ -1049,12 +1057,15 @@ def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: futures = { executor.submit(run_with_retry, task, self.retry_policy): task for task in next_tasks + if not task.writes } end_time = ( self.step_timeout + time.monotonic() if self.step_timeout else None ) + if not futures: + done, inflight = set(), set() while futures: done, inflight = concurrent.futures.wait( futures, @@ -1363,8 +1374,7 @@ def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: # increment start to 0 start += 1 else: - # if received no input, take that as signal to proceed - # past previous interrupt, if any + # no input is taken as signal to proceed past previous interrupt checkpoint = copy_checkpoint(checkpoint) for k in self.stream_channels_list: if k in checkpoint["channel_versions"]: @@ -1394,6 +1404,15 @@ def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: ), ) + # assign pending writes to tasks + if saved and saved.pending_writes: + for task in next_tasks: + task.writes.extend( + (c, v) + for tid, c, v in saved.pending_writes + if tid == task.id + ) + # if no more tasks, we're done if not next_tasks: if step == start: @@ -1430,10 +1449,13 @@ def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: arun_with_retry(task, self.retry_policy, do_stream) ): task for task in next_tasks + if not task.writes } end_time = ( self.step_timeout + loop.time() if self.step_timeout else None ) + if not futures: + done, inflight = set(), set() while futures: done, inflight = await asyncio.wait( futures, diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 0ed619735..8eaf631b9 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -38,6 +38,7 @@ from langgraph.channels.last_value import LastValue from langgraph.channels.topic import Topic from langgraph.checkpoint.base import ( + BaseCheckpointSaver, Checkpoint, CheckpointMetadata, CheckpointTuple, @@ -955,6 +956,95 @@ def raise_if_above_10(input: int) -> int: assert checkpoint["channel_values"].get("total") == 5 +@pytest.mark.parametrize( + "checkpointer", + [ + MemorySaverAssertImmutable(), + SqliteSaver.from_conn_string(":memory:"), + ], + ids=[ + "memory", + "sqlite", + ], +) +def test_pending_writes_resume(checkpointer: BaseCheckpointSaver) -> None: + try: + + class State(TypedDict): + value: Annotated[int, operator.add] + + class AwhileMaker: + def __init__(self, sleep: float, rtn: Union[Dict, Exception]) -> None: + self.sleep = sleep + self.rtn = rtn + self.reset() + + def __call__(self, input: State) -> Any: + self.calls += 1 + time.sleep(self.sleep) + if isinstance(self.rtn, Exception): + raise self.rtn + else: + return self.rtn + + def reset(self): + self.calls = 0 + + one = AwhileMaker(0.2, {"value": 2}) + two = AwhileMaker(0.6, ValueError("I'm not good")) + builder = StateGraph(State) + builder.add_node("one", one) + builder.add_node("two", two) + builder.add_edge(START, "one") + builder.add_edge(START, "two") + graph = builder.compile(checkpointer=checkpointer) + + # test interrupting astream + thread1: RunnableConfig = {"configurable": {"thread_id": 1}} + with pytest.raises(ValueError, match="I'm not good"): + graph.invoke({"value": 1}, thread1) + + # both nodes should have been called once + assert one.calls == 1 + assert two.calls == 1 + + # latest checkpoint should be before nodes "one", "two" + state = graph.get_state(thread1) + assert state is not None + assert state.values == {"value": 1} + assert state.next == ("one", "two") + assert state.metadata == {"source": "loop", "step": 0, "writes": None} + # should contain pending write of "one" + checkpoint = checkpointer.get_tuple(thread1) + assert checkpoint is not None + assert checkpoint.pending_writes == [ + (AnyStr(), "one", "one"), + (AnyStr(), "value", 2), + ] + # both pending writes come from same task + assert checkpoint.pending_writes[0][0] == checkpoint.pending_writes[1][0] + + # resume execution + with pytest.raises(ValueError, match="I'm not good"): + graph.invoke(None, thread1) + + # node "one" succeded previously, so shouldn't be called again + assert one.calls == 1 + # node "two" should have been called once again + assert two.calls == 2 + + # confirm no new checkpoints saved + state_two = graph.get_state(thread1) + assert state_two == state + + # resume execution, without exception + two.rtn = {"value": 3} + assert graph.invoke(None, thread1) == {"value": 6} + finally: + if getattr(checkpointer, "__exit__", None): + checkpointer.__exit__(None, None, None) + + def test_invoke_checkpoint_sqlite(mocker: MockerFixture) -> None: adder = mocker.Mock(side_effect=lambda x: x["total"] + x["input"]) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 6184ac09c..91e7255ae 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -1,7 +1,6 @@ import asyncio import json import operator -import time from collections import Counter from contextlib import asynccontextmanager, contextmanager from typing import ( @@ -233,6 +232,11 @@ async def alittlewhile(input: Any) -> None: AsyncSqliteSaver.from_conn_string(":memory:"), None, ], + ids=[ + "memory", + "aiosqlite", + "none", + ], ) async def test_cancel_graph_astream( checkpointer: Optional[BaseCheckpointSaver], @@ -299,6 +303,11 @@ async def alittlewhile(input: State) -> None: AsyncSqliteSaver.from_conn_string(":memory:"), None, ], + ids=[ + "memory", + "aiosqlite", + "none", + ], ) async def test_cancel_graph_astream_events_v2( checkpointer: Optional[BaseCheckpointSaver], @@ -347,7 +356,6 @@ async def alittlewhile(input: State) -> None: ) as stream: async for chunk in stream: if chunk["event"] == "on_chain_stream" and not chunk["parent_ids"]: - print(time.perf_counter(), "got event out here", chunk) got_event = True assert chunk["data"]["chunk"] == {"alittlewhile": {"value": 2}} break @@ -1056,6 +1064,95 @@ def raise_if_above_10(input: int) -> int: assert checkpoint["channel_values"].get("total") == 5 +@pytest.mark.parametrize( + "checkpointer", + [ + MemorySaverAssertImmutable(), + AsyncSqliteSaver.from_conn_string(":memory:"), + ], + ids=[ + "memory", + "sqlite", + ], +) +async def test_pending_writes_resume(checkpointer: BaseCheckpointSaver) -> None: + try: + + class State(TypedDict): + value: Annotated[int, operator.add] + + class AwhileMaker: + def __init__(self, sleep: float, rtn: Union[Dict, Exception]) -> None: + self.sleep = sleep + self.rtn = rtn + self.reset() + + async def __call__(self, input: State) -> Any: + self.calls += 1 + await asyncio.sleep(self.sleep) + if isinstance(self.rtn, Exception): + raise self.rtn + else: + return self.rtn + + def reset(self): + self.calls = 0 + + one = AwhileMaker(0.2, {"value": 2}) + two = AwhileMaker(0.6, ValueError("I'm not good")) + builder = StateGraph(State) + builder.add_node("one", one) + builder.add_node("two", two) + builder.add_edge(START, "one") + builder.add_edge(START, "two") + graph = builder.compile(checkpointer=checkpointer) + + # test interrupting astream + thread1: RunnableConfig = {"configurable": {"thread_id": 1}} + with pytest.raises(ValueError, match="I'm not good"): + await graph.ainvoke({"value": 1}, thread1) + + # both nodes should have been called once + assert one.calls == 1 + assert two.calls == 1 + + # latest checkpoint should be before nodes "one", "two" + state = await graph.aget_state(thread1) + assert state is not None + assert state.values == {"value": 1} + assert state.next == ("one", "two") + assert state.metadata == {"source": "loop", "step": 0, "writes": None} + # should contain pending write of "one" + checkpoint = await checkpointer.aget_tuple(thread1) + assert checkpoint is not None + assert checkpoint.pending_writes == [ + (AnyStr(), "one", "one"), + (AnyStr(), "value", 2), + ] + # both pending writes come from same task + assert checkpoint.pending_writes[0][0] == checkpoint.pending_writes[1][0] + + # resume execution + with pytest.raises(ValueError, match="I'm not good"): + await graph.ainvoke(None, thread1) + + # node "one" succeded previously, so shouldn't be called again + assert one.calls == 1 + # node "two" should have been called once again + assert two.calls == 2 + + # confirm no new checkpoints saved + state_two = await graph.aget_state(thread1) + assert state_two == state + + # resume execution, without exception + two.rtn = {"value": 3} + assert await graph.ainvoke(None, thread1) == {"value": 6} + finally: + if getattr(checkpointer, "__aexit__", None): + await checkpointer.__aexit__(None, None, None) + + async def test_invoke_checkpoint_aiosqlite(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x["total"] + x["input"]) From fa743a9d9e98e3d344dc4733f8a233e1eeac9e94 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 10 Jul 2024 09:16:07 -0700 Subject: [PATCH 4/9] Add comment --- libs/langgraph/tests/test_pregel.py | 1 + libs/langgraph/tests/test_pregel_async.py | 1 + 2 files changed, 2 insertions(+) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 8eaf631b9..49832ef5f 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -1039,6 +1039,7 @@ def reset(self): # resume execution, without exception two.rtn = {"value": 3} + # both the pending write and the new write were applied, 1 + 2 + 3 = 6 assert graph.invoke(None, thread1) == {"value": 6} finally: if getattr(checkpointer, "__exit__", None): diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 91e7255ae..4a65696a9 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -1147,6 +1147,7 @@ def reset(self): # resume execution, without exception two.rtn = {"value": 3} + # both the pending write and the new write were applied, 1 + 2 + 3 = 6 assert await graph.ainvoke(None, thread1) == {"value": 6} finally: if getattr(checkpointer, "__aexit__", None): From 9be374f1c4835d50b7cf4e760175fa5885b18382 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 10 Jul 2024 11:26:18 -0700 Subject: [PATCH 5/9] Add descriptive error --- libs/langgraph/langgraph/checkpoint/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/libs/langgraph/langgraph/checkpoint/base.py b/libs/langgraph/langgraph/checkpoint/base.py index 2cee04adc..7fa64b241 100644 --- a/libs/langgraph/langgraph/checkpoint/base.py +++ b/libs/langgraph/langgraph/checkpoint/base.py @@ -185,7 +185,9 @@ def put_writes( writes: List[Tuple[str, Any]], task_id: str, ) -> None: - raise NotImplementedError + raise NotImplementedError( + "This method was added in langgraph 0.1.7. Please update your checkpointer to implement it." + ) async def aget(self, config: RunnableConfig) -> Optional[Checkpoint]: if value := await self.aget_tuple(config): @@ -219,7 +221,9 @@ async def aput_writes( writes: List[Tuple[str, Any]], task_id: str, ) -> None: - raise NotImplementedError + raise NotImplementedError( + "This method was added in langgraph 0.1.7. Please update your checkpointer to implement it." + ) def get_next_version(self, current: Optional[V], channel: BaseChannel) -> V: """Get the next version of a channel. Default is to use int versions, incrementing by 1. If you override, you can use str/int/float versions, From 7afb3fe6711ac7bb939da3a450c0ccbbeefd0612 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 10 Jul 2024 13:12:15 -0700 Subject: [PATCH 6/9] Fix bug found by will --- libs/langgraph/langgraph/pregel/__init__.py | 2 +- libs/langgraph/tests/test_pregel.py | 27 ++++++++++++++++++++ libs/langgraph/tests/test_pregel_async.py | 28 +++++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 03b28509f..be605dd98 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -1926,7 +1926,7 @@ def _prepare_next_tasks( _local_write, writes.extend, processes, channels ), CONFIG_KEY_READ: partial( - _local_read, checkpoint, channels, tasks, config + _local_read, checkpoint, channels, writes, config ), }, ), diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 49832ef5f..4c5adaef5 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -1046,6 +1046,33 @@ def reset(self): checkpointer.__exit__(None, None, None) +def test_cond_edge_after_send() -> None: + class Node: + def __init__(self, name: str): + self.name = name + setattr(self, "__name__", name) + + def __call__(self, state): + return state + [self.name] + + def send_for_fun(state): + return [Send("2", state)] + + def route_to_three(state) -> Literal["3"]: + return "3" + + builder = StateGraph(list) + builder.add_node(Node("1")) + builder.add_node(Node("2")) + builder.add_node(Node("3")) + builder.add_edge(START, "1") + builder.add_conditional_edges("1", send_for_fun) + builder.add_conditional_edges("2", route_to_three) + graph = builder.compile() + + assert graph.invoke(["0"]) == ["0", "1", "2", "3"] + + def test_invoke_checkpoint_sqlite(mocker: MockerFixture) -> None: adder = mocker.Mock(side_effect=lambda x: x["total"] + x["input"]) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 4a65696a9..e6641e3f7 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -11,6 +11,7 @@ Dict, Generator, List, + Literal, Optional, Sequence, Tuple, @@ -1154,6 +1155,33 @@ def reset(self): await checkpointer.__aexit__(None, None, None) +async def test_cond_edge_after_send() -> None: + class Node: + def __init__(self, name: str): + self.name = name + setattr(self, "__name__", name) + + async def __call__(self, state): + return state + [self.name] + + async def send_for_fun(state): + return [Send("2", state)] + + async def route_to_three(state) -> Literal["3"]: + return "3" + + builder = StateGraph(list) + builder.add_node(Node("1")) + builder.add_node(Node("2")) + builder.add_node(Node("3")) + builder.add_edge(START, "1") + builder.add_conditional_edges("1", send_for_fun) + builder.add_conditional_edges("2", route_to_three) + graph = builder.compile() + + assert await graph.ainvoke(["0"]) == ["0", "1", "2", "3"] + + async def test_invoke_checkpoint_aiosqlite(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x["total"] + x["input"]) From ca7041ffc10f56a4f45d7781a80cf6c24bc8b971 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 10 Jul 2024 13:13:11 -0700 Subject: [PATCH 7/9] Fix comments --- libs/langgraph/tests/test_pregel.py | 1 - libs/langgraph/tests/test_pregel_async.py | 1 - 2 files changed, 2 deletions(-) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 4c5adaef5..0904887d3 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -999,7 +999,6 @@ def reset(self): builder.add_edge(START, "two") graph = builder.compile(checkpointer=checkpointer) - # test interrupting astream thread1: RunnableConfig = {"configurable": {"thread_id": 1}} with pytest.raises(ValueError, match="I'm not good"): graph.invoke({"value": 1}, thread1) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index e6641e3f7..f0e78eef0 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -1108,7 +1108,6 @@ def reset(self): builder.add_edge(START, "two") graph = builder.compile(checkpointer=checkpointer) - # test interrupting astream thread1: RunnableConfig = {"configurable": {"thread_id": 1}} with pytest.raises(ValueError, match="I'm not good"): await graph.ainvoke({"value": 1}, thread1) From 2df8ab85e1c4c21e73fe06f7a1f645483c2c9f8c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 10 Jul 2024 13:15:34 -0700 Subject: [PATCH 8/9] Lint --- libs/langgraph/tests/test_pregel.py | 2 +- libs/langgraph/tests/test_pregel_async.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 0904887d3..70a65d4f5 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -1027,7 +1027,7 @@ def reset(self): with pytest.raises(ValueError, match="I'm not good"): graph.invoke(None, thread1) - # node "one" succeded previously, so shouldn't be called again + # node "one" succeeded previously, so shouldn't be called again assert one.calls == 1 # node "two" should have been called once again assert two.calls == 2 diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index f0e78eef0..d814fd426 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -1136,7 +1136,7 @@ def reset(self): with pytest.raises(ValueError, match="I'm not good"): await graph.ainvoke(None, thread1) - # node "one" succeded previously, so shouldn't be called again + # node "one" succeeded previously, so shouldn't be called again assert one.calls == 1 # node "two" should have been called once again assert two.calls == 2 From f93e306dfdc16019143dde57a950e868bf44dd37 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 10 Jul 2024 14:14:32 -0700 Subject: [PATCH 9/9] Don't save pending write if executing only one node in step --- libs/langgraph/langgraph/pregel/__init__.py | 10 +++++---- libs/langgraph/tests/test_pregel.py | 16 ++++++++------ libs/langgraph/tests/test_pregel_async.py | 24 +++++++++++---------- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index be605dd98..79f678fd8 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -1083,8 +1083,9 @@ def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: # exception will be handled in panic_or_proceed futures.clear() else: - # save task writes to checkpointer - if self.checkpointer is not None: + # save task writes to checkpointer, unless this + # is the single or last task in this step + if futures: put_writes(task.id, task.writes) # yield updates output for the finished task if "updates" in stream_modes: @@ -1471,8 +1472,9 @@ def put_checkpoint(metadata: CheckpointMetadata) -> Iterator[Any]: # exception will be handle in panic_or_proceed futures.clear() else: - # save task writes to checkpointer - if self.checkpointer is not None: + # save task writes to checkpointer, unless this + # is the single or last task in this step + if futures: put_writes(task.id, task.writes) # yield updates output for the finished task if "updates" in stream_modes: diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 70a65d4f5..a304534f2 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -208,10 +208,9 @@ def get_next_version(self, current: Optional[int], channel: BaseChannel) -> int: def logic(inp: str) -> str: return "" - builder = Graph() + builder = StateGraph(Annotated[str, operator.add]) builder.add_node("agent", logic) - builder.set_entry_point("agent") - builder.set_finish_point("agent") + builder.add_edge(START, "agent") graph = builder.compile(checkpointer=FaultyGetCheckpointer()) with pytest.raises(ValueError, match="Faulty get_tuple"): @@ -221,14 +220,17 @@ def logic(inp: str) -> str: with pytest.raises(ValueError, match="Faulty put"): graph.invoke("", {"configurable": {"thread_id": "thread-1"}}) - graph = builder.compile(checkpointer=FaultyPutWritesCheckpointer()) - with pytest.raises(ValueError, match="Faulty put_writes"): - graph.invoke("", {"configurable": {"thread_id": "thread-1"}}) - graph = builder.compile(checkpointer=FaultyVersionCheckpointer()) with pytest.raises(ValueError, match="Faulty get_next_version"): graph.invoke("", {"configurable": {"thread_id": "thread-1"}}) + # add parallel node + builder.add_node("parallel", logic) + builder.add_edge(START, "parallel") + graph = builder.compile(checkpointer=FaultyPutWritesCheckpointer()) + with pytest.raises(ValueError, match="Faulty put_writes"): + graph.invoke("", {"configurable": {"thread_id": "thread-1"}}) + def test_reducer_before_first_node() -> None: from langchain_core.messages import HumanMessage diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index d814fd426..41f99d53f 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -90,10 +90,9 @@ def get_next_version(self, current: Optional[int], channel: BaseChannel) -> int: def logic(inp: str) -> str: return "" - builder = Graph() + builder = StateGraph(Annotated[str, operator.add]) builder.add_node("agent", logic) - builder.set_entry_point("agent") - builder.set_finish_point("agent") + builder.add_edge(START, "agent") graph = builder.compile(checkpointer=FaultyGetCheckpointer()) with pytest.raises(ValueError, match="Faulty get_tuple"): @@ -119,25 +118,28 @@ def logic(inp: str) -> str: ): pass - graph = builder.compile(checkpointer=FaultyPutWritesCheckpointer()) - with pytest.raises(ValueError, match="Faulty put_writes"): + graph = builder.compile(checkpointer=FaultyVersionCheckpointer()) + with pytest.raises(ValueError, match="Faulty get_next_version"): await graph.ainvoke("", {"configurable": {"thread_id": "thread-1"}}) - with pytest.raises(ValueError, match="Faulty put_writes"): + with pytest.raises(ValueError, match="Faulty get_next_version"): async for _ in graph.astream("", {"configurable": {"thread_id": "thread-2"}}): pass - with pytest.raises(ValueError, match="Faulty put_writes"): + with pytest.raises(ValueError, match="Faulty get_next_version"): async for _ in graph.astream_events( "", {"configurable": {"thread_id": "thread-3"}}, version="v2" ): pass - graph = builder.compile(checkpointer=FaultyVersionCheckpointer()) - with pytest.raises(ValueError, match="Faulty get_next_version"): + # add a parallel node + builder.add_node("parallel", logic) + builder.add_edge(START, "parallel") + graph = builder.compile(checkpointer=FaultyPutWritesCheckpointer()) + with pytest.raises(ValueError, match="Faulty put_writes"): await graph.ainvoke("", {"configurable": {"thread_id": "thread-1"}}) - with pytest.raises(ValueError, match="Faulty get_next_version"): + with pytest.raises(ValueError, match="Faulty put_writes"): async for _ in graph.astream("", {"configurable": {"thread_id": "thread-2"}}): pass - with pytest.raises(ValueError, match="Faulty get_next_version"): + with pytest.raises(ValueError, match="Faulty put_writes"): async for _ in graph.astream_events( "", {"configurable": {"thread_id": "thread-3"}}, version="v2" ):