From d0fa51dc49f9c744df94926a8fad6accb620c7c8 Mon Sep 17 00:00:00 2001 From: charnould Date: Mon, 9 Dec 2024 16:13:42 +0100 Subject: [PATCH] refactor: improve vector search performance --- tests/llm/search-by-vectors.test.ts | 13 ++ utils/search-by-vectors.ts | 252 ++++++++++++++++++++++++---- 2 files changed, 232 insertions(+), 33 deletions(-) create mode 100644 tests/llm/search-by-vectors.test.ts diff --git a/tests/llm/search-by-vectors.test.ts b/tests/llm/search-by-vectors.test.ts new file mode 100644 index 0000000..38e7d67 --- /dev/null +++ b/tests/llm/search-by-vectors.test.ts @@ -0,0 +1,13 @@ +import { expect, test } from 'bun:test' +import { query_db } from '../../utils/search-by-vectors' + +test('should vector search works', async () => { + const r = await query_db('proprietary.private', 'quel est le mode de chauffage ?', { + building: 'Racine', + process: 'none' + }) + + console.log(r) + + expect(2).toBe(2) +}) diff --git a/utils/search-by-vectors.ts b/utils/search-by-vectors.ts index e7d2e14..e9be23b 100644 --- a/utils/search-by-vectors.ts +++ b/utils/search-by-vectors.ts @@ -1,5 +1,6 @@ import { openai } from '@ai-sdk/openai' import { embed } from 'ai' +import _ from 'lodash' import { z } from 'zod' import { db } from '../utils/database' import type { Db_Name } from '../utils/database' @@ -27,17 +28,11 @@ export const vector_search = async (query: string, context: AIContext) => { // Initialize the vector search results const r: Vector_Search_Result = { community: [], private: [], public: [] } - // Generate the embedding for the user’s query - const { embedding } = await embed({ - model: openai.embedding('text-embedding-3-large'), - value: query - }) - // Perform searches based on knowledge access - if (k.community) r.community = query_db('community', new Float32Array(embedding)) - if (k.proprietary.public) r.public = query_db('proprietary.public', new Float32Array(embedding)) - if (k.proprietary.private) - r.private = query_db('proprietary.private', new Float32Array(embedding)) + const entities = context.query?.named_entities + if (k.community) r.community = await query_db('community', query) + if (k.proprietary.public) r.public = await query_db('proprietary.public', query, entities) + if (k.proprietary.private) r.private = await query_db('proprietary.private', query, entities) // Parse results and return them to ensure // TypeScript types align with the schema @@ -60,41 +55,232 @@ export const vector_search = async (query: string, context: AIContext) => { // // Query the database with the embedding to retrieve the matching chunks -const query_db = (db_name: Db_Name, embedding: Float32Array) => { +export const query_db = async (db_name: Db_Name, query: string, entities?) => { + // Initialize variables const db_instance = db(db_name) - const results = db_instance - .prepare('SELECT rowid, distance FROM vectors WHERE vector MATCH ? AND k = 10') - .all(embedding) as { rowid: number; distance: number }[] - - return results.map((result) => ({ - ...result, - ...(db_instance.prepare('SELECT * FROM chunks WHERE rowid = ?;').get(result.rowid) as { - rowid: number - chunk: string - }) - })) + const entity_promises: Promise[] = [] + let user_query: { + entity_hash: string + chunk_hash: string + chunk_text: string + distance: number + }[] + + // No need to apply entity-specific filters to `community` knowledge. + // Semantic analysis alone is robust and sufficient for handling the content effectively. + if (db_name !== 'community') { + if (entities.building !== null) { + entity_promises.push( + get_relevant_entities({ + // Apply stricter validation and rules when processing data related to a building. + // Ensure higher standards and fewer allowances for inconsistencies. + threshold: 0.5, + entity: `Building: ${entities.building}`, + db: db_instance + }) + ) + } + + if (entities.process !== null) { + entity_promises.push( + get_relevant_entities({ + // Allow for flexibility and leniency when processing data related to a process. + // Prioritize adaptability to accommodate variations and exceptions. + threshold: 0.8, + entity: `Process: ${entities.process}`, + db: db_instance + }) + ) + } + } + + // Get maybe relevant entities + const results = (await Promise.all(entity_promises)) ?? [] + + // Transfom string query into vector query + const query_vectors = await generate_embeddings(query) + + // If no relevant `entities` are identified, bypass entity-specific prefilters + // and directly search for relevant chunks to ensure comprehensive results. + if (results.length === 0) { + user_query = search_without_prefiltering(query_vectors, db_instance) + } else { + // If relevant `entities` are identified, apply pre-filters during the search + // to narrow down the results and focus on entity-specific chunks for precision. + user_query = search_with_prefiltering(query_vectors, db_instance, results) + + // If no results are found after pre-filtering, + // perform a full database search. + if (user_query.length === 0) + user_query = search_without_prefiltering(query_vectors, db_instance) + } + + return user_query +} + +// +// Get relevant entities +export const get_relevant_entities = async ({ threshold, entity, db }) => { + // Step 1. Generate guessed entity embedding + const entity_vectors = await generate_embeddings(entity) + + // Step 2. Search in DB for near-looking entities + const entity_query = db + .prepare( + ` + SELECT entity_hash, entity_text, distance + FROM vectors + WHERE entity_vector MATCH ? + AND k = 15 + ` + ) + .all(new Float32Array(entity_vectors)) as unknown as { + entity_hash: string + entity_text: string + distance: number + }[] + + // Step 3. Keep only entities that are relevant (= below a thresold) + const relevant_entities = _.chain(entity_query) + .filter((item) => item.distance < threshold) + .uniqBy('entity_hash') + .map('entity_hash') + .value() + + return relevant_entities } +// // prettier-ignore // biome-ignore format: readability // Typescript Type via Zod export const Vector_Search_Result = z.object({ - community : z.array(z.object({ - rowid : z.number(), - distance : z.number(), - chunk : z.string() })).default([]), + chunk_hash : z.string(), + chunk_text : z.string(), + entity_hash : z.string(), + distance : z.number() })).default([]), private : z.array(z.object({ - rowid : z.number(), - distance : z.number(), - chunk : z.string() })).default([]), + chunk_hash : z.string(), + chunk_text : z.string(), + entity_hash : z.string(), + distance : z.number() })).default([]), public : z.array(z.object({ - rowid : z.number(), - distance : z.number(), - chunk : z.string() })).default([]) - + chunk_hash : z.string(), + chunk_text : z.string(), + entity_hash : z.string(), + distance : z.number() })).default([]), }) export type Vector_Search_Result = z.infer + +/** + * Generate Embeddings + * + * This function generates a vector embedding for a given string using OpenAI's `text-embedding-3-large` model. + * + * Key Features: + * - Takes a text input and produces a high-dimensional embedding vector for semantic similarity tasks. + * - Uses the OpenAI API to generate embeddings, leveraging a state-of-the-art model. + * - Embeddings can be used in downstream tasks such as similarity search, classification, or clustering. + * + * Returns: + * - `embedding` (Array): A high-dimensional vector representing the semantic meaning of the input text. + */ +export const generate_embeddings = async (string: string) => { + try { + const { embedding } = await embed({ + model: openai.embedding('text-embedding-3-large'), + value: string + }) + + return embedding + } catch (e) { + console.error('Failed to generate embedding:', e) + throw new Error('Embedding generation failed.') + } +} + +/** + * Search Without Prefiltering + * + * This function executes a vector similarity search on a database of chunked text data + * without applying any prefiltering, ensuring that all relevant results are considered. + * + * Key Features: + * - Accepts query vectors for semantic search using vector similarity techniques. + * - Searches across the entire dataset without narrowing results to specific entities. + * - Queries the database to fetch the top 10 matching chunks based on vector similarity. + * + * Returns: + * - An array of matching chunks, including: + * - `chunk_text` (string): The text content of the matching chunk. + * - `chunk_hash` (string): The unique hash identifier for the chunk. + * - `entity_hash` (string): The entity hash associated with the chunk. + * - `distance` (number): The vector similarity distance of the match. + */ +export const search_without_prefiltering = (query_vectors, database) => + database + .prepare( + ` + SELECT chunk_text, chunk_hash, entity_hash, distance + FROM vectors + WHERE chunk_vector MATCH ? + AND k = 10 + ` + ) + .all(new Float32Array(query_vectors)) as unknown as { + entity_hash: string + chunk_hash: string + chunk_text: string + distance: number + }[] + +/** + * Search with Prefiltering + * + * This function performs a prefiltered vector similarity search on a database of chunked text data. + * + * Key Features: + * - Accepts query vectors for semantic search using vector similarity techniques. + * - Applies prefilters using provided entity-specific hash filters to narrow down the results. + * - Queries the database to fetch the top 10 matching chunks based on similarity and prefilter conditions. + * + * Returns: + * - An array of matching chunks, including: + * - `chunk_text` (string): The text content of the matching chunk. + * - `chunk_hash` (string): The unique hash identifier for the chunk. + * - `entity_hash` (string): The entity hash associated with the chunk. + * - `distance` (number): The vector similarity distance of the match. + */ +export const search_with_prefiltering = (query_vectors, database, filters) => + database + .prepare( + ` + SELECT chunk_text, chunk_hash, entity_hash, distance + FROM vectors + WHERE chunk_vector MATCH ? + AND k = 10 + AND entity_hash IN (${_.flattenDeep(filters) + .map(() => '?') + .join(', ')}) + ` + ) + .all(new Float32Array(query_vectors), ..._.flattenDeep(filters)) as unknown as { + entity_hash: string + chunk_hash: string + chunk_text: string + distance: number + }[] + +// +// Calculate Cosine Similarity +// biome-ignore lint: +export const cosine_similarity = (vec_1: any[], vec_2: any[]) => { + const dot_product = vec_1.reduce((sum, val, i) => sum + val * vec_2[i], 0) + const magnitude_A = Math.sqrt(vec_1.reduce((sum, val) => sum + val * val, 0)) + const magnitude_B = Math.sqrt(vec_2.reduce((sum, val) => sum + val * val, 0)) + return dot_product / (magnitude_A * magnitude_B) +}