From 6cce3b12eac37ae2d7116d04eccc8a3f8286285a Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Wed, 25 Sep 2024 19:40:20 -0700 Subject: [PATCH] feat: support `npm:postgres` (#1248) --- .changeset/weak-camels-perform.md | 8 ++ examples/package.json | 3 +- examples/vector-store/pg/.env.template | 6 + .../pg}/README.md | 0 .../pg}/load-docs.ts | 2 +- examples/vector-store/pg/neon.ts | 45 ++++++ examples/vector-store/pg/package.json | 5 + .../pg}/query.ts | 0 examples/vector-store/pg/tsconfig.json | 9 ++ packages/core/package.json | 14 ++ packages/core/src/vector-store/index.ts | 13 ++ .../node/vector-store/pg-vector-store.e2e.ts | 14 +- packages/llamaindex/package.json | 1 + .../src/vector-store/PGVectorStore.ts | 128 +++++++++++++----- pnpm-lock.yaml | 14 ++ pnpm-workspace.yaml | 2 +- tsconfig.json | 3 + 17 files changed, 225 insertions(+), 42 deletions(-) create mode 100644 .changeset/weak-camels-perform.md create mode 100644 examples/vector-store/pg/.env.template rename examples/{pg-vector-store => vector-store/pg}/README.md (100%) rename examples/{pg-vector-store => vector-store/pg}/load-docs.ts (98%) create mode 100644 examples/vector-store/pg/neon.ts create mode 100644 examples/vector-store/pg/package.json rename examples/{pg-vector-store => vector-store/pg}/query.ts (100%) create mode 100644 examples/vector-store/pg/tsconfig.json create mode 100644 packages/core/src/vector-store/index.ts diff --git a/.changeset/weak-camels-perform.md b/.changeset/weak-camels-perform.md new file mode 100644 index 0000000000..c41d4d1823 --- /dev/null +++ b/.changeset/weak-camels-perform.md @@ -0,0 +1,8 @@ +--- +"@llamaindex/core": patch +"llamaindex": patch +"@llamaindex/core-e2e": patch +"pg-vector-store": patch +--- + +feat: support `npm:postgres` diff --git a/examples/package.json b/examples/package.json index e2f79fd962..6bf993cc59 100644 --- a/examples/package.json +++ b/examples/package.json @@ -16,7 +16,8 @@ "js-tiktoken": "^1.0.14", "llamaindex": "^0.6.0", "mongodb": "^6.7.0", - "pathe": "^1.1.2" + "pathe": "^1.1.2", + "postgres": "^3.4.4" }, "devDependencies": { "@types/node": "^22.5.1", diff --git a/examples/vector-store/pg/.env.template b/examples/vector-store/pg/.env.template new file mode 100644 index 0000000000..6655a64848 --- /dev/null +++ b/examples/vector-store/pg/.env.template @@ -0,0 +1,6 @@ +# neon template +PGHOST= +PGDATABASE= +PGUSER= +PGPASSWORD= +ENDPOINT_ID= diff --git a/examples/pg-vector-store/README.md b/examples/vector-store/pg/README.md similarity index 100% rename from examples/pg-vector-store/README.md rename to examples/vector-store/pg/README.md diff --git a/examples/pg-vector-store/load-docs.ts b/examples/vector-store/pg/load-docs.ts similarity index 98% rename from examples/pg-vector-store/load-docs.ts rename to examples/vector-store/pg/load-docs.ts index 0c411b3a66..6befb9f1e8 100755 --- a/examples/pg-vector-store/load-docs.ts +++ b/examples/vector-store/pg/load-docs.ts @@ -1,11 +1,11 @@ // load-docs.ts -import fs from "fs/promises"; import { PGVectorStore, SimpleDirectoryReader, storageContextFromDefaults, VectorStoreIndex, } from "llamaindex"; +import fs from "node:fs/promises"; async function getSourceFilenames(sourceDir: string) { return await fs diff --git a/examples/vector-store/pg/neon.ts b/examples/vector-store/pg/neon.ts new file mode 100644 index 0000000000..b2256ca305 --- /dev/null +++ b/examples/vector-store/pg/neon.ts @@ -0,0 +1,45 @@ +/* eslint-disable turbo/no-undeclared-env-vars */ +import dotenv from "dotenv"; +import { Document, PGVectorStore, VectorStoreQueryMode } from "llamaindex"; +import postgres from "postgres"; + +dotenv.config(); + +const { PGHOST, PGDATABASE, PGUSER, ENDPOINT_ID } = process.env; +const PGPASSWORD = decodeURIComponent(process.env.PGPASSWORD!); + +const sql = postgres({ + host: PGHOST, + database: PGDATABASE, + username: PGUSER, + password: PGPASSWORD, + port: 5432, + ssl: "require", + connection: { + options: `project=${ENDPOINT_ID}`, + }, +}); + +await sql`CREATE EXTENSION IF NOT EXISTS vector`; + +const vectorStore = new PGVectorStore({ + dimensions: 3, + client: sql, +}); + +await vectorStore.add([ + new Document({ + text: "hello, world", + embedding: [1, 2, 3], + }), +]); + +const results = await vectorStore.query({ + mode: VectorStoreQueryMode.DEFAULT, + similarityTopK: 1, + queryEmbedding: [1, 2, 3], +}); + +console.log("result", results); + +await sql.end(); diff --git a/examples/vector-store/pg/package.json b/examples/vector-store/pg/package.json new file mode 100644 index 0000000000..ca6fc3000b --- /dev/null +++ b/examples/vector-store/pg/package.json @@ -0,0 +1,5 @@ +{ + "name": "pg-vector-store", + "type": "module", + "private": true +} diff --git a/examples/pg-vector-store/query.ts b/examples/vector-store/pg/query.ts similarity index 100% rename from examples/pg-vector-store/query.ts rename to examples/vector-store/pg/query.ts diff --git a/examples/vector-store/pg/tsconfig.json b/examples/vector-store/pg/tsconfig.json new file mode 100644 index 0000000000..ad7583bab6 --- /dev/null +++ b/examples/vector-store/pg/tsconfig.json @@ -0,0 +1,9 @@ +{ + "extends": "../../tsconfig.json", + "compilerOptions": { + "outDir": "./dist", + "types": ["node"], + "skipLibCheck": true + }, + "include": ["./**/*.ts"] +} diff --git a/packages/core/package.json b/packages/core/package.json index 5d3282cd45..ec57848f17 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -255,6 +255,20 @@ "types": "./retriever/dist/index.d.ts", "default": "./retriever/dist/index.js" } + }, + "./vector-store": { + "require": { + "types": "./dist/vector-store/index.d.cts", + "default": "./dist/vector-store/index.cjs" + }, + "import": { + "types": "./dist/vector-store/index.d.ts", + "default": "./dist/vector-store/index.js" + }, + "default": { + "types": "./dist/vector-store/index.d.ts", + "default": "./dist/vector-store/index.js" + } } }, "files": [ diff --git a/packages/core/src/vector-store/index.ts b/packages/core/src/vector-store/index.ts new file mode 100644 index 0000000000..bc8b58bf7a --- /dev/null +++ b/packages/core/src/vector-store/index.ts @@ -0,0 +1,13 @@ +/** + * should compatible with npm:pg and npm:postgres + */ +export interface IsomorphicDB { + query: (sql: string, params?: any[]) => Promise; + // begin will wrap the multiple queries in a transaction + begin: (fn: (query: IsomorphicDB["query"]) => Promise) => Promise; + + // event handler + connect: () => Promise; + close: () => Promise; + onCloseEvent: (listener: () => void) => void; +} diff --git a/packages/llamaindex/e2e/node/vector-store/pg-vector-store.e2e.ts b/packages/llamaindex/e2e/node/vector-store/pg-vector-store.e2e.ts index 5d05435abc..44e5a10fbd 100644 --- a/packages/llamaindex/e2e/node/vector-store/pg-vector-store.e2e.ts +++ b/packages/llamaindex/e2e/node/vector-store/pg-vector-store.e2e.ts @@ -27,7 +27,7 @@ await test("init with client", async (t) => { client: pgClient, shouldConnect: false, }); - assert.deepStrictEqual(await vectorStore.client(), pgClient); + assert.notDeepStrictEqual(await vectorStore.client(), undefined); }); await test("init with pool", async (t) => { @@ -44,16 +44,16 @@ await test("init with pool", async (t) => { shouldConnect: false, client, }); - assert.deepStrictEqual(await vectorStore.client(), client); + assert.notDeepStrictEqual(await vectorStore.client(), undefined); }); await test("init without client", async (t) => { const vectorStore = new PGVectorStore({ clientConfig: pgConfig }); - const pgClient = (await vectorStore.client()) as pg.Client; + const db = await vectorStore.client(); t.after(async () => { - await pgClient.end(); + await db.close(); }); - assert.notDeepStrictEqual(pgClient, undefined); + assert.notDeepStrictEqual(db, undefined); }); await test("simple node", async (t) => { @@ -71,9 +71,9 @@ await test("simple node", async (t) => { dimensions, schemaName, }); - const pgClient = (await vectorStore.client()) as pg.Client; + const db = await vectorStore.client(); t.after(async () => { - await pgClient.end(); + await db.close(); }); await vectorStore.add([node]); diff --git a/packages/llamaindex/package.json b/packages/llamaindex/package.json index 8139f58296..6d4fb8e710 100644 --- a/packages/llamaindex/package.json +++ b/packages/llamaindex/package.json @@ -97,6 +97,7 @@ "glob": "^11.0.0", "pg": "^8.12.0", "pgvector": "0.2.0", + "postgres": "^3.4.4", "typescript": "^5.6.2" }, "engines": { diff --git a/packages/llamaindex/src/vector-store/PGVectorStore.ts b/packages/llamaindex/src/vector-store/PGVectorStore.ts index d2ac5bfc0b..f915701a7f 100644 --- a/packages/llamaindex/src/vector-store/PGVectorStore.ts +++ b/packages/llamaindex/src/vector-store/PGVectorStore.ts @@ -1,5 +1,7 @@ import type pg from "pg"; +import type { IsomorphicDB } from "@llamaindex/core/vector-store"; +import type { Sql } from "postgres"; import { FilterCondition, FilterOperator, @@ -18,6 +20,61 @@ import { DEFAULT_COLLECTION } from "@llamaindex/core/global"; import type { BaseNode, Metadata } from "@llamaindex/core/schema"; import { Document, MetadataMode } from "@llamaindex/core/schema"; +// todo: create adapter for postgres client +function fromPostgres(client: Sql): IsomorphicDB { + return { + query: async (sql: string, params?: any[]): Promise => { + return client.unsafe(sql, params); + }, + begin: async (fn) => { + let res: any; + await client.begin(async (scopedClient) => { + const queryFn = async (sql: string, params?: any[]): Promise => { + return scopedClient.unsafe(sql, params); + }; + res = await fn(queryFn); + }); + return res; + }, + connect: () => Promise.resolve(), + close: async () => client.end(), + onCloseEvent: () => { + // no close event + }, + }; +} + +function fromPG(client: pg.Client | pg.PoolClient): IsomorphicDB { + const queryFn = async (sql: string, params?: any[]): Promise => { + return (await client.query(sql, params)).rows; + }; + return { + query: queryFn, + begin: async (fn) => { + await client.query("BEGIN"); + try { + const result = await fn(queryFn); + await client.query("COMMIT"); + return result; + } catch (e) { + await client.query("ROLLBACK"); + throw e; + } + }, + connect: () => client.connect(), + close: async () => { + if ("end" in client) { + await client.end(); + } else if ("release" in client) { + client.release(); + } + }, + onCloseEvent: (fn) => { + client.on("end", fn); + }, + }; +} + export const PGVECTOR_SCHEMA = "public"; export const PGVECTOR_TABLE = "llamaindex_embedding"; export const DEFAULT_DIMENSIONS = 1536; @@ -47,6 +104,13 @@ export type PGVectorStoreConfig = PGVectorStoreBaseConfig & shouldConnect?: boolean | undefined; client: pg.Client | pg.PoolClient; } + | { + /** + * No need to connect to the database, the client is already connected. + */ + shouldConnect?: false; + client: Sql; + } ); /** @@ -65,7 +129,7 @@ export class PGVectorStore private readonly dimensions: number = DEFAULT_DIMENSIONS; private isDBConnected: boolean = false; - private db: pg.ClientBase | null = null; + private db: IsomorphicDB | null = null; private readonly clientConfig: pg.ClientConfig | null = null; constructor(config: PGVectorStoreConfig) { @@ -76,9 +140,14 @@ export class PGVectorStore if ("clientConfig" in config) { this.clientConfig = config.clientConfig; } else { - this.isDBConnected = - config.shouldConnect !== undefined ? !config.shouldConnect : false; - this.db = config.client; + if (typeof config.client === "function") { + this.isDBConnected = true; + this.db = fromPostgres(config.client); + } else { + this.isDBConnected = + config.shouldConnect !== undefined ? !config.shouldConnect : false; + this.db = fromPG(config.client); + } } } @@ -104,7 +173,7 @@ export class PGVectorStore return this.collection; } - private async getDb(): Promise { + private async getDb(): Promise { if (!this.db) { const pg = await import("pg"); const { Client } = pg.default ? pg.default : pg; @@ -124,7 +193,7 @@ export class PGVectorStore await registerTypes(db); // All good? Keep the connection reference - this.db = db; + this.db = fromPG(db); } if (this.db && !this.isDBConnected) { @@ -132,8 +201,7 @@ export class PGVectorStore this.isDBConnected = true; } - this.db.on("end", () => { - // Connection closed + this.db.onCloseEvent(() => { this.isDBConnected = false; }); @@ -143,22 +211,23 @@ export class PGVectorStore return this.db; } - private async checkSchema(db: pg.ClientBase) { + private async checkSchema(db: IsomorphicDB) { await db.query(`CREATE SCHEMA IF NOT EXISTS ${this.schemaName}`); - const tbl = `CREATE TABLE IF NOT EXISTS ${this.schemaName}.${this.tableName}( + await db.query(`CREATE TABLE IF NOT EXISTS ${this.schemaName}.${this.tableName}( id uuid DEFAULT gen_random_uuid() PRIMARY KEY, external_id VARCHAR, collection VARCHAR, document TEXT, metadata JSONB DEFAULT '{}', embeddings VECTOR(${this.dimensions}) - )`; - await db.query(tbl); - - const idxs = `CREATE INDEX IF NOT EXISTS idx_${this.tableName}_external_id ON ${this.schemaName}.${this.tableName} (external_id); - CREATE INDEX IF NOT EXISTS idx_${this.tableName}_collection ON ${this.schemaName}.${this.tableName} (collection);`; - await db.query(idxs); + )`); + await db.query( + `CREATE INDEX IF NOT EXISTS idx_${this.tableName}_external_id ON ${this.schemaName}.${this.tableName} (external_id);`, + ); + await db.query( + `CREATE INDEX IF NOT EXISTS idx_${this.tableName}_collection ON ${this.schemaName}.${this.tableName} (collection);`, + ); // TODO add IVFFlat or HNSW indexing? return db; @@ -222,9 +291,7 @@ export class PGVectorStore const db = await this.getDb(); - try { - await db.query("BEGIN"); - + return db.begin(async (query) => { const data = this.getDataToInsert(embeddingResults); const placeholders = data @@ -253,15 +320,9 @@ export class PGVectorStore `; const flattenedParams = data.flat(); - const result = await db.query(sql, flattenedParams); - - await db.query("COMMIT"); - - return result.rows.map((row) => row.id as string); - } catch (error) { - await db.query("ROLLBACK"); - throw error; - } + const result = await query(sql, flattenedParams); + return result.map((row) => row.id as string); + }); } /** @@ -455,19 +516,22 @@ export class PGVectorStore const db = await this.getDb(); const results = await db.query(sql, params); - const nodes = results.rows.map((row) => { + const nodes = results.map((row) => { return new Document({ id_: row.id, text: row.document, metadata: row.metadata, - embedding: row.embeddings, + embedding: + typeof row.embeddings === "string" + ? JSON.parse(row.embeddings) + : row.embeddings, }); }); const ret = { nodes: nodes, - similarities: results.rows.map((row) => 1 - row.s), - ids: results.rows.map((row) => row.id), + similarities: results.map((row) => 1 - row.s), + ids: results.map((row) => row.id), }; return Promise.resolve(ret); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 989a37cc82..b5f664d1c1 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -171,6 +171,9 @@ importers: pathe: specifier: ^1.1.2 version: 1.1.2 + postgres: + specifier: ^3.4.4 + version: 3.4.4 devDependencies: '@types/node': specifier: ^22.5.1 @@ -198,6 +201,8 @@ importers: specifier: ^5.6.2 version: 5.6.2 + examples/vector-store/pg: {} + packages/autotool: dependencies: '@swc/core': @@ -677,6 +682,9 @@ importers: pgvector: specifier: 0.2.0 version: 0.2.0 + postgres: + specifier: ^3.4.4 + version: 3.4.4 typescript: specifier: ^5.6.2 version: 5.6.2 @@ -9559,6 +9567,10 @@ packages: postgres-range@1.1.4: resolution: {integrity: sha512-i/hbxIE9803Alj/6ytL7UHQxRvZkI9O4Sy+J3HGc4F4oo/2eQAjTSNJ0bfxyse3bH0nuVesCk+3IRLaMtG3H6w==} + postgres@3.4.4: + resolution: {integrity: sha512-IbyN+9KslkqcXa8AO9fxpk97PA4pzewvpi2B3Dwy9u4zpV32QicaEdgmF3eSQUzdRk7ttDHQejNgAEr4XoeH4A==} + engines: {node: '>=12'} + prebuild-install@7.1.2: resolution: {integrity: sha512-UnNke3IQb6sgarcZIDU3gbMeTp/9SSU1DAIkil7PrqG1vZlBtY5msYccSKSHDqa3hNg436IXK+SNImReuA1wEQ==} engines: {node: '>=10'} @@ -23356,6 +23368,8 @@ snapshots: postgres-range@1.1.4: {} + postgres@3.4.4: {} + prebuild-install@7.1.2: dependencies: detect-libc: 2.0.3 diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 78e2bf3501..d6fd2f01e7 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -8,4 +8,4 @@ packages: - "packages/llamaindex/e2e/examples/*" - "packages/autotool/examples/*" - "examples/" - - "examples/*" + - "examples/**" diff --git a/tsconfig.json b/tsconfig.json index a03509a020..772625f842 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -86,6 +86,9 @@ { "path": "./examples/readers" }, + { + "path": "./examples/vector-store/pg/tsconfig.json" + }, { "path": "./packages/experimental/tsconfig.json" }