Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixes embeddings generation via plugin-embeddings (#849) #852

Merged
merged 5 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion packages/orama/src/components/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ export function runBeforeSearch<T extends AnyOrama>(
language: string | undefined
): Promise<void> | void {
const needAsync = hooks.some(isAsyncFunction)

if (needAsync) {
return (async () => {
for (const hook of hooks) {
Expand Down
1 change: 1 addition & 0 deletions packages/orama/src/methods/search-hybrid.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ export function hybridSearch<T extends AnyOrama, ResultDocument = TypedDocument<
}

const asyncNeeded = orama.beforeSearch?.length || orama.afterSearch?.length

if (asyncNeeded) {
return executeSearchAsync()
}
Expand Down
2 changes: 2 additions & 0 deletions packages/orama/src/methods/search-vector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export function innerVectorSearch<T extends AnyOrama, ResultDocument = TypedDocu

const vectorIndex = orama.data.index.vectorIndexes[vector!.property]
const vectorSize = vectorIndex.node.size

if (vector?.value.length !== vectorSize) {
if (vector?.property === undefined || vector?.value.length === undefined) {
throw createError('INVALID_INPUT_VECTOR', 'undefined', vectorSize, 'undefined')
Expand Down Expand Up @@ -121,6 +122,7 @@ export function searchVector<T extends AnyOrama, ResultDocument = TypedDocument<
}

const asyncNeeded = orama.beforeSearch?.length || orama.afterSearch?.length

if (asyncNeeded) {
return executeSearchAsync()
}
Expand Down
1 change: 1 addition & 0 deletions packages/orama/src/trees/vector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ export function findSimilarVectors(
const similarVectors: SimilarVector[] = []

const base = keys ? keys : vectors.keys()

for (const vectorId of base) {
const entry = vectors.get(vectorId)
if (!entry) {
Expand Down
2 changes: 1 addition & 1 deletion packages/orama/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ export function isAsyncFunction(func: any): boolean {
return func?.constructor?.name === 'AsyncFunction'
}


const withIntersection = 'intersection' in (new Set());

export function setIntersection<V>(...sets: Set<V>[]): Set<V> {
// Fast path 1
if (sets.length === 0) {
Expand Down
68 changes: 45 additions & 23 deletions packages/orama/tests/utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
import t from 'tap'
import { formatBytes, formatNanoseconds, getOwnProperty, getNested, flattenObject, setUnion, setIntersection } from '../src/utils.js'
import { formatBytes, formatNanoseconds, getOwnProperty, getNested, flattenObject, setUnion, setIntersection, isAsyncFunction } from '../src/utils.js'

t.test('utils', async (t) => {
t.test('should correctly format bytes', async (t) => {
t.equal(await formatBytes(0), '0 Bytes')
t.equal(await formatBytes(1), '1 Bytes')
t.equal(await formatBytes(1024), '1 KB')
t.equal(await formatBytes(1024 ** 2), '1 MB')
t.equal(await formatBytes(1024 ** 3), '1 GB')
t.equal(await formatBytes(1024 ** 4), '1 TB')
t.equal(await formatBytes(1024 ** 5), '1 PB')
t.equal(await formatBytes(1024 ** 6), '1 EB')
t.equal(await formatBytes(1024 ** 7), '1 ZB')
t.equal(formatBytes(0), '0 Bytes')
t.equal(formatBytes(1), '1 Bytes')
t.equal(formatBytes(1024), '1 KB')
t.equal(formatBytes(1024 ** 2), '1 MB')
t.equal(formatBytes(1024 ** 3), '1 GB')
t.equal(formatBytes(1024 ** 4), '1 TB')
t.equal(formatBytes(1024 ** 5), '1 PB')
t.equal(formatBytes(1024 ** 6), '1 EB')
t.equal(formatBytes(1024 ** 7), '1 ZB')
})

t.test('should correctly format nanoseconds', async (t) => {
t.equal(await formatNanoseconds(1n), '1ns')
t.equal(await formatNanoseconds(10n), '10ns')
t.equal(await formatNanoseconds(100n), '100ns')
t.equal(await formatNanoseconds(1_000n), '1μs')
t.equal(await formatNanoseconds(10_000n), '10μs')
t.equal(await formatNanoseconds(100_000n), '100μs')
t.equal(await formatNanoseconds(1_000_000n), '1ms')
t.equal(await formatNanoseconds(10_000_000n), '10ms')
t.equal(await formatNanoseconds(100_000_000n), '100ms')
t.equal(await formatNanoseconds(1000_000_000n), '1s')
t.equal(await formatNanoseconds(10_000_000_000n), '10s')
t.equal(await formatNanoseconds(100_000_000_000n), '100s')
t.equal(await formatNanoseconds(1000_000_000_000n), '1000s')
t.equal(formatNanoseconds(1n), '1ns')
t.equal(formatNanoseconds(10n), '10ns')
t.equal(formatNanoseconds(100n), '100ns')
t.equal(formatNanoseconds(1_000n), '1μs')
t.equal(formatNanoseconds(10_000n), '10μs')
t.equal(formatNanoseconds(100_000n), '100μs')
t.equal(formatNanoseconds(1_000_000n), '1ms')
t.equal(formatNanoseconds(10_000_000n), '10ms')
t.equal(formatNanoseconds(100_000_000n), '100ms')
t.equal(formatNanoseconds(1000_000_000n), '1s')
t.equal(formatNanoseconds(10_000_000_000n), '10s')
t.equal(formatNanoseconds(100_000_000_000n), '100s')
t.equal(formatNanoseconds(1000_000_000_000n), '1000s')
})

t.test('should check object properties', async (t) => {
Expand Down Expand Up @@ -95,6 +95,28 @@ t.test('utils', async (t) => {
t.equal((flattened as Record<string, string>).foo, 'bar')
t.equal(flattened['nested.nested2.nested3.bar'], 'baz')
})

// This test is skipped because the implementation of isAsyncFunction is temporary and will be
// removed in a future version of Orama.
t.skip('should correctly detect an async function', t => {
async function asyncFunction() {
return 'async'
}

function returnPromise() {
return new Promise((resolve) => {
resolve('promise')
})
}

function syncFunction() {
return 'sync'
}

t.equal(isAsyncFunction(asyncFunction), true)
t.equal(isAsyncFunction(returnPromise), false) // Returing a promise is not async, JS cannot detect it as async
t.equal(isAsyncFunction(syncFunction), false)
})
})

t.test('setUnion', async t => {
Expand Down
21 changes: 15 additions & 6 deletions packages/plugin-embeddings/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ function getPropertiesValues(schema: object, properties: string[]) {
.join('. ')
}

function normalizeVector(v: number[]): number[] {
const norm = Math.sqrt(v.reduce((sum, val) => sum + val * val, 0));
return v.map(val => val / norm);
}

export const embeddingsType = 'vector[512]'

export async function pluginEmbeddings(pluginParams: PluginEmbeddingsParams): Promise<OramaPluginAsync> {
Expand All @@ -49,9 +54,9 @@ export async function pluginEmbeddings(pluginParams: PluginEmbeddingsParams): Pr
console.log(`Generating embeddings for properties "${properties.join(', ')}": "${values}"`)
}

const embeddings = await model.embed(values)
const embeddings = Array.from(await (await model.embed(values)).data())

params[pluginParams.embeddings.defaultProperty] = (await embeddings.data()) as unknown as number[]
params[pluginParams.embeddings.defaultProperty] = normalizeVector(embeddings)
},

async beforeSearch<T extends AnyOrama>(_db: AnyOrama, params: SearchParams<T, TypedDocument<any>>) {
Expand All @@ -64,21 +69,25 @@ export async function pluginEmbeddings(pluginParams: PluginEmbeddingsParams): Pr
}

if (!params.term) {
throw new Error('Neither "term" nor "vector" parameters were provided')
throw new Error('No "term" or "vector" parameters were provided')
}

const embeddings = await model.embed(params.term) as unknown as number[]
const embeddings = Array.from(await (await model.embed(params.term)).data()) as unknown as number[]

if (!params.vector) {
params.vector = {
// eslint-disable-next-line
// @ts-ignore
property: params?.vector?.property ?? pluginParams.embeddings.defaultProperty,
value: embeddings
value: normalizeVector(embeddings)
}
}

console.log({
vector: normalizeVector(embeddings)
})

params.vector.value = embeddings
params.vector.value = normalizeVector(embeddings)
}
}
}
Loading
Loading