Skip to content

Commit

Permalink
feat (core): add text embedding model support to provider registry (#…
Browse files Browse the repository at this point in the history
…1959)

Co-authored-by: Grace Yun <74513600+iteratetograceness@users.noreply.github.com>
  • Loading branch information
lgrammel and iteratetograceness authored Jun 14, 2024
1 parent 1121364 commit 4728c37
Show file tree
Hide file tree
Showing 16 changed files with 345 additions and 113 deletions.
11 changes: 11 additions & 0 deletions .changeset/brave-poets-worry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
'@ai-sdk/google-vertex': patch
'@ai-sdk/anthropic': patch
'@ai-sdk/mistral': patch
'@ai-sdk/google': patch
'@ai-sdk/openai': patch
'@ai-sdk/azure': patch
'ai': patch
---

feat (core): add text embedding model support to provider registry
19 changes: 17 additions & 2 deletions content/docs/03-ai-sdk-core/40-provider-management.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The Vercel AI SDK provides a [`ProviderRegistry`](/docs/reference/ai-sdk-core/pr
You can register multiple providers. The provider id will become the prefix of the model id:
`providerId:modelId`.

### Setup (Example)
### Setup

You can create a registry with multiple providers and models using `experimental_createProviderRegistry`.

Expand All @@ -39,7 +39,7 @@ export const registry = createProviderRegistry({
});
```

### Usage (Example)
### Language models

You can access language models by using the `languageModel` method on the registry.
The provider id will become the prefix of the model id: `providerId:modelId`.
Expand All @@ -53,3 +53,18 @@ const { text } = await generateText({
prompt: 'Invent a new holiday and describe its traditions.',
});
```

### Text embedding models

You can access text embedding models by using the `textEmbeddingModel` method on the registry.
The provider id will become the prefix of the model id: `providerId:modelId`.

```ts highlight={"5"}
import { embed } from 'ai';
import { registry } from './registry';

const { embedding } = await embed({
model: registry.textEmbeddingModel('openai:text-embedding-3-small'),
value: 'sunny day at the beach',
});
```
93 changes: 58 additions & 35 deletions content/docs/07-reference/ai-sdk-core/40-provider-registry.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,6 @@ in a central place and access the models through simple string ids.
`createProviderRegistry` lets you create a registry with multiple providers that you
can access by their ids.

### Setup (Example)

You can create a registry with multiple providers and models using `createProviderRegistry`.

```ts
import { anthropic } from '@ai-sdk/anthropic';
import { createOpenAI } from '@ai-sdk/openai';
import { experimental_createProviderRegistry as createProviderRegistry } from 'ai';

export const registry = createProviderRegistry({
// register provider with prefix and default setup:
anthropic,

// register provider with prefix and custom setup:
openai: createOpenAI({
apiKey: process.env.OPENAI_API_KEY,
}),
});
```

### Usage (Example)

You can access language models by using the `languageModel` method on the registry.
The provider id will become the prefix of the model id: `providerId:modelId`.

```ts highlight={"4"}
import { generateText } from 'ai';

const { text } = await generateText({
model: registry.languageModel('openai:gpt-4-turbo'),
prompt: 'Invent a new holiday and describe its traditions.',
});
```

## Import

<Snippet
Expand All @@ -64,7 +30,7 @@ Registers a language model provider with a given id.
content={[
{
name: 'providers',
type: 'Record<string, (id: string) => LanguageModel>',
type: 'Record<string, { languageModel: (id: string) => LanguageModel; textEmbedding: (id: string) => EmbeddingModel<string> }>',
description: `The unique identifier for the provider. It should be unique within the registry.`,
},
]}
Expand All @@ -81,5 +47,62 @@ The `experimental_createProviderRegistry` function returns a `experimental_Provi
type: '(id: string) => LanguageModel',
description: `A function that returns a language model by its id (format: providerId:modelId)`,
},
{
name: 'textEmbeddingModel',
type: '(id: string) => EmbeddingModel<string>',
description: `A function that returns a text embedding model by its id (format: providerId:modelId)`,
},
]}
/>

## Examples

### Setup

You can create a registry with multiple providers and models using `createProviderRegistry`.

```ts
import { anthropic } from '@ai-sdk/anthropic';
import { createOpenAI } from '@ai-sdk/openai';
import { experimental_createProviderRegistry as createProviderRegistry } from 'ai';

export const registry = createProviderRegistry({
// register provider with prefix and default setup:
anthropic,

// register provider with prefix and custom setup:
openai: createOpenAI({
apiKey: process.env.OPENAI_API_KEY,
}),
});
```

### Language models

You can access language models by using the `languageModel` method on the registry.
The provider id will become the prefix of the model id: `providerId:modelId`.

```ts highlight={"5"}
import { generateText } from 'ai';
import { registry } from './registry';

const { text } = await generateText({
model: registry.languageModel('openai:gpt-4-turbo'),
prompt: 'Invent a new holiday and describe its traditions.',
});
```

### Text embedding models

You can access text embedding models by using the `textEmbeddingModel` method on the registry.
The provider id will become the prefix of the model id: `providerId:modelId`.

```ts highlight={"5"}
import { embed } from 'ai';
import { registry } from './registry';

const { embedding } = await embed({
model: registry.textEmbeddingModel('openai:text-embedding-3-small'),
value: 'sunny day at the beach',
});
```
13 changes: 13 additions & 0 deletions examples/ai-core/src/registry/embed.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { embed } from 'ai';
import { registry } from './setup-registry';

async function main() {
const { embedding } = await embed({
model: registry.textEmbeddingModel('openai:text-embedding-3-small'),
value: 'sunny day at the beach',
});

console.log(embedding);
}

main().catch(console.error);
2 changes: 2 additions & 0 deletions examples/ai-core/src/registry/setup-registry.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { anthropic } from '@ai-sdk/anthropic';
import { mistral } from '@ai-sdk/mistral';
import { createOpenAI } from '@ai-sdk/openai';
import { experimental_createProviderRegistry as createProviderRegistry } from 'ai';
import dotenv from 'dotenv';
Expand All @@ -8,6 +9,7 @@ dotenv.config();
export const registry = createProviderRegistry({
// register provider with prefix and default setup:
anthropic,
mistral,

// register provider with prefix and custom setup:
openai: createOpenAI({
Expand Down
9 changes: 9 additions & 0 deletions packages/anthropic/src/anthropic-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ Creates a model for text generation.

/**
Creates a model for text generation.
*/
languageModel(
modelId: AnthropicMessagesModelId,
settings?: AnthropicMessagesSettings,
): AnthropicMessagesLanguageModel;

/**
Creates a model for text generation.
*/
chat(
modelId: AnthropicMessagesModelId,
Expand Down Expand Up @@ -108,6 +116,7 @@ export function createAnthropic(
return createChatModel(modelId, settings);
};

provider.languageModel = createChatModel;
provider.chat = createChatModel;
provider.messages = createChatModel;

Expand Down
9 changes: 9 additions & 0 deletions packages/azure/src/azure-openai-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ export interface AzureOpenAIProvider {
settings?: OpenAIChatSettings,
): OpenAIChatLanguageModel;

/**
Creates an Azure OpenAI chat model for text generation.
*/
languageModel(
deploymentId: string,
settings?: OpenAIChatSettings,
): OpenAIChatLanguageModel;

/**
Creates an Azure OpenAI chat model for text generation.
*/
Expand Down Expand Up @@ -85,6 +93,7 @@ export function createAzure(
return createChatModel(deploymentId, settings as OpenAIChatSettings);
};

provider.languageModel = createChatModel;
provider.chat = createChatModel;

return provider as AzureOpenAIProvider;
Expand Down
10 changes: 8 additions & 2 deletions packages/core/core/registry/no-such-model-error.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
export class NoSuchModelError extends Error {
readonly modelId: string;
readonly modelType: string;

constructor({
modelId,
message = `No such model: ${modelId}`,
modelType,
message = `No such ${modelType}: ${modelId}`,
}: {
modelId: string;
modelType: string;
message?: string;
}) {
super(message);

this.name = 'AI_NoSuchModelError';

this.modelId = modelId;
this.modelType = modelType;
}

static isNoSuchModelError(error: unknown): error is NoSuchModelError {
return (
error instanceof Error &&
error.name === 'AI_NoSuchModelError' &&
typeof (error as NoSuchModelError).modelId === 'string'
typeof (error as NoSuchModelError).modelId === 'string' &&
typeof (error as NoSuchModelError).modelType === 'string'
);
}

Expand All @@ -30,6 +35,7 @@ export class NoSuchModelError extends Error {
stack: this.stack,

modelId: this.modelId,
modelType: this.modelType,
};
}
}
10 changes: 8 additions & 2 deletions packages/core/core/registry/no-such-provider-error.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
export class NoSuchProviderError extends Error {
readonly providerId: string;
readonly availableProviders: string[];

constructor({
providerId,
message = `No such provider: ${providerId}`,
availableProviders,
message = `No such provider: ${providerId} (available providers: ${availableProviders.join()})`,
}: {
providerId: string;
availableProviders: string[];
message?: string;
}) {
super(message);

this.name = 'AI_NoSuchProviderError';

this.providerId = providerId;
this.availableProviders = availableProviders;
}

static isNoSuchProviderError(error: unknown): error is NoSuchProviderError {
return (
error instanceof Error &&
error.name === 'AI_NoSuchProviderError' &&
typeof (error as NoSuchProviderError).providerId === 'string'
typeof (error as NoSuchProviderError).providerId === 'string' &&
Array.isArray((error as NoSuchProviderError).availableProviders)
);
}

Expand All @@ -30,6 +35,7 @@ export class NoSuchProviderError extends Error {
stack: this.stack,

providerId: this.providerId,
availableProviders: this.availableProviders,
};
}
}
Loading

0 comments on commit 4728c37

Please sign in to comment.