Skip to content

Commit

Permalink
refactor: improve vector search performance
Browse files Browse the repository at this point in the history
  • Loading branch information
charnould committed Dec 15, 2024
1 parent 72a81cf commit d0fa51d
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 33 deletions.
13 changes: 13 additions & 0 deletions tests/llm/search-by-vectors.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { expect, test } from 'bun:test'
import { query_db } from '../../utils/search-by-vectors'

test('should vector search works', async () => {
const r = await query_db('proprietary.private', 'quel est le mode de chauffage ?', {
building: 'Racine',
process: 'none'
})

console.log(r)

expect(2).toBe(2)
})
252 changes: 219 additions & 33 deletions utils/search-by-vectors.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { openai } from '@ai-sdk/openai'
import { embed } from 'ai'
import _ from 'lodash'
import { z } from 'zod'
import { db } from '../utils/database'
import type { Db_Name } from '../utils/database'
Expand Down Expand Up @@ -27,17 +28,11 @@ export const vector_search = async (query: string, context: AIContext) => {
// Initialize the vector search results
const r: Vector_Search_Result = { community: [], private: [], public: [] }

// Generate the embedding for the user’s query
const { embedding } = await embed({
model: openai.embedding('text-embedding-3-large'),
value: query
})

// Perform searches based on knowledge access
if (k.community) r.community = query_db('community', new Float32Array(embedding))
if (k.proprietary.public) r.public = query_db('proprietary.public', new Float32Array(embedding))
if (k.proprietary.private)
r.private = query_db('proprietary.private', new Float32Array(embedding))
const entities = context.query?.named_entities
if (k.community) r.community = await query_db('community', query)
if (k.proprietary.public) r.public = await query_db('proprietary.public', query, entities)
if (k.proprietary.private) r.private = await query_db('proprietary.private', query, entities)

// Parse results and return them to ensure
// TypeScript types align with the schema
Expand All @@ -60,41 +55,232 @@ export const vector_search = async (query: string, context: AIContext) => {

//
// Query the database with the embedding to retrieve the matching chunks
const query_db = (db_name: Db_Name, embedding: Float32Array) => {
export const query_db = async (db_name: Db_Name, query: string, entities?) => {
// Initialize variables
const db_instance = db(db_name)
const results = db_instance
.prepare('SELECT rowid, distance FROM vectors WHERE vector MATCH ? AND k = 10')
.all(embedding) as { rowid: number; distance: number }[]

return results.map((result) => ({
...result,
...(db_instance.prepare('SELECT * FROM chunks WHERE rowid = ?;').get(result.rowid) as {
rowid: number
chunk: string
})
}))
const entity_promises: Promise<string[]>[] = []
let user_query: {
entity_hash: string
chunk_hash: string
chunk_text: string
distance: number
}[]

// No need to apply entity-specific filters to `community` knowledge.
// Semantic analysis alone is robust and sufficient for handling the content effectively.
if (db_name !== 'community') {
if (entities.building !== null) {
entity_promises.push(
get_relevant_entities({
// Apply stricter validation and rules when processing data related to a building.
// Ensure higher standards and fewer allowances for inconsistencies.
threshold: 0.5,
entity: `Building: ${entities.building}`,
db: db_instance
})
)
}

if (entities.process !== null) {
entity_promises.push(
get_relevant_entities({
// Allow for flexibility and leniency when processing data related to a process.
// Prioritize adaptability to accommodate variations and exceptions.
threshold: 0.8,
entity: `Process: ${entities.process}`,
db: db_instance
})
)
}
}

// Get maybe relevant entities
const results = (await Promise.all(entity_promises)) ?? []

// Transfom string query into vector query
const query_vectors = await generate_embeddings(query)

// If no relevant `entities` are identified, bypass entity-specific prefilters
// and directly search for relevant chunks to ensure comprehensive results.
if (results.length === 0) {
user_query = search_without_prefiltering(query_vectors, db_instance)
} else {
// If relevant `entities` are identified, apply pre-filters during the search
// to narrow down the results and focus on entity-specific chunks for precision.
user_query = search_with_prefiltering(query_vectors, db_instance, results)

// If no results are found after pre-filtering,
// perform a full database search.
if (user_query.length === 0)
user_query = search_without_prefiltering(query_vectors, db_instance)
}

return user_query
}

//
// Get relevant entities
export const get_relevant_entities = async ({ threshold, entity, db }) => {
// Step 1. Generate guessed entity embedding
const entity_vectors = await generate_embeddings(entity)

// Step 2. Search in DB for near-looking entities
const entity_query = db
.prepare(
`
SELECT entity_hash, entity_text, distance
FROM vectors
WHERE entity_vector MATCH ?
AND k = 15
`
)
.all(new Float32Array(entity_vectors)) as unknown as {
entity_hash: string
entity_text: string
distance: number
}[]

// Step 3. Keep only entities that are relevant (= below a thresold)
const relevant_entities = _.chain(entity_query)
.filter((item) => item.distance < threshold)
.uniqBy('entity_hash')
.map('entity_hash')
.value()

return relevant_entities
}

//
// prettier-ignore
// biome-ignore format: readability
// Typescript Type via Zod
export const Vector_Search_Result = z.object({

community : z.array(z.object({
rowid : z.number(),
distance : z.number(),
chunk : z.string() })).default([]),
chunk_hash : z.string(),
chunk_text : z.string(),
entity_hash : z.string(),
distance : z.number() })).default([]),

private : z.array(z.object({
rowid : z.number(),
distance : z.number(),
chunk : z.string() })).default([]),
chunk_hash : z.string(),
chunk_text : z.string(),
entity_hash : z.string(),
distance : z.number() })).default([]),

