Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lib: Checkpoint pending writes whenever a node finishes #976

Merged
merged 9 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 90 additions & 47 deletions libs/langgraph/langgraph/checkpoint/aiosqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
import functools
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import Any, AsyncIterator, Dict, Iterator, Optional, TypeVar
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
Optional,
Sequence,
Tuple,
TypeVar,
)

import aiosqlite
from langchain_core.runnables import RunnableConfig
Expand Down Expand Up @@ -203,6 +212,15 @@ async def setup(self) -> None:
metadata BLOB,
PRIMARY KEY (thread_id, thread_ts)
);
CREATE TABLE IF NOT EXISTS writes (
thread_id TEXT NOT NULL,
thread_ts TEXT NOT NULL,
task_id TEXT NOT NULL,
idx INTEGER NOT NULL,
channel TEXT NOT NULL,
value BLOB,
PRIMARY KEY (thread_id, thread_ts, task_id, idx)
);
"""
):
await self.conn.commit()
Expand All @@ -224,56 +242,58 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
await self.setup()
if config["configurable"].get("thread_ts"):
async with self.conn.execute(
"SELECT checkpoint, parent_ts, metadata FROM checkpoints WHERE thread_id = ? AND thread_ts = ?",
(
str(config["configurable"]["thread_id"]),
str(config["configurable"]["thread_ts"]),
),
) as cursor:
if value := await cursor.fetchone():
return CheckpointTuple(
config,
self.serde.loads(value[0]),
self.serde.loads(value[2]) if value[2] is not None else {},
(
{
"configurable": {
"thread_id": config["configurable"]["thread_id"],
"thread_ts": value[1],
}
}
if value[1]
else None
),
)
else:
async with self.conn.execute(
"SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ORDER BY thread_ts DESC LIMIT 1",
(str(config["configurable"]["thread_id"]),),
) as cursor:
if value := await cursor.fetchone():
return CheckpointTuple(
async with self.conn.cursor() as cur:
# find the latest checkpoint for the thread_id
if config["configurable"].get("thread_ts"):
await cur.execute(
"SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND thread_ts = ?",
(
str(config["configurable"]["thread_id"]),
str(config["configurable"]["thread_ts"]),
),
)
else:
await cur.execute(
"SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ORDER BY thread_ts DESC LIMIT 1",
(str(config["configurable"]["thread_id"]),),
)
# if a checkpoint is found, return it
if value := await cur.fetchone():
if not config["configurable"].get("thread_ts"):
config = {
"configurable": {
"thread_id": value[0],
"thread_ts": value[1],
}
}
# find any pending writes
await cur.execute(
"SELECT task_id, channel, value FROM writes WHERE thread_id = ? AND thread_ts = ?",
(
str(config["configurable"]["thread_id"]),
str(config["configurable"]["thread_ts"]),
),
)
# deserialize the checkpoint and metadata
return CheckpointTuple(
config,
self.serde.loads(value[3]),
self.serde.loads(value[4]) if value[4] is not None else {},
(
{
"configurable": {
"thread_id": value[0],
"thread_ts": value[1],
}
},
self.serde.loads(value[3]),
self.serde.loads(value[4]) if value[4] is not None else {},
(
{
"configurable": {
"thread_id": value[0],
"thread_ts": value[2],
}
"thread_ts": value[2],
}
if value[2]
else None
),
)
}
if value[2]
else None
),
[
(task_id, channel, self.serde.loads(value))
async for task_id, channel, value in cur
],
)

async def alist(
self,
Expand Down Expand Up @@ -358,3 +378,26 @@ async def aput(
"thread_ts": checkpoint["id"],
}
}

async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
await self.setup()
async with self.conn.executemany(
"INSERT OR REPLACE INTO writes (thread_id, thread_ts, task_id, idx, channel, value) VALUES (?, ?, ?, ?, ?, ?)",
[
(
str(config["configurable"]["thread_id"]),
str(config["configurable"]["thread_ts"]),
task_id,
idx,
channel,
self.serde.dumps(value),
)
for idx, (channel, value) in enumerate(writes)
],
):
await self.conn.commit()
22 changes: 22 additions & 0 deletions libs/langgraph/langgraph/checkpoint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Literal,
NamedTuple,
Optional,
Tuple,
TypedDict,
TypeVar,
Union,
Expand Down Expand Up @@ -117,6 +118,7 @@ class CheckpointTuple(NamedTuple):
checkpoint: Checkpoint
metadata: CheckpointMetadata
parent_config: Optional[RunnableConfig] = None
pending_writes: Optional[List[Tuple[str, str, Any]]] = None


CheckpointThreadId = ConfigurableFieldSpec(
Expand Down Expand Up @@ -177,6 +179,16 @@ def put(
) -> RunnableConfig:
raise NotImplementedError

def put_writes(
self,
config: RunnableConfig,
writes: List[Tuple[str, Any]],
task_id: str,
) -> None:
raise NotImplementedError(
"This method was added in langgraph 0.1.7. Please update your checkpointer to implement it."
)

async def aget(self, config: RunnableConfig) -> Optional[Checkpoint]:
if value := await self.aget_tuple(config):
return value.checkpoint
Expand All @@ -203,6 +215,16 @@ async def aput(
) -> RunnableConfig:
raise NotImplementedError

async def aput_writes(
self,
config: RunnableConfig,
writes: List[Tuple[str, Any]],
task_id: str,
) -> None:
raise NotImplementedError(
"This method was added in langgraph 0.1.7. Please update your checkpointer to implement it."
)

def get_next_version(self, current: Optional[V], channel: BaseChannel) -> V:
"""Get the next version of a channel. Default is to use int versions, incrementing by 1. If you override, you can use str/int/float versions,
as long as they are monotonically increasing."""
Expand Down
45 changes: 44 additions & 1 deletion libs/langgraph/langgraph/checkpoint/memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from collections import defaultdict
from functools import partial
from typing import Any, AsyncIterator, Dict, Iterator, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple

from langchain_core.runnables import RunnableConfig

Expand Down Expand Up @@ -53,6 +53,7 @@ def __init__(
) -> None:
super().__init__(serde=serde)
self.storage = defaultdict(dict)
self.writes = defaultdict(list)

def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the in-memory storage.
Expand All @@ -72,19 +73,27 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
if ts := config["configurable"].get("thread_ts"):
if saved := self.storage[thread_id].get(ts):
checkpoint, metadata = saved
writes = self.writes[(thread_id, ts)]
return CheckpointTuple(
config=config,
checkpoint=self.serde.loads(checkpoint),
metadata=self.serde.loads(metadata),
pending_writes=[
(id, c, self.serde.loads(v)) for id, c, v in writes
],
)
else:
if checkpoints := self.storage[thread_id]:
ts = max(checkpoints.keys())
checkpoint, metadata = checkpoints[ts]
writes = self.writes[(thread_id, ts)]
return CheckpointTuple(
config={"configurable": {"thread_id": thread_id, "thread_ts": ts}},
checkpoint=self.serde.loads(checkpoint),
metadata=self.serde.loads(metadata),
pending_writes=[
(id, c, self.serde.loads(v)) for id, c, v in writes
],
)

def list(
Expand Down Expand Up @@ -168,6 +177,30 @@ def put(
}
}

def put_writes(
self,
config: RunnableConfig,
writes: List[Tuple[str, Any]],
task_id: str,
) -> RunnableConfig:
"""Save a list of writes to the in-memory storage.

This method saves a list of writes to the in-memory storage. The writes are associated
with the provided config.

Args:
config (RunnableConfig): The config to associate with the writes.
writes (list[tuple[str, Any]]): The writes to save.

Returns:
RunnableConfig: The updated config containing the saved writes' timestamp.
"""
thread_id = config["configurable"]["thread_id"]
ts = config["configurable"]["thread_ts"]
self.writes[(thread_id, ts)].extend(
[(task_id, c, self.serde.dumps(v)) for c, v in writes]
)

async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Asynchronous version of get_tuple.

Expand Down Expand Up @@ -224,3 +257,13 @@ async def aput(
return await asyncio.get_running_loop().run_in_executor(
None, self.put, config, checkpoint, metadata
)

async def aput_writes(
self,
config: RunnableConfig,
writes: List[Tuple[str, Any]],
task_id: str,
) -> RunnableConfig:
return await asyncio.get_running_loop().run_in_executor(
None, self.put_writes, config, writes, task_id
)
Loading
Loading