From 019423ef6645a2324d0a7db2765f9a844eb4be88 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Thu, 25 Jul 2024 14:27:33 -0700 Subject: [PATCH] cloudflare[minor]: Adds Cloudflare D1 checkpointer (#6212) * Adds Cloudflare D1 checkpointer * Fix lint + format --- libs/langchain-cloudflare/.gitignore | 4 + libs/langchain-cloudflare/langchain.config.js | 3 +- libs/langchain-cloudflare/package.json | 26 ++- .../src/langgraph/checkpointers.ts | 212 ++++++++++++++++++ yarn.lock | 22 ++ 5 files changed, 264 insertions(+), 3 deletions(-) create mode 100644 libs/langchain-cloudflare/src/langgraph/checkpointers.ts diff --git a/libs/langchain-cloudflare/.gitignore b/libs/langchain-cloudflare/.gitignore index c10034e2f1be..3443681d0418 100644 --- a/libs/langchain-cloudflare/.gitignore +++ b/libs/langchain-cloudflare/.gitignore @@ -2,6 +2,10 @@ index.cjs index.js index.d.ts index.d.cts +langgraph/checkpointers.cjs +langgraph/checkpointers.js +langgraph/checkpointers.d.ts +langgraph/checkpointers.d.cts node_modules dist .yarn diff --git a/libs/langchain-cloudflare/langchain.config.js b/libs/langchain-cloudflare/langchain.config.js index 416001cb4772..e112dcf90c6c 100644 --- a/libs/langchain-cloudflare/langchain.config.js +++ b/libs/langchain-cloudflare/langchain.config.js @@ -11,9 +11,10 @@ function abs(relativePath) { export const config = { - internals: [/node\:/, /@langchain\/core\//], + internals: [/node\:/, /@langchain\/core\//, /@langchain\/langgraph\/web/], entrypoints: { index: "index", + "langgraph/checkpointers": "langgraph/checkpointers", }, tsConfigPath: resolve("./tsconfig.json"), cjsSource: "./dist-cjs", diff --git a/libs/langchain-cloudflare/package.json b/libs/langchain-cloudflare/package.json index 31a6165706ca..06eaa43f981b 100644 --- a/libs/langchain-cloudflare/package.json +++ b/libs/langchain-cloudflare/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/cloudflare", - "version": "0.0.6", + "version": "0.0.7-rc.0", "description": "Cloudflare integration for LangChain.js", "type": "module", "engines": { @@ -42,6 +42,7 @@ "devDependencies": { "@cloudflare/workers-types": "^4.20231218.0", "@jest/globals": "^29.5.0", + "@langchain/langgraph": "~0.0.31", "@langchain/scripts": "~0.0.20", "@langchain/standard-tests": "0.0.0", "@swc/core": "^1.3.90", @@ -66,6 +67,14 @@ "ts-jest": "^29.1.0", "typescript": "<5.2.0" }, + "peerDependencies": { + "@langchain/langgraph": "<0.1.0" + }, + "peerDependenciesMeta": { + "@langchain/langgraph": { + "optional": true + } + }, "publishConfig": { "access": "public" }, @@ -79,6 +88,15 @@ "import": "./index.js", "require": "./index.cjs" }, + "./langgraph/checkpointers": { + "types": { + "import": "./langgraph/checkpointers.d.ts", + "require": "./langgraph/checkpointers.d.cts", + "default": "./langgraph/checkpointers.d.ts" + }, + "import": "./langgraph/checkpointers.js", + "require": "./langgraph/checkpointers.cjs" + }, "./package.json": "./package.json" }, "files": [ @@ -86,6 +104,10 @@ "index.cjs", "index.js", "index.d.ts", - "index.d.cts" + "index.d.cts", + "langgraph/checkpointers.cjs", + "langgraph/checkpointers.js", + "langgraph/checkpointers.d.ts", + "langgraph/checkpointers.d.cts" ] } diff --git a/libs/langchain-cloudflare/src/langgraph/checkpointers.ts b/libs/langchain-cloudflare/src/langgraph/checkpointers.ts new file mode 100644 index 000000000000..1d3ad8dfb194 --- /dev/null +++ b/libs/langchain-cloudflare/src/langgraph/checkpointers.ts @@ -0,0 +1,212 @@ +import { D1Database } from "@cloudflare/workers-types"; + +import { RunnableConfig } from "@langchain/core/runnables"; +import { + BaseCheckpointSaver, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, + SerializerProtocol, +} from "@langchain/langgraph/web"; + +// snake_case is used to match Python implementation +interface Row { + checkpoint: string; + metadata: string; + parent_id?: string; + thread_id: string; + checkpoint_id: string; +} + +export type CloudflareD1SaverFields = { + db: D1Database; +}; + +export class CloudflareD1Saver extends BaseCheckpointSaver { + db: D1Database; + + protected isSetup: boolean; + + constructor( + fields: CloudflareD1SaverFields, + serde?: SerializerProtocol + ) { + super(serde); + this.db = fields.db; + this.isSetup = false; + } + + private async setup() { + if (this.isSetup) { + return; + } + + try { + await this.db.exec(` +CREATE TABLE IF NOT EXISTS checkpoints (thread_id TEXT NOT NULL, checkpoint_id TEXT NOT NULL, parent_id TEXT, checkpoint BLOB, metadata BLOB, PRIMARY KEY (thread_id, checkpoint_id));`); + } catch (error) { + console.log("Error creating checkpoints table", error); + throw error; + } + + this.isSetup = true; + } + + async getTuple(config: RunnableConfig): Promise { + await this.setup(); + const thread_id = config.configurable?.thread_id; + const checkpoint_id = config.configurable?.checkpoint_id; + + if (checkpoint_id) { + try { + const row: Row | null = await this.db + .prepare( + `SELECT checkpoint, parent_id, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_id = ?` + ) + .bind(thread_id, checkpoint_id) + .first(); + + if (row) { + return { + config, + checkpoint: (await this.serde.parse(row.checkpoint)) as Checkpoint, + metadata: (await this.serde.parse( + row.metadata + )) as CheckpointMetadata, + parentConfig: row.parent_id + ? { + configurable: { + thread_id, + checkpoint_id: row.parent_id, + }, + } + : undefined, + }; + } + } catch (error) { + console.log("Error retrieving checkpoint", error); + throw error; + } + } else { + const row: Row | null = await this.db + .prepare( + `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1` + ) + .bind(thread_id) + .first(); + + if (row) { + return { + config: { + configurable: { + thread_id: row.thread_id, + checkpoint_id: row.checkpoint_id, + }, + }, + checkpoint: (await this.serde.parse(row.checkpoint)) as Checkpoint, + metadata: (await this.serde.parse( + row.metadata + )) as CheckpointMetadata, + parentConfig: row.parent_id + ? { + configurable: { + thread_id: row.thread_id, + checkpoint_id: row.parent_id, + }, + } + : undefined, + }; + } + } + + return undefined; + } + + async *list( + config: RunnableConfig, + limit?: number, + before?: RunnableConfig + ): AsyncGenerator { + await this.setup(); + const thread_id = config.configurable?.thread_id; + let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ${ + before ? "AND checkpoint_id < ?" : "" + } ORDER BY checkpoint_id DESC`; + if (limit) { + sql += ` LIMIT ${limit}`; + } + const args = [thread_id, before?.configurable?.checkpoint_id].filter( + Boolean + ); + + try { + const { results: rows }: { results: Row[] } = await this.db + .prepare(sql) + .bind(...args) + .all(); + + if (rows) { + for (const row of rows) { + yield { + config: { + configurable: { + thread_id: row.thread_id, + checkpoint_id: row.checkpoint_id, + }, + }, + checkpoint: (await this.serde.parse(row.checkpoint)) as Checkpoint, + metadata: (await this.serde.parse( + row.metadata + )) as CheckpointMetadata, + parentConfig: row.parent_id + ? { + configurable: { + thread_id: row.thread_id, + checkpoint_id: row.parent_id, + }, + } + : undefined, + }; + } + } + } catch (error) { + console.log("Error listing checkpoints", error); + throw error; + } + } + + async put( + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata + ): Promise { + await this.setup(); + + try { + const row = [ + config.configurable?.thread_id ?? null, + checkpoint.id, + config.configurable?.checkpoint_id ?? null, + this.serde.stringify(checkpoint), + this.serde.stringify(metadata), + ]; + + await this.db + .prepare( + `INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_id, parent_id, checkpoint, metadata) VALUES (?, ?, ?, ?, ?)` + ) + .bind(...row) + .run(); + } catch (error) { + console.log("Error saving checkpoint", error); + throw error; + } + + return { + configurable: { + thread_id: config.configurable?.thread_id, + checkpoint_id: checkpoint.id, + }, + }; + } +} diff --git a/yarn.lock b/yarn.lock index c387d6eb2ded..7a03517cb874 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10999,6 +10999,7 @@ __metadata: "@cloudflare/workers-types": ^4.20231218.0 "@jest/globals": ^29.5.0 "@langchain/core": ">0.1.0 <0.3.0" + "@langchain/langgraph": ~0.0.31 "@langchain/scripts": ~0.0.20 "@langchain/standard-tests": 0.0.0 "@swc/core": ^1.3.90 @@ -11023,6 +11024,11 @@ __metadata: ts-jest: ^29.1.0 typescript: <5.2.0 uuid: ^10.0.0 + peerDependencies: + "@langchain/langgraph": <0.1.0 + peerDependenciesMeta: + "@langchain/langgraph": + optional: true languageName: unknown linkType: soft @@ -11960,6 +11966,22 @@ __metadata: languageName: node linkType: hard +"@langchain/langgraph@npm:~0.0.31": + version: 0.0.31 + resolution: "@langchain/langgraph@npm:0.0.31" + dependencies: + "@langchain/core": ">=0.2.18 <0.3.0" + uuid: ^10.0.0 + zod: ^3.23.8 + peerDependencies: + better-sqlite3: ^9.5.0 + peerDependenciesMeta: + better-sqlite3: + optional: true + checksum: 74c0af490dab5c1f38d426cdeb0530fd300606bd28bb099d27b0ace029a02800a75fcc047f6755d853b485e78728b472170a19173803014dcc54bafe85939d9f + languageName: node + linkType: hard + "@langchain/mistralai@^0.0.26, @langchain/mistralai@workspace:*, @langchain/mistralai@workspace:libs/langchain-mistralai": version: 0.0.0-use.local resolution: "@langchain/mistralai@workspace:libs/langchain-mistralai"