Skip to content

Commit

Permalink
feat: support npm:postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
himself65 committed Sep 24, 2024
1 parent 50e6b57 commit 55a1122
Show file tree
Hide file tree
Showing 14 changed files with 219 additions and 34 deletions.
8 changes: 8 additions & 0 deletions examples/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
"incremental": true,
"composite": true
},
"references": [
{

Check failure on line 16 in examples/tsconfig.json

View workflow job for this annotation

GitHub Actions / typecheck

Referenced project '/home/runner/work/LlamaIndexTS/LlamaIndexTS/examples/readers/tsconfig.json' must have setting "composite": true.
"path": "./readers/tsconfig.json"
},
{

Check failure on line 19 in examples/tsconfig.json

View workflow job for this annotation

GitHub Actions / typecheck

Cannot write file '/home/runner/work/LlamaIndexTS/LlamaIndexTS/examples/lib/.tsbuildinfo' because it will overwrite '.tsbuildinfo' file generated by referenced project '/home/runner/work/LlamaIndexTS/LlamaIndexTS/examples/vector-store/pg/tsconfig.json'
"path": "./vector-store/pg/tsconfig.json"
}
],
"ts-node": {
"files": true,
"compilerOptions": {
Expand Down
6 changes: 6 additions & 0 deletions examples/vector-store/pg/.env.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# neon template
PGHOST=
PGDATABASE=
PGUSER=
PGPASSWORD=
ENDPOINT_ID=
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// load-docs.ts
import fs from "fs/promises";
import {
PGVectorStore,
SimpleDirectoryReader,
storageContextFromDefaults,
VectorStoreIndex,
} from "llamaindex";
import fs from "node:fs/promises";

async function getSourceFilenames(sourceDir: string) {
return await fs
Expand Down
45 changes: 45 additions & 0 deletions examples/vector-store/pg/neon.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import dotenv from "dotenv";
import { Document, PGVectorStore, VectorStoreQueryMode } from "llamaindex";
import postgres from "postgres";

Check failure on line 4 in examples/vector-store/pg/neon.ts

View workflow job for this annotation

GitHub Actions / typecheck

Cannot find module 'postgres' or its corresponding type declarations.

dotenv.config();

const { PGHOST, PGDATABASE, PGUSER, ENDPOINT_ID } = process.env;
const PGPASSWORD = decodeURIComponent(process.env.PGPASSWORD!);

const sql = postgres({
host: PGHOST,
database: PGDATABASE,
username: PGUSER,
password: PGPASSWORD,
port: 5432,
ssl: "require",
connection: {
options: `project=${ENDPOINT_ID}`,
},
});

await sql`CREATE EXTENSION IF NOT EXISTS vector`;

const vectorStore = new PGVectorStore({
dimensions: 3,
client: sql,
});

await vectorStore.add([
new Document({
text: "hello, world",
embedding: [1, 2, 3],
}),
]);

const results = await vectorStore.query({
mode: VectorStoreQueryMode.DEFAULT,
similarityTopK: 1,
queryEmbedding: [1, 2, 3],
});

console.log("result", results);

await sql.end();
8 changes: 8 additions & 0 deletions examples/vector-store/pg/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"name": "pg-vector-store",
"type": "module",
"private": true,
"devDependencies": {
"postgres": "^3.4.4"
}
}
File renamed without changes.
9 changes: 9 additions & 0 deletions examples/vector-store/pg/tsconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"extends": "../../tsconfig.json",
"compilerOptions": {
"outDir": "./dist",
"types": ["node"],
"skipLibCheck": true
},
"include": ["./**/*.ts"]
}
14 changes: 14 additions & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,20 @@
"types": "./dist/retriever/index.d.ts",
"default": "./dist/retriever/index.js"
}
},
"./vector-store": {
"require": {
"types": "./dist/vector-store/index.d.cts",
"default": "./dist/vector-store/index.cjs"
},
"import": {
"types": "./dist/vector-store/index.d.ts",
"default": "./dist/vector-store/index.js"
},
"default": {
"types": "./dist/vector-store/index.d.ts",
"default": "./dist/vector-store/index.js"
}
}
},
"files": [
Expand Down
13 changes: 13 additions & 0 deletions packages/core/src/vector-store/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/**
* should compatible with npm:pg and npm:postgres
*/
export interface IsomorphicDB {
query: (sql: string, params?: any[]) => Promise<any[]>;
// begin will wrap the multiple queries in a transaction
begin: <T>(fn: (query: IsomorphicDB["query"]) => Promise<T>) => Promise<T>;

// event handler
connect: () => Promise<void>;
close: () => Promise<void>;
onCloseEvent: (listener: () => void) => void;
}
1 change: 1 addition & 0 deletions packages/llamaindex/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
"glob": "^11.0.0",
"pg": "^8.12.0",
"pgvector": "0.2.0",
"postgres": "^3.4.4",
"typescript": "^5.6.2"
},
"engines": {
Expand Down
127 changes: 95 additions & 32 deletions packages/llamaindex/src/vector-store/PGVectorStore.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import type pg from "pg";

import type { IsomorphicDB } from "@llamaindex/core/vector-store";
import type { Sql } from "postgres";
import {
FilterCondition,
FilterOperator,
Expand All @@ -11,6 +13,60 @@ import {
type VectorStoreQueryResult,
} from "./types.js";

function fromPostgres(client: Sql): IsomorphicDB {
return {
query: async (sql: string, params?: any[]): Promise<any[]> => {
return client.unsafe(sql, params);
},
begin: async (fn) => {
let res: any;
await client.begin(async (scopedClient) => {
const queryFn = async (sql: string, params?: any[]): Promise<any[]> => {
return scopedClient.unsafe(sql, params);
};
res = await fn(queryFn);
});
return res;
},
connect: () => Promise.resolve(),
close: async () => client.end(),
onCloseEvent: (fn) => {
// no close event
},
};
}

function fromPG(client: pg.Client | pg.PoolClient): IsomorphicDB {
const queryFn = async (sql: string, params?: any[]): Promise<any[]> => {
return (await client.query(sql, params)).rows;
};
return {
query: queryFn,
begin: async (fn) => {
await client.query("BEGIN");
try {
const result = await fn(queryFn);
await client.query("COMMIT");
return result;
} catch (e) {
await client.query("ROLLBACK");
throw e;
}
},
connect: () => client.connect(),
close: async () => {
if ("end" in client) {
await client.end();
} else if ("release" in client) {
client.release();
}
},
onCloseEvent: (fn) => {
client.on("end", fn);
},
};
}

import { escapeLikeString } from "./utils.js";

import type { BaseEmbedding } from "@llamaindex/core/embeddings";
Expand Down Expand Up @@ -47,6 +103,13 @@ export type PGVectorStoreConfig = PGVectorStoreBaseConfig &
shouldConnect?: boolean | undefined;
client: pg.Client | pg.PoolClient;
}
| {
/**
* No need to connect to the database, the client is already connected.
*/
shouldConnect?: false;
client: Sql;
}
);