public : z.array(z.object({
rowid : z.number(),
distance : z.number(),
chunk : z.string() })).default([])

chunk_hash : z.string(),
chunk_text : z.string(),
entity_hash : z.string(),
distance : z.number() })).default([]),
})

export type Vector_Search_Result = z.infer<typeof Vector_Search_Result>

/**
* Generate Embeddings
*
* This function generates a vector embedding for a given string using OpenAI's `text-embedding-3-large` model.
*
* Key Features:
* - Takes a text input and produces a high-dimensional embedding vector for semantic similarity tasks.
* - Uses the OpenAI API to generate embeddings, leveraging a state-of-the-art model.
* - Embeddings can be used in downstream tasks such as similarity search, classification, or clustering.
*
* Returns:
* - `embedding` (Array<number>): A high-dimensional vector representing the semantic meaning of the input text.
*/
export const generate_embeddings = async (string: string) => {
try {
const { embedding } = await embed({
model: openai.embedding('text-embedding-3-large'),
value: string
})

return embedding
} catch (e) {
console.error('Failed to generate embedding:', e)
throw new Error('Embedding generation failed.')
}
}

/**
* Search Without Prefiltering
*
* This function executes a vector similarity search on a database of chunked text data
* without applying any prefiltering, ensuring that all relevant results are considered.
*
* Key Features:
* - Accepts query vectors for semantic search using vector similarity techniques.
* - Searches across the entire dataset without narrowing results to specific entities.
* - Queries the database to fetch the top 10 matching chunks based on vector similarity.
*
* Returns:
* - An array of matching chunks, including:
* - `chunk_text` (string): The text content of the matching chunk.
* - `chunk_hash` (string): The unique hash identifier for the chunk.
* - `entity_hash` (string): The entity hash associated with the chunk.
* - `distance` (number): The vector similarity distance of the match.
*/
export const search_without_prefiltering = (query_vectors, database) =>
database
.prepare(
`
SELECT chunk_text, chunk_hash, entity_hash, distance
FROM vectors
WHERE chunk_vector MATCH ?
AND k = 10
`
)
.all(new Float32Array(query_vectors)) as unknown as {
entity_hash: string
chunk_hash: string
chunk_text: string
distance: number
}[]

/**
* Search with Prefiltering
*
* This function performs a prefiltered vector similarity search on a database of chunked text data.
*
* Key Features:
* - Accepts query vectors for semantic search using vector similarity techniques.
* - Applies prefilters using provided entity-specific hash filters to narrow down the results.
* - Queries the database to fetch the top 10 matching chunks based on similarity and prefilter conditions.
*
* Returns:
* - An array of matching chunks, including:
* - `chunk_text` (string): The text content of the matching chunk.
* - `chunk_hash` (string): The unique hash identifier for the chunk.
* - `entity_hash` (string): The entity hash associated with the chunk.
* - `distance` (number): The vector similarity distance of the match.
*/
export const search_with_prefiltering = (query_vectors, database, filters) =>
database
.prepare(
`
SELECT chunk_text, chunk_hash, entity_hash, distance
FROM vectors
WHERE chunk_vector MATCH ?
AND k = 10
AND entity_hash IN (${_.flattenDeep(filters)
.map(() => '?')
.join(', ')})
`
)
.all(new Float32Array(query_vectors), ..._.flattenDeep(filters)) as unknown as {
entity_hash: string
chunk_hash: string
chunk_text: string
distance: number
}[]

//
// Calculate Cosine Similarity
// biome-ignore lint:
export const cosine_similarity = (vec_1: any[], vec_2: any[]) => {
const dot_product = vec_1.reduce((sum, val, i) => sum + val * vec_2[i], 0)
const magnitude_A = Math.sqrt(vec_1.reduce((sum, val) => sum + val * val, 0))
const magnitude_B = Math.sqrt(vec_2.reduce((sum, val) => sum + val * val, 0))
return dot_product / (magnitude_A * magnitude_B)
}

0 comments on commit d0fa51d

Please sign in to comment.