Skip to content

Commit

Permalink
feat: DeepInfra Embeddings implementation (run-llama#890)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Yang <himself65@outlook.com>
  • Loading branch information
ovuruska and himself65 authored Jun 3, 2024
1 parent 631f000 commit 3d484da
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# DeepInfra

To use DeepInfra embeddings, you need to import `DeepInfraEmbedding` from llamaindex.
Check out available embedding models [here](https://deepinfra.com/models/embeddings).

```ts
import {
DeepInfraEmbedding,
Settings,
Document,
VectorStoreIndex,
} from "llamaindex";

// Update Embed Model
Settings.embedModel = new DeepInfraEmbedding();

const document = new Document({ text: essay, id_: "essay" });

const index = await VectorStoreIndex.fromDocuments([document]);

const queryEngine = index.asQueryEngine();

const query = "What is the meaning of life?";

const results = await queryEngine.query({
query,
});
```

By default, DeepInfraEmbedding is using the sentence-transformers/clip-ViT-B-32 model. You can change the model by passing the model parameter to the constructor.
For example:

```ts
import { DeepInfraEmbedding } from "llamaindex";

const model = "intfloat/e5-large-v2";
Settings.embedModel = new DeepInfraEmbedding({
model,
});
```

You can also set the `maxRetries` and `timeout` parameters when initializing `DeepInfraEmbedding` for better control over the request behavior.

For example:

```ts
import { DeepInfraEmbedding, Settings } from "llamaindex";

const model = "intfloat/e5-large-v2";
const maxRetries = 5;
const timeout = 5000; // 5 seconds

Settings.embedModel = new DeepInfraEmbedding({
model,
maxRetries,
timeout,
});
```

Standalone usage:

```ts
import { DeepInfraEmbedding } from "llamaindex";
import { config } from "dotenv";
// For standalone usage, you need to configure DEEPINFRA_API_TOKEN in .env file
config();

const main = async () => {
const model = "intfloat/e5-large-v2";
const embeddings = new DeepInfraEmbedding({ model });
const text = "What is the meaning of life?";
const response = await embeddings.embed([text]);
console.log(response);
};

main();
```

For questions or feedback, please contact us at [feedback@deepinfra.com](mailto:feedback@deepinfra.com)
17 changes: 17 additions & 0 deletions examples/deepinfra/embedding.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { DeepInfraEmbedding } from "llamaindex";

async function main() {
// API token can be provided as an environment variable too
// using DEEPINFRA_API_TOKEN variable
const apiToken = "YOUR_API_TOKEN" ?? process.env.DEEPINFRA_API_TOKEN;
const model = "BAAI/bge-large-en-v1.5";
const embedModel = new DeepInfraEmbedding({
model,
apiToken,
});
const texts = ["hello", "world"];
const embeddings = await embedModel.getTextEmbeddingsBatch(texts);
console.log(`\nWe have ${embeddings.length} embeddings`);
}

main().catch(console.error);
152 changes: 152 additions & 0 deletions packages/core/src/embeddings/DeepInfraEmbedding.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import { getEnv } from "@llamaindex/env";
import type { MessageContentDetail } from "../llm/index.js";
import { extractSingleText } from "../llm/utils.js";
import { BaseEmbedding } from "./types.js";

const DEFAULT_MODEL = "sentence-transformers/clip-ViT-B-32";

const API_TOKEN_ENV_VARIABLE_NAME = "DEEPINFRA_API_TOKEN";

const API_ROOT = "https://api.deepinfra.com/v1/inference";

const DEFAULT_TIMEOUT = 60 * 1000;

const DEFAULT_MAX_RETRIES = 5;

export interface DeepInfraEmbeddingResponse {
embeddings: number[][];
request_id: string;
inference_status: InferenceStatus;
}

export interface InferenceStatus {
status: string;
runtime_ms: number;
cost: number;
tokens_input: number;
}

const mapPrefixWithInputs = (prefix: string, inputs: string[]): string[] => {
return inputs.map((input) => `${prefix} ${input}`);
};

/**
* DeepInfraEmbedding is an alias for DeepInfra that implements the BaseEmbedding interface.
*/
export class DeepInfraEmbedding extends BaseEmbedding {
/**
* DeepInfra model to use
* @default "sentence-transformers/clip-ViT-B-32"
* @see https://deepinfra.com/models/embeddings
*/
model: string;

/**
* DeepInfra API token
* @see https://deepinfra.com/dash/api_keys
* If not provided, it will try to get the token from the environment variable `DEEPINFRA_API_TOKEN`
*
*/
apiToken: string;

/**
* Prefix to add to the query
* @default ""
*/
queryPrefix: string;

/**
* Prefix to add to the text
* @default ""
*/
textPrefix: string;

/**
*
* @default 5
*/
maxRetries: number;

/**
*
* @default 60 * 1000
*/
timeout: number;

constructor(init?: Partial<DeepInfraEmbedding>) {
super();

this.model = init?.model ?? DEFAULT_MODEL;
this.apiToken = init?.apiToken ?? getEnv(API_TOKEN_ENV_VARIABLE_NAME) ?? "";
this.queryPrefix = init?.queryPrefix ?? "";
this.textPrefix = init?.textPrefix ?? "";
this.maxRetries = init?.maxRetries ?? DEFAULT_MAX_RETRIES;
this.timeout = init?.timeout ?? DEFAULT_TIMEOUT;
}

async getTextEmbedding(text: string): Promise<number[]> {
const texts = mapPrefixWithInputs(this.textPrefix, [text]);
const embeddings = await this.getDeepInfraEmbedding(texts);
return embeddings[0];
}

async getQueryEmbedding(
query: MessageContentDetail,
): Promise<number[] | null> {
const text = extractSingleText(query);
if (text) {
const queries = mapPrefixWithInputs(this.queryPrefix, [text]);
const embeddings = await this.getDeepInfraEmbedding(queries);
return embeddings[0];
} else {
return null;
}
}

async getTextEmbeddings(texts: string[]): Promise<number[][]> {
const textsWithPrefix = mapPrefixWithInputs(this.textPrefix, texts);
return await this.getDeepInfraEmbedding(textsWithPrefix);
}

async getQueryEmbeddings(queries: string[]): Promise<number[][]> {
const queriesWithPrefix = mapPrefixWithInputs(this.queryPrefix, queries);
return await this.getDeepInfraEmbedding(queriesWithPrefix);
}

private async getDeepInfraEmbedding(inputs: string[]): Promise<number[][]> {
const url = this.getUrl(this.model);

for (let attempt = 0; attempt < this.maxRetries; attempt++) {
const controller = new AbortController();
const id = setTimeout(() => controller.abort(), this.timeout);

try {
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiToken}`,
},
body: JSON.stringify({ inputs }),
signal: controller.signal,
});
if (!response.ok) {
throw new Error(`Request failed with status ${response.status}`);
}

const responseJson: DeepInfraEmbeddingResponse = await response.json();
return responseJson.embeddings;
} catch (error) {
console.error(`Attempt ${attempt + 1} failed: ${error}`);
} finally {
clearTimeout(id);
}
}

throw new Error("Exceeded maximum retries");
}

private getUrl(model: string): string {
return `${API_ROOT}/${model}`;
}
}
1 change: 1 addition & 0 deletions packages/core/src/embeddings/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export { DeepInfraEmbedding } from "./DeepInfraEmbedding.js";
export * from "./GeminiEmbedding.js";
export * from "./JinaAIEmbedding.js";
export * from "./MistralAIEmbedding.js";
Expand Down

0 comments on commit 3d484da

Please sign in to comment.