/**
Expand All @@ -65,7 +128,7 @@ export class PGVectorStore
private readonly dimensions: number = DEFAULT_DIMENSIONS;

private isDBConnected: boolean = false;
private db: pg.ClientBase | null = null;
private db: IsomorphicDB | null = null;
private readonly clientConfig: pg.ClientConfig | null = null;

constructor(config: PGVectorStoreConfig) {
Expand All @@ -76,9 +139,14 @@ export class PGVectorStore
if ("clientConfig" in config) {
this.clientConfig = config.clientConfig;
} else {
this.isDBConnected =
config.shouldConnect !== undefined ? !config.shouldConnect : false;
this.db = config.client;
if (typeof config.client === "function") {
this.isDBConnected = true;
this.db = fromPostgres(config.client);
} else {
this.isDBConnected =
config.shouldConnect !== undefined ? !config.shouldConnect : false;
this.db = fromPG(config.client);
}
}
}

Expand All @@ -104,7 +172,7 @@ export class PGVectorStore
return this.collection;
}

private async getDb(): Promise<pg.ClientBase> {
private async getDb(): Promise<IsomorphicDB> {
if (!this.db) {
const pg = await import("pg");
const { Client } = pg.default ? pg.default : pg;
Expand All @@ -124,16 +192,15 @@ export class PGVectorStore
await registerTypes(db);

// All good? Keep the connection reference
this.db = db;
this.db = fromPG(db);
}

if (this.db && !this.isDBConnected) {
await this.db.connect();
this.isDBConnected = true;
}

this.db.on("end", () => {
// Connection closed
this.db.onCloseEvent(() => {
this.isDBConnected = false;
});

Expand All @@ -143,22 +210,23 @@ export class PGVectorStore
return this.db;
}

private async checkSchema(db: pg.ClientBase) {
private async checkSchema(db: IsomorphicDB) {
await db.query(`CREATE SCHEMA IF NOT EXISTS ${this.schemaName}`);

const tbl = `CREATE TABLE IF NOT EXISTS ${this.schemaName}.${this.tableName}(
await db.query(`CREATE TABLE IF NOT EXISTS ${this.schemaName}.${this.tableName}(
id uuid DEFAULT gen_random_uuid() PRIMARY KEY,
external_id VARCHAR,
collection VARCHAR,
document TEXT,
metadata JSONB DEFAULT '{}',
embeddings VECTOR(${this.dimensions})
)`;
await db.query(tbl);

const idxs = `CREATE INDEX IF NOT EXISTS idx_${this.tableName}_external_id ON ${this.schemaName}.${this.tableName} (external_id);
CREATE INDEX IF NOT EXISTS idx_${this.tableName}_collection ON ${this.schemaName}.${this.tableName} (collection);`;
await db.query(idxs);
)`);
await db.query(
`CREATE INDEX IF NOT EXISTS idx_${this.tableName}_external_id ON ${this.schemaName}.${this.tableName} (external_id);`,
);
await db.query(
`CREATE INDEX IF NOT EXISTS idx_${this.tableName}_collection ON ${this.schemaName}.${this.tableName} (collection);`,
);

// TODO add IVFFlat or HNSW indexing?
return db;
Expand Down Expand Up @@ -222,9 +290,7 @@ export class PGVectorStore

const db = await this.getDb();

try {
await db.query("BEGIN");

return db.begin(async (query) => {
const data = this.getDataToInsert(embeddingResults);

const placeholders = data
Expand Down Expand Up @@ -253,15 +319,9 @@ export class PGVectorStore
`;

const flattenedParams = data.flat();
const result = await db.query(sql, flattenedParams);

await db.query("COMMIT");

return result.rows.map((row) => row.id as string);
} catch (error) {
await db.query("ROLLBACK");
throw error;
}
const result = await query(sql, flattenedParams);
return result.map((row) => row.id as string);
});
}

/**
Expand Down Expand Up @@ -455,19 +515,22 @@ export class PGVectorStore
const db = await this.getDb();
const results = await db.query(sql, params);

const nodes = results.rows.map((row) => {
const nodes = results.map((row) => {
return new Document({
id_: row.id,
text: row.document,
metadata: row.metadata,
embedding: row.embeddings,
embedding:
typeof row.embeddings === "string"
? JSON.parse(row.embeddings)
: row.embeddings,
});
});

const ret = {
nodes: nodes,
similarities: results.rows.map((row) => 1 - row.s),
ids: results.rows.map((row) => row.id),
similarities: results.map((row) => 1 - row.s),
ids: results.map((row) => row.id),
};

return Promise.resolve(ret);
Expand Down
Loading

0 comments on commit 55a1122

Please sign in to comment.