-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat (ai/core): add embedMany function (#1617)
- Loading branch information
Showing
17 changed files
with
444 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--- | ||
'ai': patch | ||
--- | ||
|
||
feat (ai/core): add embedMany function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
--- | ||
title: embedMany | ||
description: Embed several values using the AI SDK Core (batch embedding) | ||
--- | ||
|
||
# `embedMany` | ||
|
||
Embed several values using an embedding model. The type of the value is defined | ||
by the embedding model. | ||
|
||
`embedMany` automatically splits large requests into smaller chunks if the model | ||
has a limit on how many embeddings can be generated in a single call. | ||
|
||
## Import | ||
|
||
<Snippet text={`import { embedMany } from "ai"`} prompt={false} /> | ||
|
||
<ReferenceTable packageName="core" functionName="embedMany" /> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import { mistral } from '@ai-sdk/mistral'; | ||
import { embedMany } from 'ai'; | ||
import dotenv from 'dotenv'; | ||
|
||
dotenv.config(); | ||
|
||
async function main() { | ||
const { embeddings } = await embedMany({ | ||
model: mistral.embedding('mistral-embed'), | ||
values: [ | ||
'sunny day at the beach', | ||
'rainy afternoon in the city', | ||
'snowy night in the mountains', | ||
], | ||
}); | ||
|
||
console.log(embeddings); | ||
} | ||
|
||
main().catch(console.error); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import { openai } from '@ai-sdk/openai'; | ||
import { embedMany } from 'ai'; | ||
import dotenv from 'dotenv'; | ||
|
||
dotenv.config(); | ||
|
||
async function main() { | ||
const { embeddings } = await embedMany({ | ||
model: openai.embedding('text-embedding-3-small'), | ||
values: [ | ||
'sunny day at the beach', | ||
'rainy afternoon in the city', | ||
'snowy night in the mountains', | ||
], | ||
}); | ||
|
||
console.log(embeddings); | ||
} | ||
|
||
main().catch(console.error); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import assert from 'node:assert'; | ||
import { | ||
MockEmbeddingModelV1, | ||
mockEmbed, | ||
} from '../test/mock-embedding-model-v1'; | ||
import { embedMany } from './embed-many'; | ||
|
||
const dummyEmbeddings = [ | ||
[0.1, 0.2, 0.3], | ||
[0.4, 0.5, 0.6], | ||
[0.7, 0.8, 0.9], | ||
]; | ||
|
||
const testValues = [ | ||
'sunny day at the beach', | ||
'rainy afternoon in the city', | ||
'snowy night in the mountains', | ||
]; | ||
|
||
describe('result.embedding', () => { | ||
it('should generate embeddings', async () => { | ||
const result = await embedMany({ | ||
model: new MockEmbeddingModelV1({ | ||
maxEmbeddingsPerCall: 5, | ||
doEmbed: mockEmbed(testValues, dummyEmbeddings), | ||
}), | ||
values: testValues, | ||
}); | ||
|
||
assert.deepStrictEqual(result.embeddings, dummyEmbeddings); | ||
}); | ||
|
||
it('should generate embeddings when several calls are required', async () => { | ||
let callCount = 0; | ||
|
||
const result = await embedMany({ | ||
model: new MockEmbeddingModelV1({ | ||
maxEmbeddingsPerCall: 2, | ||
doEmbed: async ({ values }) => { | ||
if (callCount === 0) { | ||
assert.deepStrictEqual(values, testValues.slice(0, 2)); | ||
callCount++; | ||
return { embeddings: dummyEmbeddings.slice(0, 2) }; | ||
} | ||
|
||
if (callCount === 1) { | ||
assert.deepStrictEqual(values, testValues.slice(2)); | ||
callCount++; | ||
return { embeddings: dummyEmbeddings.slice(2) }; | ||
} | ||
|
||
throw new Error('Unexpected call'); | ||
}, | ||
}), | ||
values: testValues, | ||
}); | ||
|
||
assert.deepStrictEqual(result.embeddings, dummyEmbeddings); | ||
}); | ||
}); | ||
|
||
describe('result.values', () => { | ||
it('should include values in the result', async () => { | ||
const result = await embedMany({ | ||
model: new MockEmbeddingModelV1({ | ||
maxEmbeddingsPerCall: 5, | ||
doEmbed: mockEmbed(testValues, dummyEmbeddings), | ||
}), | ||
values: testValues, | ||
}); | ||
|
||
assert.deepStrictEqual(result.values, testValues); | ||
}); | ||
}); |
Oops, something went wrong.