forked from run-llama/LlamaIndexTS
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: DeepInfra Embeddings implementation (run-llama#890)
Co-authored-by: Alex Yang <himself65@outlook.com>
- Loading branch information
Showing
4 changed files
with
249 additions
and
0 deletions.
There are no files selected for viewing
79 changes: 79 additions & 0 deletions
79
apps/docs/docs/modules/embeddings/available_embeddings/deepinfra.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# DeepInfra | ||
|
||
To use DeepInfra embeddings, you need to import `DeepInfraEmbedding` from llamaindex. | ||
Check out available embedding models [here](https://deepinfra.com/models/embeddings). | ||
|
||
```ts | ||
import { | ||
DeepInfraEmbedding, | ||
Settings, | ||
Document, | ||
VectorStoreIndex, | ||
} from "llamaindex"; | ||
|
||
// Update Embed Model | ||
Settings.embedModel = new DeepInfraEmbedding(); | ||
|
||
const document = new Document({ text: essay, id_: "essay" }); | ||
|
||
const index = await VectorStoreIndex.fromDocuments([document]); | ||
|
||
const queryEngine = index.asQueryEngine(); | ||
|
||
const query = "What is the meaning of life?"; | ||
|
||
const results = await queryEngine.query({ | ||
query, | ||
}); | ||
``` | ||
|
||
By default, DeepInfraEmbedding is using the sentence-transformers/clip-ViT-B-32 model. You can change the model by passing the model parameter to the constructor. | ||
For example: | ||
|
||
```ts | ||
import { DeepInfraEmbedding } from "llamaindex"; | ||
|
||
const model = "intfloat/e5-large-v2"; | ||
Settings.embedModel = new DeepInfraEmbedding({ | ||
model, | ||
}); | ||
``` | ||
|
||
You can also set the `maxRetries` and `timeout` parameters when initializing `DeepInfraEmbedding` for better control over the request behavior. | ||
|
||
For example: | ||
|
||
```ts | ||
import { DeepInfraEmbedding, Settings } from "llamaindex"; | ||
|
||
const model = "intfloat/e5-large-v2"; | ||
const maxRetries = 5; | ||
const timeout = 5000; // 5 seconds | ||
|
||
Settings.embedModel = new DeepInfraEmbedding({ | ||
model, | ||
maxRetries, | ||
timeout, | ||
}); | ||
``` | ||
|
||
Standalone usage: | ||
|
||
```ts | ||
import { DeepInfraEmbedding } from "llamaindex"; | ||
import { config } from "dotenv"; | ||
// For standalone usage, you need to configure DEEPINFRA_API_TOKEN in .env file | ||
config(); | ||
|
||
const main = async () => { | ||
const model = "intfloat/e5-large-v2"; | ||
const embeddings = new DeepInfraEmbedding({ model }); | ||
const text = "What is the meaning of life?"; | ||
const response = await embeddings.embed([text]); | ||
console.log(response); | ||
}; | ||
|
||
main(); | ||
``` | ||
|
||
For questions or feedback, please contact us at [feedback@deepinfra.com](mailto:feedback@deepinfra.com) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import { DeepInfraEmbedding } from "llamaindex"; | ||
|
||
async function main() { | ||
// API token can be provided as an environment variable too | ||
// using DEEPINFRA_API_TOKEN variable | ||
const apiToken = "YOUR_API_TOKEN" ?? process.env.DEEPINFRA_API_TOKEN; | ||
const model = "BAAI/bge-large-en-v1.5"; | ||
const embedModel = new DeepInfraEmbedding({ | ||
model, | ||
apiToken, | ||
}); | ||
const texts = ["hello", "world"]; | ||
const embeddings = await embedModel.getTextEmbeddingsBatch(texts); | ||
console.log(`\nWe have ${embeddings.length} embeddings`); | ||
} | ||
|
||
main().catch(console.error); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import { getEnv } from "@llamaindex/env"; | ||
import type { MessageContentDetail } from "../llm/index.js"; | ||
import { extractSingleText } from "../llm/utils.js"; | ||
import { BaseEmbedding } from "./types.js"; | ||
|
||
const DEFAULT_MODEL = "sentence-transformers/clip-ViT-B-32"; | ||
|
||
const API_TOKEN_ENV_VARIABLE_NAME = "DEEPINFRA_API_TOKEN"; | ||
|
||
const API_ROOT = "https://api.deepinfra.com/v1/inference"; | ||
|
||
const DEFAULT_TIMEOUT = 60 * 1000; | ||
|
||
const DEFAULT_MAX_RETRIES = 5; | ||
|
||
export interface DeepInfraEmbeddingResponse { | ||
embeddings: number[][]; | ||
request_id: string; | ||
inference_status: InferenceStatus; | ||
} | ||
|
||
export interface InferenceStatus { | ||
status: string; | ||
runtime_ms: number; | ||
cost: number; | ||
tokens_input: number; | ||
} | ||
|
||
const mapPrefixWithInputs = (prefix: string, inputs: string[]): string[] => { | ||
return inputs.map((input) => `${prefix} ${input}`); | ||
}; | ||
|
||
/** | ||
* DeepInfraEmbedding is an alias for DeepInfra that implements the BaseEmbedding interface. | ||
*/ | ||
export class DeepInfraEmbedding extends BaseEmbedding { | ||
/** | ||
* DeepInfra model to use | ||
* @default "sentence-transformers/clip-ViT-B-32" | ||
* @see https://deepinfra.com/models/embeddings | ||
*/ | ||
model: string; | ||
|
||
/** | ||
* DeepInfra API token | ||
* @see https://deepinfra.com/dash/api_keys | ||
* If not provided, it will try to get the token from the environment variable `DEEPINFRA_API_TOKEN` | ||
* | ||
*/ | ||
apiToken: string; | ||
|
||
/** | ||
* Prefix to add to the query | ||
* @default "" | ||
*/ | ||
queryPrefix: string; | ||
|
||
/** | ||
* Prefix to add to the text | ||
* @default "" | ||
*/ | ||
textPrefix: string; | ||
|
||
/** | ||
* | ||
* @default 5 | ||
*/ | ||
maxRetries: number; | ||
|
||
/** | ||
* | ||
* @default 60 * 1000 | ||
*/ | ||
timeout: number; | ||
|
||
constructor(init?: Partial<DeepInfraEmbedding>) { | ||
super(); | ||
|
||
this.model = init?.model ?? DEFAULT_MODEL; | ||
this.apiToken = init?.apiToken ?? getEnv(API_TOKEN_ENV_VARIABLE_NAME) ?? ""; | ||
this.queryPrefix = init?.queryPrefix ?? ""; | ||
this.textPrefix = init?.textPrefix ?? ""; | ||
this.maxRetries = init?.maxRetries ?? DEFAULT_MAX_RETRIES; | ||
this.timeout = init?.timeout ?? DEFAULT_TIMEOUT; | ||
} | ||
|
||
async getTextEmbedding(text: string): Promise<number[]> { | ||
const texts = mapPrefixWithInputs(this.textPrefix, [text]); | ||
const embeddings = await this.getDeepInfraEmbedding(texts); | ||
return embeddings[0]; | ||
} | ||
|
||
async getQueryEmbedding( | ||
query: MessageContentDetail, | ||
): Promise<number[] | null> { | ||
const text = extractSingleText(query); | ||
if (text) { | ||
const queries = mapPrefixWithInputs(this.queryPrefix, [text]); | ||
const embeddings = await this.getDeepInfraEmbedding(queries); | ||
return embeddings[0]; | ||
} else { | ||
return null; | ||
} | ||
} | ||
|
||
async getTextEmbeddings(texts: string[]): Promise<number[][]> { | ||
const textsWithPrefix = mapPrefixWithInputs(this.textPrefix, texts); | ||
return await this.getDeepInfraEmbedding(textsWithPrefix); | ||
} | ||
|
||
async getQueryEmbeddings(queries: string[]): Promise<number[][]> { | ||
const queriesWithPrefix = mapPrefixWithInputs(this.queryPrefix, queries); | ||
return await this.getDeepInfraEmbedding(queriesWithPrefix); | ||
} | ||
|
||
private async getDeepInfraEmbedding(inputs: string[]): Promise<number[][]> { | ||
const url = this.getUrl(this.model); | ||
|
||
for (let attempt = 0; attempt < this.maxRetries; attempt++) { | ||
const controller = new AbortController(); | ||
const id = setTimeout(() => controller.abort(), this.timeout); | ||
|
||
try { | ||
const response = await fetch(url, { | ||
method: "POST", | ||
headers: { | ||
"Content-Type": "application/json", | ||
Authorization: `Bearer ${this.apiToken}`, | ||
}, | ||
body: JSON.stringify({ inputs }), | ||
signal: controller.signal, | ||
}); | ||
if (!response.ok) { | ||
throw new Error(`Request failed with status ${response.status}`); | ||
} | ||
|
||
const responseJson: DeepInfraEmbeddingResponse = await response.json(); | ||
return responseJson.embeddings; | ||
} catch (error) { | ||
console.error(`Attempt ${attempt + 1} failed: ${error}`); | ||
} finally { | ||
clearTimeout(id); | ||
} | ||
} | ||
|
||
throw new Error("Exceeded maximum retries"); | ||
} | ||
|
||
private getUrl(model: string): string { | ||
return `${API_ROOT}/${model}`; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters