From e96764a53c1d15b6f53cff6ec842f63f70d13842 Mon Sep 17 00:00:00 2001 From: MarcoMandar Date: Mon, 28 Oct 2024 16:18:23 +0200 Subject: [PATCH] update sqlite.ts to use VEC Signed-off-by: MarcoMandar --- src/adapters/sqlite.ts | 93 ++++++++++++++++++++---------------------- 1 file changed, 45 insertions(+), 48 deletions(-) diff --git a/src/adapters/sqlite.ts b/src/adapters/sqlite.ts index d7f598b3df..a519b12e28 100644 --- a/src/adapters/sqlite.ts +++ b/src/adapters/sqlite.ts @@ -210,18 +210,16 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter { // Insert the memory with the appropriate 'unique' value const sql = `INSERT OR REPLACE INTO memories (id, type, content, embedding, userId, roomId, \`unique\`, createdAt) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`; - this.db - .prepare(sql) - .run( - memory.id ?? v4(), - tableName, - content, - JSON.stringify(memory.embedding ?? embeddingZeroVector), - memory.userId, - memory.roomId, - isUnique ? 1 : 0, - createdAt - ); + this.db.prepare(sql).run( + memory.id ?? v4(), + tableName, + content, + new Float32Array(memory.embedding ?? embeddingZeroVector), // Store as Float32Array + memory.userId, + memory.roomId, + isUnique ? 1 : 0, + createdAt + ); } async searchMemories(params: { @@ -232,37 +230,36 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter { match_count: number; unique: boolean; }): Promise { + const queryParams = [ + new Float32Array(params.embedding), // Ensure embedding is Float32Array + params.tableName, + params.roomId, + params.match_count, + ]; + let sql = ` -SELECT *, (1 - vss_distance_l2(embedding, ?)) AS similarity -FROM memories -WHERE type = ? -AND roomId = ?`; + SELECT *, vec_distance_L2(embedding, ?) AS similarity + FROM memories + WHERE type = ?`; if (params.unique) { sql += " AND `unique` = 1"; } - sql += ` ORDER BY similarity DESC LIMIT ?`; - const queryParams = [ - JSON.stringify(params.embedding), - params.tableName, - params.roomId, - params.match_count, - ]; + sql += ` ORDER BY similarity ASC LIMIT ?`; // ASC for lower distance + // Updated queryParams order matches the placeholders const memories = this.db.prepare(sql).all(...queryParams) as (Memory & { similarity: number; })[]; - return memories.map((memory) => { - return { - ...memory, - createdAt: - typeof memory.createdAt === "string" - ? Date.parse(memory.createdAt as string) - : memory.createdAt, - content: JSON.parse(memory.content as unknown as string), - }; - }); + return memories.map((memory) => ({ + ...memory, + createdAt: + typeof memory.createdAt === "string" + ? Date.parse(memory.createdAt as string) + : memory.createdAt, + content: JSON.parse(memory.content as unknown as string), + })); } async searchMemoriesByEmbedding( @@ -276,15 +273,15 @@ AND roomId = ?`; } ): Promise { const queryParams = [ - JSON.stringify(embedding), - params.tableName, // JSON.stringify(embedding), + new Float32Array(embedding), + params.tableName, ]; let sql = ` - SELECT *, (1 - vss_distance_l2(embedding, ?)) AS similarity + SELECT *, vec_distance_L2(embedding, ?) AS similarity FROM memories - WHERE type = ?`; // AND vss_search(embedding, ?) + WHERE type = ?`; if (params.unique) { sql += " AND `unique` = 1"; @@ -330,21 +327,21 @@ AND roomId = ?`; SELECT * FROM memories WHERE type = ? - AND vss_search(${opts.query_field_name}, ?) - ORDER BY vss_search(${opts.query_field_name}, ?) DESC + AND vec_distance_L2(${opts.query_field_name}, ?) <= ? + ORDER BY vec_distance_L2(${opts.query_field_name}, ?) ASC LIMIT ? `; - const memories = this.db - .prepare(sql) - .all( - opts.query_table_name, - opts.query_input, - opts.query_input, - opts.query_match_count - ) as Memory[]; + const memories = this.db.prepare(sql).all( + opts.query_table_name, + new Float32Array(opts.query_input.split(",").map(Number)), // Convert string to Float32Array + opts.query_input, + new Float32Array(opts.query_input.split(",").map(Number)) + ) as Memory[]; return memories.map((memory) => ({ - embedding: JSON.parse(memory.embedding as unknown as string), + embedding: Array.from( + new Float32Array(memory.embedding as unknown as Buffer) + ), // Convert Buffer to number[] levenshtein_score: 0, })); }