From 0e159848c8a1161cee60fc0325a89a3a4c20f2ed Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Mon, 18 Nov 2024 14:23:16 -0800 Subject: [PATCH 01/13] feat (provider/together): Add togetherai provider with base openai-compat. --- examples/ai-core/package.json | 1 + .../ai-core/src/generate-text/togetherai.ts | 8 +- .../src/js-test/rename-format-stream-part.js | 4 + packages/openai-compat/CHANGELOG.md | 1 + packages/openai-compat/README.md | 7 + packages/openai-compat/package.json | 70 ++ ...ert-to-openai-compat-chat-messages.test.ts | 65 ++ .../convert-to-openai-compat-chat-messages.ts | 110 +++ .../src/get-response-metadata.ts | 15 + packages/openai-compat/src/index.ts | 6 + .../src/map-openai-compat-finish-reason.ts | 19 + .../src/openai-compat-api-types.ts | 52 + .../openai-compat-chat-language-model.test.ts | 919 ++++++++++++++++++ .../src/openai-compat-chat-language-model.ts | 530 ++++++++++ .../src/openai-compat-chat-settings.ts | 11 + .../openai-compat/src/openai-compat-error.ts | 16 + .../src/openai-compat-prepare-tools.ts | 95 ++ .../src/openai-compat-provider.ts | 118 +++ packages/openai-compat/tsconfig.json | 9 + packages/openai-compat/tsup.config.ts | 10 + packages/openai-compat/turbo.json | 12 + packages/openai-compat/vitest.edge.config.js | 10 + packages/openai-compat/vitest.node.config.js | 10 + packages/togetherai/CHANGELOG.md | 1 + packages/togetherai/README.md | 3 + packages/togetherai/package.json | 71 ++ packages/togetherai/src/index.ts | 5 + .../src/togetherai-chat-settings.ts | 5 + .../togetherai/src/togetherai-provider.ts | 82 ++ packages/togetherai/tsconfig.json | 9 + packages/togetherai/tsup.config.ts | 10 + packages/togetherai/turbo.json | 12 + packages/togetherai/vitest.edge.config.js | 10 + packages/togetherai/vitest.node.config.js | 10 + pnpm-lock.yaml | 56 ++ 35 files changed, 2368 insertions(+), 4 deletions(-) create mode 100644 examples/ai-core/src/js-test/rename-format-stream-part.js create mode 100644 packages/openai-compat/CHANGELOG.md create mode 100644 packages/openai-compat/README.md create mode 100644 packages/openai-compat/package.json create mode 100644 packages/openai-compat/src/convert-to-openai-compat-chat-messages.test.ts create mode 100644 packages/openai-compat/src/convert-to-openai-compat-chat-messages.ts create mode 100644 packages/openai-compat/src/get-response-metadata.ts create mode 100644 packages/openai-compat/src/index.ts create mode 100644 packages/openai-compat/src/map-openai-compat-finish-reason.ts create mode 100644 packages/openai-compat/src/openai-compat-api-types.ts create mode 100644 packages/openai-compat/src/openai-compat-chat-language-model.test.ts create mode 100644 packages/openai-compat/src/openai-compat-chat-language-model.ts create mode 100644 packages/openai-compat/src/openai-compat-chat-settings.ts create mode 100644 packages/openai-compat/src/openai-compat-error.ts create mode 100644 packages/openai-compat/src/openai-compat-prepare-tools.ts create mode 100644 packages/openai-compat/src/openai-compat-provider.ts create mode 100644 packages/openai-compat/tsconfig.json create mode 100644 packages/openai-compat/tsup.config.ts create mode 100644 packages/openai-compat/turbo.json create mode 100644 packages/openai-compat/vitest.edge.config.js create mode 100644 packages/openai-compat/vitest.node.config.js create mode 100644 packages/togetherai/CHANGELOG.md create mode 100644 packages/togetherai/README.md create mode 100644 packages/togetherai/package.json create mode 100644 packages/togetherai/src/index.ts create mode 100644 packages/togetherai/src/togetherai-chat-settings.ts create mode 100644 packages/togetherai/src/togetherai-provider.ts create mode 100644 packages/togetherai/tsconfig.json create mode 100644 packages/togetherai/tsup.config.ts create mode 100644 packages/togetherai/turbo.json create mode 100644 packages/togetherai/vitest.edge.config.js create mode 100644 packages/togetherai/vitest.node.config.js diff --git a/examples/ai-core/package.json b/examples/ai-core/package.json index 905733db626d..347fc934b0ea 100644 --- a/examples/ai-core/package.json +++ b/examples/ai-core/package.json @@ -12,6 +12,7 @@ "@ai-sdk/groq": "1.0.1", "@ai-sdk/mistral": "1.0.2", "@ai-sdk/openai": "1.0.2", + "@ai-sdk/togetherai": "0.0.0", "@ai-sdk/xai": "1.0.2", "@opentelemetry/sdk-node": "0.54.2", "@opentelemetry/auto-instrumentations-node": "0.47.0", diff --git a/examples/ai-core/src/generate-text/togetherai.ts b/examples/ai-core/src/generate-text/togetherai.ts index d584ea67af4f..644a75dffff5 100644 --- a/examples/ai-core/src/generate-text/togetherai.ts +++ b/examples/ai-core/src/generate-text/togetherai.ts @@ -1,16 +1,16 @@ -import { createOpenAI } from '@ai-sdk/openai'; +import { createTogetherAI } from '@ai-sdk/togetherai'; import { generateText } from 'ai'; import 'dotenv/config'; -const togetherai = createOpenAI({ - name: 'togetherai', +const togetherai = createTogetherAI({ apiKey: process.env.TOGETHER_AI_API_KEY!, baseURL: 'https://api.together.xyz/v1/', }); async function main() { const { text, usage } = await generateText({ - model: togetherai('google/gemma-2-9b-it'), + // model: togetherai('google/gemma-2-9b-it'), + model: togetherai('meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'), prompt: 'Invent a new holiday and describe its traditions.', }); diff --git a/examples/ai-core/src/js-test/rename-format-stream-part.js b/examples/ai-core/src/js-test/rename-format-stream-part.js new file mode 100644 index 000000000000..27f763edcfdd --- /dev/null +++ b/examples/ai-core/src/js-test/rename-format-stream-part.js @@ -0,0 +1,4 @@ +// @ts-nocheck +import { formatStreamPart } from 'ai'; + +const response = new Response(formatStreamPart('text', cached)); diff --git a/packages/openai-compat/CHANGELOG.md b/packages/openai-compat/CHANGELOG.md new file mode 100644 index 000000000000..70b1e470d95f --- /dev/null +++ b/packages/openai-compat/CHANGELOG.md @@ -0,0 +1 @@ +# @ai-sdk/openai-compat diff --git a/packages/openai-compat/README.md b/packages/openai-compat/README.md new file mode 100644 index 000000000000..1a2daa57c041 --- /dev/null +++ b/packages/openai-compat/README.md @@ -0,0 +1,7 @@ +# AI SDK - OpenAI Compatible Provider + +This provider aims to support a core subset of functionality common to a wide +range of OpenAI compatible LLM providers. The intent is to allow code sharing +across multiple concrete provider implementations. + +The primary OpenAI provider is heavier-weight than what this package offers. diff --git a/packages/openai-compat/package.json b/packages/openai-compat/package.json new file mode 100644 index 000000000000..234ba883b59d --- /dev/null +++ b/packages/openai-compat/package.json @@ -0,0 +1,70 @@ +{ + "name": "@ai-sdk/openai-compat", + "version": "0.0.0", + "license": "Apache-2.0", + "sideEffects": false, + "main": "./dist/index.js", + "module": "./dist/index.mjs", + "types": "./dist/index.d.ts", + "files": [ + "dist/**/*", + "internal/dist/**/*", + "CHANGELOG.md" + ], + "scripts": { + "build": "tsup", + "build:watch": "tsup --watch", + "clean": "rm -rf dist && rm -rf internal/dist", + "lint": "eslint \"./**/*.ts*\"", + "type-check": "tsc --noEmit", + "prettier-check": "prettier --check \"./**/*.ts*\"", + "test": "pnpm test:node && pnpm test:edge", + "test:edge": "vitest --config vitest.edge.config.js --run", + "test:node": "vitest --config vitest.node.config.js --run" + }, + "exports": { + "./package.json": "./package.json", + ".": { + "types": "./dist/index.d.ts", + "import": "./dist/index.mjs", + "require": "./dist/index.js" + }, + "./internal": { + "types": "./internal/dist/index.d.ts", + "import": "./internal/dist/index.mjs", + "module": "./internal/dist/index.mjs", + "require": "./internal/dist/index.js" + } + }, + "dependencies": { + "@ai-sdk/provider": "1.0.0", + "@ai-sdk/provider-utils": "2.0.0" + }, + "devDependencies": { + "@types/node": "^18", + "@vercel/ai-tsconfig": "workspace:*", + "tsup": "^8", + "typescript": "5.6.3", + "zod": "3.23.8" + }, + "peerDependencies": { + "zod": "^3.0.0" + }, + "engines": { + "node": ">=18" + }, + "publishConfig": { + "access": "public" + }, + "homepage": "https://sdk.vercel.ai/docs", + "repository": { + "type": "git", + "url": "git+https://github.com/vercel/ai.git" + }, + "bugs": { + "url": "https://github.com/vercel/ai/issues" + }, + "keywords": [ + "ai" + ] +} diff --git a/packages/openai-compat/src/convert-to-openai-compat-chat-messages.test.ts b/packages/openai-compat/src/convert-to-openai-compat-chat-messages.test.ts new file mode 100644 index 000000000000..9551240a02c4 --- /dev/null +++ b/packages/openai-compat/src/convert-to-openai-compat-chat-messages.test.ts @@ -0,0 +1,65 @@ +import { convertToOpenAICompatChatMessages } from './convert-to-openai-compat-chat-messages'; + +describe('user messages', () => { + it('should convert messages with only a text part to a string content', async () => { + const result = convertToOpenAICompatChatMessages([ + { + role: 'user', + content: [{ type: 'text', text: 'Hello' }], + }, + ]); + + expect(result).toEqual([{ role: 'user', content: 'Hello' }]); + }); +}); + +describe('tool calls', () => { + it('should stringify arguments to tool calls', () => { + const result = convertToOpenAICompatChatMessages([ + { + role: 'assistant', + content: [ + { + type: 'tool-call', + args: { foo: 'bar123' }, + toolCallId: 'quux', + toolName: 'thwomp', + }, + ], + }, + { + role: 'tool', + content: [ + { + type: 'tool-result', + toolCallId: 'quux', + toolName: 'thwomp', + result: { oof: '321rab' }, + }, + ], + }, + ]); + + expect(result).toEqual([ + { + role: 'assistant', + content: '', + tool_calls: [ + { + type: 'function', + id: 'quux', + function: { + name: 'thwomp', + arguments: JSON.stringify({ foo: 'bar123' }), + }, + }, + ], + }, + { + role: 'tool', + content: JSON.stringify({ oof: '321rab' }), + tool_call_id: 'quux', + }, + ]); + }); +}); diff --git a/packages/openai-compat/src/convert-to-openai-compat-chat-messages.ts b/packages/openai-compat/src/convert-to-openai-compat-chat-messages.ts new file mode 100644 index 000000000000..134c2bd64e07 --- /dev/null +++ b/packages/openai-compat/src/convert-to-openai-compat-chat-messages.ts @@ -0,0 +1,110 @@ +import { + LanguageModelV1Prompt, + UnsupportedFunctionalityError, +} from '@ai-sdk/provider'; +import { convertUint8ArrayToBase64 } from '@ai-sdk/provider-utils'; +import { OpenAICompatChatPrompt } from './openai-compat-api-types'; + +export function convertToOpenAICompatChatMessages( + prompt: LanguageModelV1Prompt, +): OpenAICompatChatPrompt { + const messages: OpenAICompatChatPrompt = []; + + for (const { role, content } of prompt) { + switch (role) { + case 'system': { + messages.push({ role: 'system', content }); + break; + } + + case 'user': { + if (content.length === 1 && content[0].type === 'text') { + messages.push({ role: 'user', content: content[0].text }); + break; + } + + messages.push({ + role: 'user', + content: content.map(part => { + switch (part.type) { + case 'text': { + return { type: 'text', text: part.text }; + } + case 'image': { + throw new UnsupportedFunctionalityError({ + functionality: 'Image content parts in user messages', + }); + } + case 'file': { + throw new UnsupportedFunctionalityError({ + functionality: 'File content parts in user messages', + }); + } + } + }), + }); + + break; + } + + case 'assistant': { + let text = ''; + const toolCalls: Array<{ + id: string; + type: 'function'; + function: { name: string; arguments: string }; + }> = []; + + for (const part of content) { + switch (part.type) { + case 'text': { + text += part.text; + break; + } + case 'tool-call': { + toolCalls.push({ + id: part.toolCallId, + type: 'function', + function: { + name: part.toolName, + arguments: JSON.stringify(part.args), + }, + }); + break; + } + default: { + const _exhaustiveCheck: never = part; + throw new Error(`Unsupported part: ${_exhaustiveCheck}`); + } + } + } + + messages.push({ + role: 'assistant', + content: text, + tool_calls: toolCalls.length > 0 ? toolCalls : undefined, + }); + + break; + } + + case 'tool': { + for (const toolResponse of content) { + messages.push({ + role: 'tool', + tool_call_id: toolResponse.toolCallId, + content: JSON.stringify(toolResponse.result), + }); + } + break; + } + + default: { + const _exhaustiveCheck: never = role; + throw new Error(`Unsupported role: ${_exhaustiveCheck}`); + } + } + } + + return messages; +} diff --git a/packages/openai-compat/src/get-response-metadata.ts b/packages/openai-compat/src/get-response-metadata.ts new file mode 100644 index 000000000000..bd358b23f704 --- /dev/null +++ b/packages/openai-compat/src/get-response-metadata.ts @@ -0,0 +1,15 @@ +export function getResponseMetadata({ + id, + model, + created, +}: { + id?: string | undefined | null; + created?: number | undefined | null; + model?: string | undefined | null; +}) { + return { + id: id ?? undefined, + modelId: model ?? undefined, + timestamp: created != null ? new Date(created * 1000) : undefined, + }; +} diff --git a/packages/openai-compat/src/index.ts b/packages/openai-compat/src/index.ts new file mode 100644 index 000000000000..79f3cf8f88d5 --- /dev/null +++ b/packages/openai-compat/src/index.ts @@ -0,0 +1,6 @@ +export { createOpenAICompat, openaiCompat } from './openai-compat-provider'; +export type { + OpenAICompatProvider, + OpenAICompatProviderSettings, +} from './openai-compat-provider'; +export type { OpenAICompatChatSettings } from './openai-compat-chat-settings'; diff --git a/packages/openai-compat/src/map-openai-compat-finish-reason.ts b/packages/openai-compat/src/map-openai-compat-finish-reason.ts new file mode 100644 index 000000000000..a6011eb0b4cb --- /dev/null +++ b/packages/openai-compat/src/map-openai-compat-finish-reason.ts @@ -0,0 +1,19 @@ +import { LanguageModelV1FinishReason } from '@ai-sdk/provider'; + +export function mapOpenAICompatFinishReason( + finishReason: string | null | undefined, +): LanguageModelV1FinishReason { + switch (finishReason) { + case 'stop': + return 'stop'; + case 'length': + return 'length'; + case 'content_filter': + return 'content-filter'; + case 'function_call': + case 'tool_calls': + return 'tool-calls'; + default: + return 'unknown'; + } +} diff --git a/packages/openai-compat/src/openai-compat-api-types.ts b/packages/openai-compat/src/openai-compat-api-types.ts new file mode 100644 index 000000000000..65720888c33d --- /dev/null +++ b/packages/openai-compat/src/openai-compat-api-types.ts @@ -0,0 +1,52 @@ +export type OpenAICompatChatPrompt = Array; + +export type OpenAICompatMessage = + | OpenAICompatSystemMessage + | OpenAICompatUserMessage + | OpenAICompatAssistantMessage + | OpenAICompatToolMessage; + +export interface OpenAICompatSystemMessage { + role: 'system'; + content: string; +} + +export interface OpenAICompatUserMessage { + role: 'user'; + content: string | Array; +} + +export type OpenAICompatContentPart = + | OpenAICompatContentPartText + | OpenAICompatContentPartImage; + +export interface OpenAICompatContentPartImage { + type: 'image_url'; + image_url: { url: string }; +} + +export interface OpenAICompatContentPartText { + type: 'text'; + text: string; +} + +export interface OpenAICompatAssistantMessage { + role: 'assistant'; + content?: string | null; + tool_calls?: Array; +} + +export interface OpenAICompatMessageToolCall { + type: 'function'; + id: string; + function: { + arguments: string; + name: string; + }; +} + +export interface OpenAICompatToolMessage { + role: 'tool'; + content: string; + tool_call_id: string; +} diff --git a/packages/openai-compat/src/openai-compat-chat-language-model.test.ts b/packages/openai-compat/src/openai-compat-chat-language-model.test.ts new file mode 100644 index 000000000000..5006d36527ab --- /dev/null +++ b/packages/openai-compat/src/openai-compat-chat-language-model.test.ts @@ -0,0 +1,919 @@ +import { LanguageModelV1Prompt } from '@ai-sdk/provider'; +import { + JsonTestServer, + StreamingTestServer, + convertReadableStreamToArray, +} from '@ai-sdk/provider-utils/test'; +import { createOpenAICompat } from './openai-compat-provider'; + +const TEST_PROMPT: LanguageModelV1Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, +]; + +const provider = createOpenAICompat({ + apiKey: 'test-api-key', +}); + +const model = provider('grok-beta'); + +describe('doGenerate', () => { + const server = new JsonTestServer('https://api.x.ai/v1/chat/completions'); + + server.setupTestEnvironment(); + + function prepareJsonResponse({ + content = '', + tool_calls, + function_call, + usage = { + prompt_tokens: 4, + total_tokens: 34, + completion_tokens: 30, + }, + finish_reason = 'stop', + id = 'chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd', + created = 1711115037, + model = 'grok-beta', + }: { + content?: string; + tool_calls?: Array<{ + id: string; + type: 'function'; + function: { + name: string; + arguments: string; + }; + }>; + function_call?: { + name: string; + arguments: string; + }; + usage?: { + prompt_tokens?: number; + total_tokens?: number; + completion_tokens?: number; + }; + finish_reason?: string; + created?: number; + id?: string; + model?: string; + } = {}) { + server.responseBodyJson = { + id, + object: 'chat.completion', + created, + model, + choices: [ + { + index: 0, + message: { + role: 'assistant', + content, + tool_calls, + function_call, + }, + finish_reason, + }, + ], + usage, + system_fingerprint: 'fp_3bc1b5746c', + }; + } + + it('should pass user setting to requests', async () => { + prepareJsonResponse({ content: 'Hello, World!' }); + const modelWithUser = provider('grok-beta', { + user: 'test-user-id', + }); + await modelWithUser.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + expect(await server.getRequestBodyJson()).toMatchObject({ + user: 'test-user-id', + }); + }); + + it('should extract text response', async () => { + prepareJsonResponse({ content: 'Hello, World!' }); + + const { text } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(text).toStrictEqual('Hello, World!'); + }); + + it('should extract usage', async () => { + prepareJsonResponse({ + content: '', + usage: { prompt_tokens: 20, total_tokens: 25, completion_tokens: 5 }, + }); + + const { usage } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(usage).toStrictEqual({ + promptTokens: 20, + completionTokens: 5, + }); + }); + + it('should send additional response information', async () => { + prepareJsonResponse({ + id: 'test-id', + created: 123, + model: 'test-model', + }); + + const { response } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(response).toStrictEqual({ + id: 'test-id', + timestamp: new Date(123 * 1000), + modelId: 'test-model', + }); + }); + + it('should support partial usage', async () => { + prepareJsonResponse({ + content: '', + usage: { prompt_tokens: 20, total_tokens: 20 }, + }); + + const { usage } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(usage).toStrictEqual({ + promptTokens: 20, + completionTokens: NaN, + }); + }); + + it('should extract finish reason', async () => { + prepareJsonResponse({ + content: '', + finish_reason: 'stop', + }); + + const response = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(response.finishReason).toStrictEqual('stop'); + }); + + it('should support unknown finish reason', async () => { + prepareJsonResponse({ + content: '', + finish_reason: 'eos', + }); + + const response = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(response.finishReason).toStrictEqual('unknown'); + }); + + it('should expose the raw response headers', async () => { + prepareJsonResponse({ content: '' }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-length': '312', + 'content-type': 'application/json', + + // custom header + 'test-header': 'test-value', + }); + }); + + it('should pass the model and the messages', async () => { + prepareJsonResponse({ content: '' }); + + await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + model: 'grok-beta', + messages: [{ role: 'user', content: 'Hello' }], + }); + }); + + it('should pass settings', async () => { + prepareJsonResponse(); + + await provider('grok-beta', { + user: 'test-user-id', + }).doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + model: 'grok-beta', + messages: [{ role: 'user', content: 'Hello' }], + user: 'test-user-id', + }); + }); + + it('should pass tools and toolChoice', async () => { + prepareJsonResponse({ content: '' }); + + await model.doGenerate({ + inputFormat: 'prompt', + mode: { + type: 'regular', + tools: [ + { + type: 'function', + name: 'test-tool', + parameters: { + type: 'object', + properties: { value: { type: 'string' } }, + required: ['value'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + ], + toolChoice: { + type: 'tool', + toolName: 'test-tool', + }, + }, + prompt: TEST_PROMPT, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + model: 'grok-beta', + messages: [{ role: 'user', content: 'Hello' }], + tools: [ + { + type: 'function', + function: { + name: 'test-tool', + parameters: { + type: 'object', + properties: { value: { type: 'string' } }, + required: ['value'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + }, + ], + tool_choice: { + type: 'function', + function: { name: 'test-tool' }, + }, + }); + }); + + it('should pass headers', async () => { + prepareJsonResponse({ content: '' }); + + const provider = createOpenAICompat({ + apiKey: 'test-api-key', + headers: { + 'Custom-Provider-Header': 'provider-header-value', + }, + }); + + await provider('grok-beta').doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + headers: { + 'Custom-Request-Header': 'request-header-value', + }, + }); + + const requestHeaders = await server.getRequestHeaders(); + + expect(requestHeaders).toStrictEqual({ + authorization: 'Bearer test-api-key', + 'content-type': 'application/json', + 'custom-provider-header': 'provider-header-value', + 'custom-request-header': 'request-header-value', + }); + }); + + it('should parse tool results', async () => { + prepareJsonResponse({ + tool_calls: [ + { + id: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + type: 'function', + function: { + name: 'test-tool', + arguments: '{"value":"Spark"}', + }, + }, + ], + }); + + const result = await model.doGenerate({ + inputFormat: 'prompt', + mode: { + type: 'regular', + tools: [ + { + type: 'function', + name: 'test-tool', + parameters: { + type: 'object', + properties: { value: { type: 'string' } }, + required: ['value'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + ], + toolChoice: { + type: 'tool', + toolName: 'test-tool', + }, + }, + prompt: TEST_PROMPT, + }); + + expect(result.toolCalls).toStrictEqual([ + { + args: '{"value":"Spark"}', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + }, + ]); + }); + + it('should send request body', async () => { + prepareJsonResponse({ content: '' }); + + const { request } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(request).toStrictEqual({ + body: '{"model":"grok-beta","messages":[{"role":"user","content":"Hello"}]}', + }); + }); +}); + +describe('doStream', () => { + const server = new StreamingTestServer( + 'https://api.x.ai/v1/chat/completions', + ); + + server.setupTestEnvironment(); + + function prepareStreamResponse({ + content, + finish_reason = 'stop', + }: { + content: string[]; + finish_reason?: string; + }) { + server.responseChunks = [ + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1702657020,"model":"grok-beta",` + + `"system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}\n\n`, + ...content.map(text => { + return ( + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1702657020,"model":"grok-beta",` + + `"system_fingerprint":null,"choices":[{"index":1,"delta":{"content":"${text}"},"finish_reason":null}]}\n\n` + ); + }), + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1702657020,"model":"grok-beta",` + + `"system_fingerprint":null,"choices":[{"index":0,"delta":{},"finish_reason":"${finish_reason}"}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1729171479,"model":"grok-beta",` + + `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"${finish_reason}"}],` + + `"usage":{"queue_time":0.061348671,"prompt_tokens":18,"prompt_time":0.000211569,` + + `"completion_tokens":439,"completion_time":0.798181818,"total_tokens":457,"total_time":0.798393387}}\n\n`, + 'data: [DONE]\n\n', + ]; + } + + it('should stream text deltas', async () => { + prepareStreamResponse({ + content: ['Hello', ', ', 'World!'], + finish_reason: 'stop', + }); + + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + // note: space moved to last chunk bc of trimming + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { + type: 'response-metadata', + id: 'chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798', + modelId: 'grok-beta', + timestamp: new Date('2023-12-15T16:17:00.000Z'), + }, + { type: 'text-delta', textDelta: '' }, + { type: 'text-delta', textDelta: 'Hello' }, + { type: 'text-delta', textDelta: ', ' }, + { type: 'text-delta', textDelta: 'World!' }, + { + type: 'finish', + finishReason: 'stop', + usage: { promptTokens: 18, completionTokens: 439 }, + }, + ]); + }); + + it('should stream tool deltas', async () => { + server.responseChunks = [ + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,` + + `"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\""}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\\":\\""}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\\"}"}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1729171479,"model":"grok-beta",` + + `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],` + + `"usage":{"queue_time":0.061348671,"prompt_tokens":18,"prompt_time":0.000211569,` + + `"completion_tokens":439,"completion_time":0.798181818,"total_tokens":457,"total_time":0.798393387}}\n\n`, + 'data: [DONE]\n\n', + ]; + + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { + type: 'regular', + tools: [ + { + type: 'function', + name: 'test-tool', + parameters: { + type: 'object', + properties: { value: { type: 'string' } }, + required: ['value'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + ], + }, + prompt: TEST_PROMPT, + }); + + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { + type: 'response-metadata', + id: 'chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798', + modelId: 'grok-beta', + timestamp: new Date('2024-03-25T09:06:38.000Z'), + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '{"', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: 'value', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '":"', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: 'Spark', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: 'le', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: ' Day', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '"}', + }, + { + type: 'tool-call', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + args: '{"value":"Sparkle Day"}', + }, + { + type: 'finish', + finishReason: 'tool-calls', + usage: { promptTokens: 18, completionTokens: 439 }, + }, + ]); + }); + + it('should stream tool call deltas when tool call arguments are passed in the first chunk', async () => { + server.responseChunks = [ + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,` + + `"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":"{\\""}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"va"}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"lue"}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\\":\\""}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\\"}"}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1729171479,"model":"grok-beta",` + + `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],` + + `"usage":{"queue_time":0.061348671,"prompt_tokens":18,"prompt_time":0.000211569,` + + `"completion_tokens":439,"completion_time":0.798181818,"total_tokens":457,"total_time":0.798393387}}\n\n`, + 'data: [DONE]\n\n', + ]; + + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { + type: 'regular', + tools: [ + { + type: 'function', + name: 'test-tool', + parameters: { + type: 'object', + properties: { value: { type: 'string' } }, + required: ['value'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + ], + }, + prompt: TEST_PROMPT, + }); + + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { + type: 'response-metadata', + id: 'chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798', + modelId: 'grok-beta', + timestamp: new Date('2024-03-25T09:06:38.000Z'), + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '{"', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: 'va', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: 'lue', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '":"', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: 'Spark', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: 'le', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: ' Day', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '"}', + }, + { + type: 'tool-call', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + args: '{"value":"Sparkle Day"}', + }, + { + type: 'finish', + finishReason: 'tool-calls', + usage: { promptTokens: 18, completionTokens: 439 }, + }, + ]); + }); + + it('should stream tool call that is sent in one chunk', async () => { + server.responseChunks = [ + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1711357598,"model":"grok-beta",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,` + + `"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":"{\\"value\\":\\"Sparkle Day\\"}"}}]},` + + `"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1729171479,"model":"grok-beta",` + + `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],` + + `"usage":{"queue_time":0.061348671,"prompt_tokens":18,"prompt_time":0.000211569,` + + `"completion_tokens":439,"completion_time":0.798181818,"total_tokens":457,"total_time":0.798393387}}\n\n`, + 'data: [DONE]\n\n', + ]; + + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { + type: 'regular', + tools: [ + { + type: 'function', + name: 'test-tool', + parameters: { + type: 'object', + properties: { value: { type: 'string' } }, + required: ['value'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + ], + }, + prompt: TEST_PROMPT, + }); + + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { + type: 'response-metadata', + id: 'chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798', + modelId: 'grok-beta', + timestamp: new Date('2024-03-25T09:06:38.000Z'), + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '{"value":"Sparkle Day"}', + }, + { + type: 'tool-call', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + args: '{"value":"Sparkle Day"}', + }, + { + type: 'finish', + finishReason: 'tool-calls', + usage: { promptTokens: 18, completionTokens: 439 }, + }, + ]); + }); + + it('should handle error stream parts', async () => { + server.responseChunks = [ + `data: {"code":"Client specified an invalid argument","error":"Incorrect API key provided: as***T7. You can obtain an API key from https://console.x.ai."}\n\n`, + 'data: [DONE]\n\n', + ]; + + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { + type: 'error', + error: + 'Incorrect API key provided: as***T7. You can obtain an API key from https://console.x.ai.', + }, + { + finishReason: 'error', + type: 'finish', + usage: { + completionTokens: NaN, + promptTokens: NaN, + }, + }, + ]); + }); + + it('should handle unparsable stream parts', async () => { + server.responseChunks = [`data: {unparsable}\n\n`, 'data: [DONE]\n\n']; + + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + const elements = await convertReadableStreamToArray(stream); + + expect(elements.length).toBe(2); + expect(elements[0].type).toBe('error'); + expect(elements[1]).toStrictEqual({ + finishReason: 'error', + type: 'finish', + usage: { + completionTokens: NaN, + promptTokens: NaN, + }, + }); + }); + + it('should expose the raw response headers', async () => { + prepareStreamResponse({ content: [] }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'text/event-stream', + 'cache-control': 'no-cache', + connection: 'keep-alive', + + // custom header + 'test-header': 'test-value', + }); + }); + + it('should pass the messages and the model', async () => { + prepareStreamResponse({ content: [] }); + + await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + stream: true, + model: 'grok-beta', + messages: [{ role: 'user', content: 'Hello' }], + }); + }); + + it('should pass headers', async () => { + prepareStreamResponse({ content: [] }); + + const provider = createOpenAICompat({ + apiKey: 'test-api-key', + headers: { + 'Custom-Provider-Header': 'provider-header-value', + }, + }); + + await provider('grok-beta').doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + headers: { + 'Custom-Request-Header': 'request-header-value', + }, + }); + + const requestHeaders = await server.getRequestHeaders(); + + expect(requestHeaders).toStrictEqual({ + authorization: 'Bearer test-api-key', + 'content-type': 'application/json', + 'custom-provider-header': 'provider-header-value', + 'custom-request-header': 'request-header-value', + }); + }); + + it('should send request body', async () => { + prepareStreamResponse({ content: [] }); + + const { request } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(request).toStrictEqual({ + body: '{"model":"grok-beta","messages":[{"role":"user","content":"Hello"}],"stream":true}', + }); + }); +}); diff --git a/packages/openai-compat/src/openai-compat-chat-language-model.ts b/packages/openai-compat/src/openai-compat-chat-language-model.ts new file mode 100644 index 000000000000..e0c05924d28c --- /dev/null +++ b/packages/openai-compat/src/openai-compat-chat-language-model.ts @@ -0,0 +1,530 @@ +import { + InvalidResponseDataError, + LanguageModelV1, + LanguageModelV1CallWarning, + LanguageModelV1FinishReason, + LanguageModelV1ProviderMetadata, + LanguageModelV1StreamPart, + UnsupportedFunctionalityError, +} from '@ai-sdk/provider'; +import { + FetchFunction, + ParseResult, + combineHeaders, + createEventSourceResponseHandler, + createJsonResponseHandler, + generateId, + isParsableJson, + postJsonToApi, +} from '@ai-sdk/provider-utils'; +import { z } from 'zod'; +import { convertToOpenAICompatChatMessages } from './convert-to-openai-compat-chat-messages'; +import { getResponseMetadata } from './get-response-metadata'; +import { + OpenAICompatChatModelId, + OpenAICompatChatSettings, +} from './openai-compat-chat-settings'; +import { + openaiCompatErrorDataSchema, + openaiCompatFailedResponseHandler, +} from './openai-compat-error'; +import { prepareTools } from './openai-compat-prepare-tools'; +import { mapOpenAICompatFinishReason } from './map-openai-compat-finish-reason'; + +type OpenAICompatChatConfig = { + provider: string; + headers: () => Record; + url: (options: { modelId: string; path: string }) => string; + fetch?: FetchFunction; +}; + +export class OpenAICompatChatLanguageModel implements LanguageModelV1 { + readonly specificationVersion = 'v1'; + + readonly supportsStructuredOutputs = false; + readonly defaultObjectGenerationMode = 'tool'; + + readonly modelId: OpenAICompatChatModelId; + readonly settings: OpenAICompatChatSettings; + + private readonly config: OpenAICompatChatConfig; + + constructor( + modelId: OpenAICompatChatModelId, + settings: OpenAICompatChatSettings, + config: OpenAICompatChatConfig, + ) { + this.modelId = modelId; + this.settings = settings; + this.config = config; + } + + get provider(): string { + return this.config.provider; + } + + private getArgs({ + mode, + prompt, + maxTokens, + temperature, + topP, + topK, + frequencyPenalty, + presencePenalty, + stopSequences, + responseFormat, + seed, + stream, + }: Parameters[0] & { + stream: boolean; + }) { + const type = mode.type; + + const warnings: LanguageModelV1CallWarning[] = []; + + if (topK != null) { + warnings.push({ + type: 'unsupported-setting', + setting: 'topK', + }); + } + + if ( + responseFormat != null && + responseFormat.type === 'json' && + responseFormat.schema != null + ) { + warnings.push({ + type: 'unsupported-setting', + setting: 'responseFormat', + details: 'JSON response format schema is not supported', + }); + } + + const baseArgs = { + // model id: + model: this.modelId, + + // model specific settings: + user: this.settings.user, + + // standardized settings: + max_tokens: maxTokens, + temperature, + top_p: topP, + frequency_penalty: frequencyPenalty, + presence_penalty: presencePenalty, + stop: stopSequences, + seed, + + // response format: + response_format: + // json object response format is not currently supported + undefined, + + // messages: + messages: convertToOpenAICompatChatMessages(prompt), + }; + + switch (type) { + case 'regular': { + const { tools, tool_choice, toolWarnings } = prepareTools({ mode }); + return { + args: { + ...baseArgs, + tools, + tool_choice, + }, + warnings: [...warnings, ...toolWarnings], + }; + } + + case 'object-json': { + throw new UnsupportedFunctionalityError({ + functionality: 'object-json mode', + }); + } + + case 'object-tool': { + return { + args: { + ...baseArgs, + tool_choice: { + type: 'function', + function: { name: mode.tool.name }, + }, + tools: [ + { + type: 'function', + function: { + name: mode.tool.name, + description: mode.tool.description, + parameters: mode.tool.parameters, + }, + }, + ], + }, + warnings, + }; + } + + default: { + const _exhaustiveCheck: never = type; + throw new Error(`Unsupported type: ${_exhaustiveCheck}`); + } + } + } + + async doGenerate( + options: Parameters[0], + ): Promise>> { + const { args, warnings } = this.getArgs({ ...options, stream: false }); + + const body = JSON.stringify(args); + + const { responseHeaders, value: response } = await postJsonToApi({ + url: this.config.url({ + path: '/chat/completions', + modelId: this.modelId, + }), + headers: combineHeaders(this.config.headers(), options.headers), + body: args, + failedResponseHandler: openaiCompatFailedResponseHandler, + successfulResponseHandler: createJsonResponseHandler( + openaiCompatChatResponseSchema, + ), + abortSignal: options.abortSignal, + fetch: this.config.fetch, + }); + + const { messages: rawPrompt, ...rawSettings } = args; + const choice = response.choices[0]; + + return { + text: choice.message.content ?? undefined, + toolCalls: choice.message.tool_calls?.map(toolCall => ({ + toolCallType: 'function', + toolCallId: toolCall.id ?? generateId(), + toolName: toolCall.function.name, + args: toolCall.function.arguments!, + })), + finishReason: mapOpenAICompatFinishReason(choice.finish_reason), + usage: { + promptTokens: response.usage?.prompt_tokens ?? NaN, + completionTokens: response.usage?.completion_tokens ?? NaN, + }, + rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, + response: getResponseMetadata(response), + warnings, + request: { body }, + }; + } + + async doStream( + options: Parameters[0], + ): Promise>> { + const { args, warnings } = this.getArgs({ ...options, stream: true }); + + const body = JSON.stringify({ ...args, stream: true }); + + const { responseHeaders, value: response } = await postJsonToApi({ + url: this.config.url({ + path: '/chat/completions', + modelId: this.modelId, + }), + headers: combineHeaders(this.config.headers(), options.headers), + body: { + ...args, + stream: true, + }, + failedResponseHandler: openaiCompatFailedResponseHandler, + successfulResponseHandler: createEventSourceResponseHandler( + openaiCompatChatChunkSchema, + ), + abortSignal: options.abortSignal, + fetch: this.config.fetch, + }); + + const { messages: rawPrompt, ...rawSettings } = args; + + const toolCalls: Array<{ + id: string; + type: 'function'; + function: { + name: string; + arguments: string; + }; + }> = []; + + let finishReason: LanguageModelV1FinishReason = 'unknown'; + let usage: { + promptTokens: number | undefined; + completionTokens: number | undefined; + } = { + promptTokens: undefined, + completionTokens: undefined, + }; + let isFirstChunk = true; + + let providerMetadata: LanguageModelV1ProviderMetadata | undefined; + return { + stream: response.pipeThrough( + new TransformStream< + ParseResult>, + LanguageModelV1StreamPart + >({ + transform(chunk, controller) { + // handle failed chunk parsing / validation: + if (!chunk.success) { + finishReason = 'error'; + controller.enqueue({ type: 'error', error: chunk.error }); + return; + } + + const value = chunk.value; + + // handle error chunks: + if ('error' in value) { + finishReason = 'error'; + controller.enqueue({ type: 'error', error: value.error }); + return; + } + + if (isFirstChunk) { + isFirstChunk = false; + + controller.enqueue({ + type: 'response-metadata', + ...getResponseMetadata(value), + }); + } + + if (value.usage != null) { + usage = { + promptTokens: value.usage.prompt_tokens ?? undefined, + completionTokens: value.usage.completion_tokens ?? undefined, + }; + } + + const choice = value.choices[0]; + + if (choice?.finish_reason != null) { + finishReason = mapOpenAICompatFinishReason(choice.finish_reason); + } + + if (choice?.delta == null) { + return; + } + + const delta = choice.delta; + + if (delta.content != null) { + controller.enqueue({ + type: 'text-delta', + textDelta: delta.content, + }); + } + + if (delta.tool_calls != null) { + for (const toolCallDelta of delta.tool_calls) { + const index = toolCallDelta.index; + + if (toolCalls[index] == null) { + if (toolCallDelta.type !== 'function') { + throw new InvalidResponseDataError({ + data: toolCallDelta, + message: `Expected 'function' type.`, + }); + } + + if (toolCallDelta.id == null) { + throw new InvalidResponseDataError({ + data: toolCallDelta, + message: `Expected 'id' to be a string.`, + }); + } + + if (toolCallDelta.function?.name == null) { + throw new InvalidResponseDataError({ + data: toolCallDelta, + message: `Expected 'function.name' to be a string.`, + }); + } + + toolCalls[index] = { + id: toolCallDelta.id, + type: 'function', + function: { + name: toolCallDelta.function.name, + arguments: toolCallDelta.function.arguments ?? '', + }, + }; + + const toolCall = toolCalls[index]; + + if ( + toolCall.function?.name != null && + toolCall.function?.arguments != null + ) { + // send delta if the argument text has already started: + if (toolCall.function.arguments.length > 0) { + controller.enqueue({ + type: 'tool-call-delta', + toolCallType: 'function', + toolCallId: toolCall.id, + toolName: toolCall.function.name, + argsTextDelta: toolCall.function.arguments, + }); + } + + // check if tool call is complete + // (some providers send the full tool call in one chunk): + if (isParsableJson(toolCall.function.arguments)) { + controller.enqueue({ + type: 'tool-call', + toolCallType: 'function', + toolCallId: toolCall.id ?? generateId(), + toolName: toolCall.function.name, + args: toolCall.function.arguments, + }); + } + } + + continue; + } + + // existing tool call, merge + const toolCall = toolCalls[index]; + + if (toolCallDelta.function?.arguments != null) { + toolCall.function!.arguments += + toolCallDelta.function?.arguments ?? ''; + } + + // send delta + controller.enqueue({ + type: 'tool-call-delta', + toolCallType: 'function', + toolCallId: toolCall.id, + toolName: toolCall.function.name, + argsTextDelta: toolCallDelta.function.arguments ?? '', + }); + + // check if tool call is complete + if ( + toolCall.function?.name != null && + toolCall.function?.arguments != null && + isParsableJson(toolCall.function.arguments) + ) { + controller.enqueue({ + type: 'tool-call', + toolCallType: 'function', + toolCallId: toolCall.id ?? generateId(), + toolName: toolCall.function.name, + args: toolCall.function.arguments, + }); + } + } + } + }, + + flush(controller) { + controller.enqueue({ + type: 'finish', + finishReason, + usage: { + promptTokens: usage.promptTokens ?? NaN, + completionTokens: usage.completionTokens ?? NaN, + }, + ...(providerMetadata != null ? { providerMetadata } : {}), + }); + }, + }), + ), + rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, + warnings, + request: { body }, + }; + } +} + +// limited version of the schema, focussed on what is needed for the implementation +// this approach limits breakages when the API changes and increases efficiency +const openaiCompatChatResponseSchema = z.object({ + id: z.string().nullish(), + created: z.number().nullish(), + model: z.string().nullish(), + choices: z.array( + z.object({ + message: z.object({ + role: z.literal('assistant').nullish(), + content: z.string().nullish(), + tool_calls: z + .array( + z.object({ + id: z.string().nullish(), + type: z.literal('function'), + function: z.object({ + name: z.string(), + arguments: z.string(), + }), + }), + ) + .nullish(), + }), + index: z.number(), + finish_reason: z.string().nullish(), + }), + ), + usage: z + .object({ + prompt_tokens: z.number().nullish(), + completion_tokens: z.number().nullish(), + }) + .nullish(), +}); + +// limited version of the schema, focussed on what is needed for the implementation +// this approach limits breakages when the API changes and increases efficiency +const openaiCompatChatChunkSchema = z.union([ + z.object({ + id: z.string().nullish(), + created: z.number().nullish(), + model: z.string().nullish(), + choices: z.array( + z.object({ + delta: z + .object({ + role: z.enum(['assistant']).nullish(), + content: z.string().nullish(), + tool_calls: z + .array( + z.object({ + index: z.number(), + id: z.string().nullish(), + type: z.literal('function').optional(), + function: z.object({ + name: z.string().nullish(), + arguments: z.string().nullish(), + }), + }), + ) + .nullish(), + }) + .nullish(), + finish_reason: z.string().nullable().optional(), + index: z.number(), + }), + ), + usage: z + .object({ + prompt_tokens: z.number().nullish(), + completion_tokens: z.number().nullish(), + }) + .nullish(), + }), + openaiCompatErrorDataSchema, +]); diff --git a/packages/openai-compat/src/openai-compat-chat-settings.ts b/packages/openai-compat/src/openai-compat-chat-settings.ts new file mode 100644 index 000000000000..e5554224b5a2 --- /dev/null +++ b/packages/openai-compat/src/openai-compat-chat-settings.ts @@ -0,0 +1,11 @@ +// TODO(shaper): Need to generalize/fix the below to use an interface somehow. +// https://console.x.ai and see "View models" +export type OpenAICompatChatModelId = string; + +export interface OpenAICompatChatSettings { + /** +A unique identifier representing your end-user, which can help the provider to +monitor and detect abuse. +*/ + user?: string; +} diff --git a/packages/openai-compat/src/openai-compat-error.ts b/packages/openai-compat/src/openai-compat-error.ts new file mode 100644 index 000000000000..debbee8270e1 --- /dev/null +++ b/packages/openai-compat/src/openai-compat-error.ts @@ -0,0 +1,16 @@ +import { z } from 'zod'; +import { createJsonErrorResponseHandler } from '@ai-sdk/provider-utils'; + +export const openaiCompatErrorDataSchema = z.object({ + code: z.string(), + error: z.string(), +}); + +export type OpenAICompatErrorData = z.infer; + +export const openaiCompatFailedResponseHandler = createJsonErrorResponseHandler( + { + errorSchema: openaiCompatErrorDataSchema, + errorToMessage: data => data.error, + }, +); diff --git a/packages/openai-compat/src/openai-compat-prepare-tools.ts b/packages/openai-compat/src/openai-compat-prepare-tools.ts new file mode 100644 index 000000000000..7a81d04a03f6 --- /dev/null +++ b/packages/openai-compat/src/openai-compat-prepare-tools.ts @@ -0,0 +1,95 @@ +import { + LanguageModelV1, + LanguageModelV1CallWarning, + UnsupportedFunctionalityError, +} from '@ai-sdk/provider'; + +export function prepareTools({ + mode, +}: { + mode: Parameters[0]['mode'] & { + type: 'regular'; + }; +}): { + tools: + | undefined + | Array<{ + type: 'function'; + function: { + name: string; + description: string | undefined; + parameters: unknown; + }; + }>; + tool_choice: + | { type: 'function'; function: { name: string } } + | 'auto' + | 'none' + | 'required' + | undefined; + toolWarnings: LanguageModelV1CallWarning[]; +} { + // when the tools array is empty, change it to undefined to prevent errors: + const tools = mode.tools?.length ? mode.tools : undefined; + const toolWarnings: LanguageModelV1CallWarning[] = []; + + if (tools == null) { + return { tools: undefined, tool_choice: undefined, toolWarnings }; + } + + const toolChoice = mode.toolChoice; + + const openaiCompatTools: Array<{ + type: 'function'; + function: { + name: string; + description: string | undefined; + parameters: unknown; + }; + }> = []; + + for (const tool of tools) { + if (tool.type === 'provider-defined') { + toolWarnings.push({ type: 'unsupported-tool', tool }); + } else { + openaiCompatTools.push({ + type: 'function', + function: { + name: tool.name, + description: tool.description, + parameters: tool.parameters, + }, + }); + } + } + + if (toolChoice == null) { + return { tools: openaiCompatTools, tool_choice: undefined, toolWarnings }; + } + + const type = toolChoice.type; + + switch (type) { + case 'auto': + case 'none': + case 'required': + return { tools: openaiCompatTools, tool_choice: type, toolWarnings }; + case 'tool': + return { + tools: openaiCompatTools, + tool_choice: { + type: 'function', + function: { + name: toolChoice.toolName, + }, + }, + toolWarnings, + }; + default: { + const _exhaustiveCheck: never = type; + throw new UnsupportedFunctionalityError({ + functionality: `Unsupported tool choice type: ${_exhaustiveCheck}`, + }); + } + } +} diff --git a/packages/openai-compat/src/openai-compat-provider.ts b/packages/openai-compat/src/openai-compat-provider.ts new file mode 100644 index 000000000000..4fac23cd9891 --- /dev/null +++ b/packages/openai-compat/src/openai-compat-provider.ts @@ -0,0 +1,118 @@ +import { + LanguageModelV1, + NoSuchModelError, + ProviderV1, +} from '@ai-sdk/provider'; +import { + FetchFunction, + loadApiKey, + withoutTrailingSlash, +} from '@ai-sdk/provider-utils'; +import { OpenAICompatChatLanguageModel } from './openai-compat-chat-language-model'; +import { + OpenAICompatChatModelId, + OpenAICompatChatSettings, +} from './openai-compat-chat-settings'; + +export interface OpenAICompatProvider + extends ProviderV1 { + /** +Creates a model for text generation. +*/ + (modelId: M, settings?: OpenAICompatChatSettings): LanguageModelV1; + + /** +Creates an OpenAICompat chat model for text generation. + */ + languageModel( + modelId: M, + settings?: OpenAICompatChatSettings, + ): LanguageModelV1; +} + +export interface OpenAICompatProviderSettings { + /** +Base URL for the OpenAICompat API calls. + */ + baseURL?: string; + + /** +API key for authenticating requests. + */ + apiKey?: string; + + /** +Custom headers to include in the requests. + */ + headers?: Record; + + /** +Custom fetch implementation. You can use it as a middleware to intercept requests, +or to provide a custom fetch implementation for e.g. testing. + */ + fetch?: FetchFunction; +} + +/** +Create an OpenAICompat provider instance. + */ +export function createOpenAICompat( + options: OpenAICompatProviderSettings = {}, +): OpenAICompatProvider { + // TODO(shaper): Generalize: + // - base url + // - api key name + const baseURL = + withoutTrailingSlash(options.baseURL) ?? 'https://api.x.ai/v1'; + + const getHeaders = () => ({ + // TODO(shaper): Need to use an interface for the below, and/or throw. + Authorization: `Bearer ${loadApiKey({ + apiKey: options.apiKey, + environmentVariableName: 'OPENAI_COMPAT_API_KEY', + description: 'OpenAICompat API key', + })}`, + ...options.headers, + }); + + const createChatModel = ( + modelId: M, + settings: OpenAICompatChatSettings = {}, + ) => + new OpenAICompatChatLanguageModel(modelId, settings, { + provider: 'openaiCompat.chat', + url: ({ path }) => `${baseURL}${path}`, + headers: getHeaders, + fetch: options.fetch, + }); + + const createLanguageModel = ( + modelId: M, + settings?: OpenAICompatChatSettings, + ) => { + if (new.target) { + throw new Error( + 'The OpenAICompat model function cannot be called with the new keyword.', + ); + } + + return createChatModel(modelId, settings); + }; + + const provider = function (modelId: M, settings?: OpenAICompatChatSettings) { + return createLanguageModel(modelId, settings); + }; + + provider.languageModel = createLanguageModel; + provider.chat = createChatModel; + provider.textEmbeddingModel = (modelId: string) => { + throw new NoSuchModelError({ modelId, modelType: 'textEmbeddingModel' }); + }; + + return provider as OpenAICompatProvider; +} + +/** +Default OpenAICompat provider instance. + */ +export const openaiCompat = createOpenAICompat(); diff --git a/packages/openai-compat/tsconfig.json b/packages/openai-compat/tsconfig.json new file mode 100644 index 000000000000..3dc0ba4f10f5 --- /dev/null +++ b/packages/openai-compat/tsconfig.json @@ -0,0 +1,9 @@ +{ + "extends": "./node_modules/@vercel/ai-tsconfig/react-library.json", + "compilerOptions": { + "target": "ES2018", + "stripInternal": true + }, + "include": ["."], + "exclude": ["dist", "build", "node_modules"] +} diff --git a/packages/openai-compat/tsup.config.ts b/packages/openai-compat/tsup.config.ts new file mode 100644 index 000000000000..3f92041b987c --- /dev/null +++ b/packages/openai-compat/tsup.config.ts @@ -0,0 +1,10 @@ +import { defineConfig } from 'tsup'; + +export default defineConfig([ + { + entry: ['src/index.ts'], + format: ['cjs', 'esm'], + dts: true, + sourcemap: true, + }, +]); diff --git a/packages/openai-compat/turbo.json b/packages/openai-compat/turbo.json new file mode 100644 index 000000000000..620b8380e744 --- /dev/null +++ b/packages/openai-compat/turbo.json @@ -0,0 +1,12 @@ +{ + "extends": [ + "//" + ], + "tasks": { + "build": { + "outputs": [ + "**/dist/**" + ] + } + } +} diff --git a/packages/openai-compat/vitest.edge.config.js b/packages/openai-compat/vitest.edge.config.js new file mode 100644 index 000000000000..700660e913f5 --- /dev/null +++ b/packages/openai-compat/vitest.edge.config.js @@ -0,0 +1,10 @@ +import { defineConfig } from 'vite'; + +// https://vitejs.dev/config/ +export default defineConfig({ + test: { + environment: 'edge-runtime', + globals: true, + include: ['**/*.test.ts', '**/*.test.tsx'], + }, +}); diff --git a/packages/openai-compat/vitest.node.config.js b/packages/openai-compat/vitest.node.config.js new file mode 100644 index 000000000000..b1d14b21fc11 --- /dev/null +++ b/packages/openai-compat/vitest.node.config.js @@ -0,0 +1,10 @@ +import { defineConfig } from 'vite'; + +// https://vitejs.dev/config/ +export default defineConfig({ + test: { + environment: 'node', + globals: true, + include: ['**/*.test.ts', '**/*.test.tsx'], + }, +}); diff --git a/packages/togetherai/CHANGELOG.md b/packages/togetherai/CHANGELOG.md new file mode 100644 index 000000000000..a0fc74cf662c --- /dev/null +++ b/packages/togetherai/CHANGELOG.md @@ -0,0 +1 @@ +# @ai-sdk/togetherai diff --git a/packages/togetherai/README.md b/packages/togetherai/README.md new file mode 100644 index 000000000000..7930eb5f53b3 --- /dev/null +++ b/packages/togetherai/README.md @@ -0,0 +1,3 @@ +# AI SDK - Together.ai Provider + +TODO diff --git a/packages/togetherai/package.json b/packages/togetherai/package.json new file mode 100644 index 000000000000..94812940f4cb --- /dev/null +++ b/packages/togetherai/package.json @@ -0,0 +1,71 @@ +{ + "name": "@ai-sdk/togetherai", + "version": "0.0.0", + "license": "Apache-2.0", + "sideEffects": false, + "main": "./dist/index.js", + "module": "./dist/index.mjs", + "types": "./dist/index.d.ts", + "files": [ + "dist/**/*", + "internal/dist/**/*", + "CHANGELOG.md" + ], + "scripts": { + "build": "tsup", + "build:watch": "tsup --watch", + "clean": "rm -rf dist && rm -rf internal/dist", + "lint": "eslint \"./**/*.ts*\"", + "type-check": "tsc --noEmit", + "prettier-check": "prettier --check \"./**/*.ts*\"", + "test": "pnpm test:node && pnpm test:edge", + "test:edge": "vitest --config vitest.edge.config.js --run", + "test:node": "vitest --config vitest.node.config.js --run" + }, + "exports": { + "./package.json": "./package.json", + ".": { + "types": "./dist/index.d.ts", + "import": "./dist/index.mjs", + "require": "./dist/index.js" + }, + "./internal": { + "types": "./internal/dist/index.d.ts", + "import": "./internal/dist/index.mjs", + "module": "./internal/dist/index.mjs", + "require": "./internal/dist/index.js" + } + }, + "dependencies": { + "@ai-sdk/openai-compat": "0.0.0", + "@ai-sdk/provider": "1.0.0", + "@ai-sdk/provider-utils": "2.0.0" + }, + "devDependencies": { + "@types/node": "^18", + "@vercel/ai-tsconfig": "workspace:*", + "tsup": "^8", + "typescript": "5.6.3", + "zod": "3.23.8" + }, + "peerDependencies": { + "zod": "^3.0.0" + }, + "engines": { + "node": ">=18" + }, + "publishConfig": { + "access": "public" + }, + "homepage": "https://sdk.vercel.ai/docs", + "repository": { + "type": "git", + "url": "git+https://github.com/vercel/ai.git" + }, + "bugs": { + "url": "https://github.com/vercel/ai/issues" + }, + "keywords": [ + "ai" + ] +} diff --git a/packages/togetherai/src/index.ts b/packages/togetherai/src/index.ts new file mode 100644 index 000000000000..53931dbe2f35 --- /dev/null +++ b/packages/togetherai/src/index.ts @@ -0,0 +1,5 @@ +export { createTogetherAI, togetherai } from './togetherai-provider'; +export type { + TogetherAIProvider, + TogetherAIProviderSettings, +} from './togetherai-provider'; diff --git a/packages/togetherai/src/togetherai-chat-settings.ts b/packages/togetherai/src/togetherai-chat-settings.ts new file mode 100644 index 000000000000..f12fdf6b3f7c --- /dev/null +++ b/packages/togetherai/src/togetherai-chat-settings.ts @@ -0,0 +1,5 @@ +import { OpenAICompatChatSettings } from '@ai-sdk/openai-compat'; + +export interface TogetherChatSettings : OpenAICompatChatSettings { + +} diff --git a/packages/togetherai/src/togetherai-provider.ts b/packages/togetherai/src/togetherai-provider.ts new file mode 100644 index 000000000000..b740c1f51da6 --- /dev/null +++ b/packages/togetherai/src/togetherai-provider.ts @@ -0,0 +1,82 @@ +import { + OpenAICompatProvider, + createOpenAICompat, + OpenAICompatChatSettings, + OpenAICompatProviderSettings, +} from '@ai-sdk/openai-compat'; +import { LanguageModelV1 } from '@ai-sdk/provider'; + +// https://api.together.ai/models +// https://docs.together.ai/docs/serverless-models +// https://docs.together.ai/docs/dedicated-models +export type TogetherAIChatModelId = + | 'google/gemma-2-9b-it' + | 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'; + +// TODO(shaper): Add Language and Embedding model ids. + +export interface TogetherAIProviderSettings + extends OpenAICompatProviderSettings { + /** + * Additional Together-specific settings can be added here. + */ + togetherOption?: string; +} + +export interface TogetherAIProvider + extends OpenAICompatProvider { + /** + * Example of a Together-specific method. + */ + togetherSpecificMethod(): void; +} + +export function createTogetherAI( + options: TogetherAIProviderSettings = {}, +): TogetherAIProvider { + // Create an instance of OpenAICompatProvider with the provided options + const openAICompatProvider = + createOpenAICompat(options); + + /** + * Implement Together-specific methods here. + * For example, a method that performs additional logging. + */ + const togetherSpecificMethod = () => { + console.log('Together-specific method invoked.'); + // Add any Together-specific logic here + }; + + /** + * Combine OpenAICompatProvider with Together-specific methods. + * Object.assign is used to merge the functions and methods. + */ + const togetheraiProvider: TogetherAIProvider = Object.assign( + // The provider function + ( + modelId: TogetherAIChatModelId, + settings?: OpenAICompatChatSettings, + ): LanguageModelV1 => { + return openAICompatProvider(modelId, settings); + }, + { + // Delegate the languageModel method to OpenAICompatProvider + languageModel: openAICompatProvider.languageModel, + + // Delegate the chat method to OpenAICompatProvider + // chat: openAICompatProvider.chat, + + // // Delegate the textEmbeddingModel method to OpenAICompatProvider + // textEmbeddingModel: openAICompatProvider.textEmbeddingModel, + + // // Add Together-specific methods + // togetherSpecificMethod, + + // You can add more Together-specific methods or override existing ones if needed + }, + ) as TogetherAIProvider; + + return togetheraiProvider; +} + +export const togetherai = createTogetherAI(); diff --git a/packages/togetherai/tsconfig.json b/packages/togetherai/tsconfig.json new file mode 100644 index 000000000000..3dc0ba4f10f5 --- /dev/null +++ b/packages/togetherai/tsconfig.json @@ -0,0 +1,9 @@ +{ + "extends": "./node_modules/@vercel/ai-tsconfig/react-library.json", + "compilerOptions": { + "target": "ES2018", + "stripInternal": true + }, + "include": ["."], + "exclude": ["dist", "build", "node_modules"] +} diff --git a/packages/togetherai/tsup.config.ts b/packages/togetherai/tsup.config.ts new file mode 100644 index 000000000000..3f92041b987c --- /dev/null +++ b/packages/togetherai/tsup.config.ts @@ -0,0 +1,10 @@ +import { defineConfig } from 'tsup'; + +export default defineConfig([ + { + entry: ['src/index.ts'], + format: ['cjs', 'esm'], + dts: true, + sourcemap: true, + }, +]); diff --git a/packages/togetherai/turbo.json b/packages/togetherai/turbo.json new file mode 100644 index 000000000000..620b8380e744 --- /dev/null +++ b/packages/togetherai/turbo.json @@ -0,0 +1,12 @@ +{ + "extends": [ + "//" + ], + "tasks": { + "build": { + "outputs": [ + "**/dist/**" + ] + } + } +} diff --git a/packages/togetherai/vitest.edge.config.js b/packages/togetherai/vitest.edge.config.js new file mode 100644 index 000000000000..700660e913f5 --- /dev/null +++ b/packages/togetherai/vitest.edge.config.js @@ -0,0 +1,10 @@ +import { defineConfig } from 'vite'; + +// https://vitejs.dev/config/ +export default defineConfig({ + test: { + environment: 'edge-runtime', + globals: true, + include: ['**/*.test.ts', '**/*.test.tsx'], + }, +}); diff --git a/packages/togetherai/vitest.node.config.js b/packages/togetherai/vitest.node.config.js new file mode 100644 index 000000000000..b1d14b21fc11 --- /dev/null +++ b/packages/togetherai/vitest.node.config.js @@ -0,0 +1,10 @@ +import { defineConfig } from 'vite'; + +// https://vitejs.dev/config/ +export default defineConfig({ + test: { + environment: 'node', + globals: true, + include: ['**/*.test.ts', '**/*.test.tsx'], + }, +}); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 73bd88fe9970..98abc7023434 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -83,6 +83,9 @@ importers: '@ai-sdk/openai': specifier: 1.0.2 version: link:../../packages/openai + '@ai-sdk/togetherai': + specifier: 0.0.0 + version: link:../../packages/togetherai '@ai-sdk/xai': specifier: 1.0.2 version: link:../../packages/xai @@ -1361,6 +1364,31 @@ importers: specifier: 3.23.8 version: 3.23.8 + packages/openai-compat: + dependencies: + '@ai-sdk/provider': + specifier: 1.0.0 + version: link:../provider + '@ai-sdk/provider-utils': + specifier: 2.0.0 + version: link:../provider-utils + devDependencies: + '@types/node': + specifier: ^18 + version: 18.19.54 + '@vercel/ai-tsconfig': + specifier: workspace:* + version: link:../../tools/tsconfig + tsup: + specifier: ^8 + version: 8.3.0(jiti@2.4.0)(postcss@8.4.49)(tsx@4.7.1)(typescript@5.6.3)(yaml@2.5.0) + typescript: + specifier: 5.6.3 + version: 5.6.3 + zod: + specifier: 3.23.8 + version: 3.23.8 + packages/provider: dependencies: json-schema: @@ -1589,6 +1617,34 @@ importers: specifier: 5.6.3 version: 5.6.3 + packages/togetherai: + dependencies: + '@ai-sdk/openai-compat': + specifier: 0.0.0 + version: link:../openai-compat + '@ai-sdk/provider': + specifier: 1.0.0 + version: link:../provider + '@ai-sdk/provider-utils': + specifier: 2.0.0 + version: link:../provider-utils + devDependencies: + '@types/node': + specifier: ^18 + version: 18.19.54 + '@vercel/ai-tsconfig': + specifier: workspace:* + version: link:../../tools/tsconfig + tsup: + specifier: ^8 + version: 8.3.0(jiti@2.4.0)(postcss@8.4.49)(tsx@4.7.1)(typescript@5.6.3)(yaml@2.5.0) + typescript: + specifier: 5.6.3 + version: 5.6.3 + zod: + specifier: 3.23.8 + version: 3.23.8 + packages/ui-utils: dependencies: '@ai-sdk/provider': From e7da088b1fe5c2fab77dff68060f4ce2d240ee4f Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Mon, 18 Nov 2024 18:02:34 -0800 Subject: [PATCH 02/13] More exploration. --- .../ai-core/src/generate-object/togetherai.ts | 30 ++++ packages/openai-compat/src/index.ts | 4 +- .../src/openai-compat-chat-settings.ts | 2 - .../src/openai-compat-completion-settings.ts | 48 +++++++ .../src/openai-compat-embedding-model.ts | 106 ++++++++++++++ .../src/openai-compat-embedding-settings.ts | 25 ++++ .../src/openai-compat-provider.ts | 100 +++++++------ .../src/togetherai-chat-settings.ts | 133 +++++++++++++++++- .../src/togetherai-completion-settings.ts | 59 ++++++++ .../src/togetherai-embedding-settings.ts | 29 ++++ .../togetherai/src/togetherai-provider.ts | 107 +++++++------- 11 files changed, 549 insertions(+), 94 deletions(-) create mode 100644 examples/ai-core/src/generate-object/togetherai.ts create mode 100644 packages/openai-compat/src/openai-compat-completion-settings.ts create mode 100644 packages/openai-compat/src/openai-compat-embedding-model.ts create mode 100644 packages/openai-compat/src/openai-compat-embedding-settings.ts create mode 100644 packages/togetherai/src/togetherai-completion-settings.ts create mode 100644 packages/togetherai/src/togetherai-embedding-settings.ts diff --git a/examples/ai-core/src/generate-object/togetherai.ts b/examples/ai-core/src/generate-object/togetherai.ts new file mode 100644 index 000000000000..8c44b0eddce3 --- /dev/null +++ b/examples/ai-core/src/generate-object/togetherai.ts @@ -0,0 +1,30 @@ +import { togetherai } from '@ai-sdk/togetherai'; +import { generateObject } from 'ai'; +import 'dotenv/config'; +import { z } from 'zod'; + +async function main() { + const result = await generateObject({ + model: togetherai('google/gemma-2b-it'), + schema: z.object({ + recipe: z.object({ + name: z.string(), + ingredients: z.array( + z.object({ + name: z.string(), + amount: z.string(), + }), + ), + steps: z.array(z.string()), + }), + }), + prompt: 'Generate a lasagna recipe.', + }); + + console.log(JSON.stringify(result.object.recipe, null, 2)); + console.log(); + console.log('Token usage:', result.usage); + console.log('Finish reason:', result.finishReason); +} + +main().catch(console.error); diff --git a/packages/openai-compat/src/index.ts b/packages/openai-compat/src/index.ts index 79f3cf8f88d5..dfdbbb22f50e 100644 --- a/packages/openai-compat/src/index.ts +++ b/packages/openai-compat/src/index.ts @@ -1,6 +1,8 @@ -export { createOpenAICompat, openaiCompat } from './openai-compat-provider'; +export { createOpenAICompat } from './openai-compat-provider'; export type { OpenAICompatProvider, OpenAICompatProviderSettings, } from './openai-compat-provider'; export type { OpenAICompatChatSettings } from './openai-compat-chat-settings'; +export type { OpenAICompatCompletionSettings } from './openai-compat-completion-settings'; +export type { OpenAICompatEmbeddingSettings } from './openai-compat-embedding-settings'; diff --git a/packages/openai-compat/src/openai-compat-chat-settings.ts b/packages/openai-compat/src/openai-compat-chat-settings.ts index e5554224b5a2..d7c6dbccec2d 100644 --- a/packages/openai-compat/src/openai-compat-chat-settings.ts +++ b/packages/openai-compat/src/openai-compat-chat-settings.ts @@ -1,5 +1,3 @@ -// TODO(shaper): Need to generalize/fix the below to use an interface somehow. -// https://console.x.ai and see "View models" export type OpenAICompatChatModelId = string; export interface OpenAICompatChatSettings { diff --git a/packages/openai-compat/src/openai-compat-completion-settings.ts b/packages/openai-compat/src/openai-compat-completion-settings.ts new file mode 100644 index 000000000000..e64feac878e9 --- /dev/null +++ b/packages/openai-compat/src/openai-compat-completion-settings.ts @@ -0,0 +1,48 @@ +export type OpenAICompatCompletionModelId = string; + +export interface OpenAICompatCompletionSettings { + /** +Echo back the prompt in addition to the completion. + */ + echo?: boolean; + + /** +Modify the likelihood of specified tokens appearing in the completion. + +Accepts a JSON object that maps tokens (specified by their token ID in +the GPT tokenizer) to an associated bias value from -100 to 100. You +can use this tokenizer tool to convert text to token IDs. Mathematically, +the bias is added to the logits generated by the model prior to sampling. +The exact effect will vary per model, but values between -1 and 1 should +decrease or increase likelihood of selection; values like -100 or 100 +should result in a ban or exclusive selection of the relevant token. + +As an example, you can pass {"50256": -100} to prevent the <|endoftext|> +token from being generated. + */ + logitBias?: Record; + + /** +Return the log probabilities of the tokens. Including logprobs will increase +the response size and can slow down response times. However, it can +be useful to better understand how the model is behaving. + +Setting to true will return the log probabilities of the tokens that +were generated. + +Setting to a number will return the log probabilities of the top n +tokens that were generated. + */ + logprobs?: boolean | number; + + /** +The suffix that comes after a completion of inserted text. + */ + suffix?: string; + + /** +A unique identifier representing your end-user, which can help OpenAI to +monitor and detect abuse. Learn more. + */ + user?: string; +} diff --git a/packages/openai-compat/src/openai-compat-embedding-model.ts b/packages/openai-compat/src/openai-compat-embedding-model.ts new file mode 100644 index 000000000000..92638aeaed6a --- /dev/null +++ b/packages/openai-compat/src/openai-compat-embedding-model.ts @@ -0,0 +1,106 @@ +import { + EmbeddingModelV1, + TooManyEmbeddingValuesForCallError, +} from '@ai-sdk/provider'; +import { + combineHeaders, + createJsonResponseHandler, + FetchFunction, + postJsonToApi, +} from '@ai-sdk/provider-utils'; +import { z } from 'zod'; +import { + OpenAICompatEmbeddingModelId, + OpenAICompatEmbeddingSettings, +} from './openai-compat-embedding-settings'; +import { openaiCompatFailedResponseHandler } from './openai-compat-error'; + +type OpenAIEmbeddingConfig = { + provider: string; + url: (options: { modelId: string; path: string }) => string; + headers: () => Record; + fetch?: FetchFunction; +}; + +export class OpenAICompatEmbeddingModel implements EmbeddingModelV1 { + readonly specificationVersion = 'v1'; + readonly modelId: OpenAICompatEmbeddingModelId; + + private readonly config: OpenAIEmbeddingConfig; + private readonly settings: OpenAICompatEmbeddingSettings; + + get provider(): string { + return this.config.provider; + } + + get maxEmbeddingsPerCall(): number { + return this.settings.maxEmbeddingsPerCall ?? 2048; + } + + get supportsParallelCalls(): boolean { + return this.settings.supportsParallelCalls ?? true; + } + + constructor( + modelId: OpenAICompatEmbeddingModelId, + settings: OpenAICompatEmbeddingSettings, + config: OpenAIEmbeddingConfig, + ) { + this.modelId = modelId; + this.settings = settings; + this.config = config; + } + + async doEmbed({ + values, + headers, + abortSignal, + }: Parameters['doEmbed']>[0]): Promise< + Awaited['doEmbed']>> + > { + if (values.length > this.maxEmbeddingsPerCall) { + throw new TooManyEmbeddingValuesForCallError({ + provider: this.provider, + modelId: this.modelId, + maxEmbeddingsPerCall: this.maxEmbeddingsPerCall, + values, + }); + } + + const { responseHeaders, value: response } = await postJsonToApi({ + url: this.config.url({ + path: '/embeddings', + modelId: this.modelId, + }), + headers: combineHeaders(this.config.headers(), headers), + body: { + model: this.modelId, + input: values, + encoding_format: 'float', + dimensions: this.settings.dimensions, + user: this.settings.user, + }, + failedResponseHandler: openaiCompatFailedResponseHandler, + successfulResponseHandler: createJsonResponseHandler( + openaiTextEmbeddingResponseSchema, + ), + abortSignal, + fetch: this.config.fetch, + }); + + return { + embeddings: response.data.map(item => item.embedding), + usage: response.usage + ? { tokens: response.usage.prompt_tokens } + : undefined, + rawResponse: { headers: responseHeaders }, + }; + } +} + +// minimal version of the schema, focussed on what is needed for the implementation +// this approach limits breakages when the API changes and increases efficiency +const openaiTextEmbeddingResponseSchema = z.object({ + data: z.array(z.object({ embedding: z.array(z.number()) })), + usage: z.object({ prompt_tokens: z.number() }).nullish(), +}); diff --git a/packages/openai-compat/src/openai-compat-embedding-settings.ts b/packages/openai-compat/src/openai-compat-embedding-settings.ts new file mode 100644 index 000000000000..42a8d730dc37 --- /dev/null +++ b/packages/openai-compat/src/openai-compat-embedding-settings.ts @@ -0,0 +1,25 @@ +export type OpenAICompatEmbeddingModelId = string; + +export interface OpenAICompatEmbeddingSettings { + /** +Override the maximum number of embeddings per call. + */ + maxEmbeddingsPerCall?: number; + + /** +Override the parallelism of embedding calls. + */ + supportsParallelCalls?: boolean; + + /** +The number of dimensions the resulting output embeddings should have. +Only supported in text-embedding-3 and later models. + */ + dimensions?: number; + + /** +A unique identifier representing your end-user, which can help OpenAI to +monitor and detect abuse. Learn more. +*/ + user?: string; +} diff --git a/packages/openai-compat/src/openai-compat-provider.ts b/packages/openai-compat/src/openai-compat-provider.ts index 4fac23cd9891..1de1c127fd30 100644 --- a/packages/openai-compat/src/openai-compat-provider.ts +++ b/packages/openai-compat/src/openai-compat-provider.ts @@ -1,6 +1,6 @@ import { + EmbeddingModelV1, LanguageModelV1, - NoSuchModelError, ProviderV1, } from '@ai-sdk/provider'; import { @@ -9,25 +9,26 @@ import { withoutTrailingSlash, } from '@ai-sdk/provider-utils'; import { OpenAICompatChatLanguageModel } from './openai-compat-chat-language-model'; -import { - OpenAICompatChatModelId, - OpenAICompatChatSettings, -} from './openai-compat-chat-settings'; +import { OpenAICompatChatSettings } from './openai-compat-chat-settings'; +import { OpenAICompatCompletionSettings } from './openai-compat-completion-settings'; +import { OpenAICompatEmbeddingSettings } from './openai-compat-embedding-settings'; +import { OpenAICompatEmbeddingModel } from './openai-compat-embedding-model'; export interface OpenAICompatProvider extends ProviderV1 { - /** -Creates a model for text generation. -*/ (modelId: M, settings?: OpenAICompatChatSettings): LanguageModelV1; - /** -Creates an OpenAICompat chat model for text generation. - */ languageModel( modelId: M, - settings?: OpenAICompatChatSettings, + settings?: OpenAICompatCompletionSettings, ): LanguageModelV1; + + chatModel(modelId: M, settings?: OpenAICompatChatSettings): LanguageModelV1; + + textEmbeddingModel( + modelId: M, + settings?: OpenAICompatEmbeddingSettings, + ): EmbeddingModelV1; } export interface OpenAICompatProviderSettings { @@ -51,30 +52,50 @@ Custom fetch implementation. You can use it as a middleware to intercept request or to provide a custom fetch implementation for e.g. testing. */ fetch?: FetchFunction; + + /** +The name of the environment variable from which to load the API key if not explicitly provided. + */ + apiKeyEnvVarName?: string; + + /** +Description of the API key environment variable for error messages. + */ + apiKeyEnvVarDescription?: string; } /** Create an OpenAICompat provider instance. */ export function createOpenAICompat( - options: OpenAICompatProviderSettings = {}, + options: OpenAICompatProviderSettings, ): OpenAICompatProvider { - // TODO(shaper): Generalize: - // - base url - // - api key name - const baseURL = - withoutTrailingSlash(options.baseURL) ?? 'https://api.x.ai/v1'; + // TODO(shaper): Throw if baseURL isn't set. + const baseURL = withoutTrailingSlash(options.baseURL); const getHeaders = () => ({ - // TODO(shaper): Need to use an interface for the below, and/or throw. Authorization: `Bearer ${loadApiKey({ apiKey: options.apiKey, - environmentVariableName: 'OPENAI_COMPAT_API_KEY', - description: 'OpenAICompat API key', + environmentVariableName: options.apiKeyEnvVarName ?? '', + description: options.apiKeyEnvVarDescription ?? '', })}`, ...options.headers, }); + const createLanguageModel = ( + modelId: M, + settings?: OpenAICompatCompletionSettings, + ) => { + if (new.target) { + throw new Error( + 'The OpenAICompat model function cannot be called with the new keyword.', + ); + } + + // TODO(shaper): Do we need to pull in and strip down the OpenAI Completion Model? + return createChatModel(modelId, settings); + }; + const createChatModel = ( modelId: M, settings: OpenAICompatChatSettings = {}, @@ -86,33 +107,30 @@ export function createOpenAICompat( fetch: options.fetch, }); - const createLanguageModel = ( + const createEmbeddingModel = ( modelId: M, - settings?: OpenAICompatChatSettings, - ) => { - if (new.target) { - throw new Error( - 'The OpenAICompat model function cannot be called with the new keyword.', - ); - } - - return createChatModel(modelId, settings); - }; + settings: OpenAICompatEmbeddingSettings = {}, + ) => + new OpenAICompatEmbeddingModel(modelId, settings, { + provider: 'openaiCompat.embedding', + url: ({ path }) => `${baseURL}${path}`, + headers: getHeaders, + fetch: options.fetch, + }); const provider = function (modelId: M, settings?: OpenAICompatChatSettings) { return createLanguageModel(modelId, settings); }; provider.languageModel = createLanguageModel; - provider.chat = createChatModel; - provider.textEmbeddingModel = (modelId: string) => { - throw new NoSuchModelError({ modelId, modelType: 'textEmbeddingModel' }); - }; + provider.chatModel = createChatModel; + provider.textEmbeddingModel = createEmbeddingModel; + + // TODO(shaper): Need a way for concrete impls to note if they don't support + // one of the model types. + // provider.textEmbeddingModel = (modelId: string) => { + // throw new NoSuchModelError({ modelId, modelType: 'textEmbeddingModel' }); + // }; return provider as OpenAICompatProvider; } - -/** -Default OpenAICompat provider instance. - */ -export const openaiCompat = createOpenAICompat(); diff --git a/packages/togetherai/src/togetherai-chat-settings.ts b/packages/togetherai/src/togetherai-chat-settings.ts index f12fdf6b3f7c..4a964e866181 100644 --- a/packages/togetherai/src/togetherai-chat-settings.ts +++ b/packages/togetherai/src/togetherai-chat-settings.ts @@ -1,5 +1,134 @@ import { OpenAICompatChatSettings } from '@ai-sdk/openai-compat'; -export interface TogetherChatSettings : OpenAICompatChatSettings { +// https://docs.together.ai/docs/serverless-models#chat-models +export type TogetherAIChatModelId = + | 'databricks/dbrx-instruct' + | 'deepseek-ai/deepseek-llm-67b-chat' + | 'google/gemma-2-27b-it' + | 'google/gemma-2-9b-it' + | 'google/gemma-2b-it' + | 'Gryphe/MythoMax-L2-13b' + | 'meta-llama/Llama-2-13b-chat-hf' + | 'meta-llama/Llama-3-70b-chat-hf' + | 'meta-llama/Llama-3-8b-chat-hf' + | 'meta-llama/Llama-3.2-3B-Instruct-Turbo' + | 'meta-llama/Meta-Llama-3-70B-Instruct-Lite' + | 'meta-llama/Meta-Llama-3-70B-Instruct-Turbo' + | 'meta-llama/Meta-Llama-3-8B-Instruct-Lite' + | 'meta-llama/Meta-Llama-3-8B-Instruct-Turbo' + | 'meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo' + | 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo' + | 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo' + | 'microsoft/WizardLM-2-8x22B' + | 'mistralai/Mistral-7B-Instruct-v0.1' + | 'mistralai/Mistral-7B-Instruct-v0.2' + | 'mistralai/Mistral-7B-Instruct-v0.3' + | 'mistralai/Mixtral-8x22B-Instruct-v0.1' + | 'mistralai/Mixtral-8x7B-Instruct-v0.1' + | 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO' + | 'nvidia/Llama-3.1-Nemotron-70B-Instruct-HF' + | 'Qwen/Qwen2-72B-Instruct' + | 'Qwen/Qwen2.5-72B-Instruct-Turbo' + | 'Qwen/Qwen2.5-7B-Instruct-Turbo' + | 'Qwen/Qwen2.5-Coder-32B-Instruct' + | 'togethercomputer/StripedHyena-Nous-7B' + | 'upstage/SOLAR-10.7B-Instruct-v1.0' + | (string & {}); -} +// https://docs.together.ai/docs/dedicated-models#chat-models +// export type TogetherAIChatModelId_Dedicated = +// | 'databricks/dbrx-instruct' +// | 'deepseek-ai/deepseek-coder-33b-instruct' +// | 'deepseek-ai/deepseek-llm-67b-chat' +// | 'google/gemma-2-27b-it' +// | 'google/gemma-2-9b-it' +// | 'google/gemma-2b-it' +// | 'google/gemma-7b-it' +// | 'gradientai/Llama-3-70B-Instruct-Gradient-1048k' +// | 'Gryphe/MythoMax-L2-13b-Lite' +// | 'Gryphe/MythoMax-L2-13b' +// | 'Haotian Liu/LLaVa-Next (Mistral-7B)' +// | 'HuggingFaceH4/zephyr-7b-beta' +// | 'lmSys/Koala (13B)' +// | 'lmSys/Koala-7B' +// | 'lmSys/Vicuna v1.3 (13B)' +// | 'lmSys/Vicuna v1.3 (7B)' +// | 'lmSys/Vicuna v1.5 (13B)' +// | 'lmSys/Vicuna v1.5 (7B)' +// | 'lmSys/Vicuna v1.5 16K (13B)' +// | 'Meta/Code Llama Instruct (13B)' +// | 'Meta/Code Llama Instruct (13B)' +// | 'Meta/Code Llama Instruct (34B)' +// | 'Meta/Code Llama Instruct (34B)' +// | 'Meta/Code Llama Instruct (70B)' +// | 'Meta/Code Llama Instruct (7B)' +// | 'Meta/LLaMA-2 Chat (13B)' +// | 'Meta/LLaMA-2 Chat (13B)' +// | 'Meta/LLaMA-2 Chat (70B)' +// | 'Meta/LLaMA-2 Chat (7B)' +// | 'Meta/LLaMA-2 Chat (7B)' +// | 'Meta/Llama3 8B Chat HF INT4' +// | 'Meta/Meta Llama 3 70B Instruct Lite' +// | 'Meta/Meta Llama 3 70B Instruct Reference' +// | 'Meta/Meta Llama 3 70B Instruct Turbo' +// | 'Meta/Meta Llama 3 70B Instruct' +// | 'Meta/Meta Llama 3 8B Instruct Lite' +// | 'Meta/Meta Llama 3 8B Instruct Reference' +// | 'Meta/Meta Llama 3 8B Instruct Turbo' +// | 'Meta/Meta Llama 3 8B Instruct' +// | 'Meta/Meta Llama 3.1 405B Instruct Turbo' +// | 'Meta/Meta Llama 3.1 405B Instruct Turbo' +// | 'Meta/Meta Llama 3.1 70B Instruct Turbo' +// | 'Meta/Meta Llama 3.1 8B Instruct Turbo' +// | 'Meta/Meta Llama 3.2 11B Vision Instruct Turbo' +// | 'Meta/Meta Llama 3.2 3B Instruct Turbo' +// | 'Meta/Meta Llama 3.2 90B Vision Instruct Turbo' +// | 'Meta/Meta Llama Vision Free' +// | 'Meta/Togethercomputer Llama3 8B Instruct Int8' +// | 'microsoft/WizardLM-2-8x22B' +// | 'mistralai/Mistral-7B-Instruct-v0.1' +// | 'mistralai/Mistral-7B-Instruct-v0.2' +// | 'mistralai/Mistral-7B-Instruct-v0.3' +// | 'mistralai/Mixtral-8x22B-Instruct-v0.1' +// | 'mistralai/Mixtral-8x7B-Instruct-v0.1' +// | 'NousResearch/Hermes-2-Theta-Llama-3-70B' +// | 'NousResearch/Nous-Capybara-7B-V1p9' +// | 'NousResearch/Nous-Hermes-2-Mistral-7B-DPO' +// | 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO' +// | 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT' +// | 'NousResearch/Nous-Hermes-Llama2-13b' +// | 'NousResearch/Nous-Hermes-Llama2-70b' +// | 'Open-Orca/Mistral-7B-OpenOrca' +// | 'openchat/openchat-3.5-1210' +// | 'Qwen/Qwen1.5-0.5B-Chat' +// | 'Qwen/Qwen1.5-1.8B-Chat' +// | 'Qwen/Qwen1.5-110B-Chat' +// | 'Qwen/Qwen1.5-14B-Chat' +// | 'Qwen/Qwen1.5-32B-Chat' +// | 'Qwen/Qwen1.5-4B-Chat' +// | 'Qwen/Qwen1.5-72B-Chat' +// | 'Qwen/Qwen1.5-7B-Chat' +// | 'Qwen/Qwen2-1.5B-Instruct' +// | 'Qwen/Qwen2-72B-Instruct' +// | 'Qwen/Qwen2-7B-Instruct' +// | 'Qwen/Qwen2.5-72B-Instruct-Turbo' +// | 'Qwen/Qwen2.5-7B-Instruct-Turbo' +// | 'Qwen/Qwen2.5-Coder-32B-Instruct' +// | 'snorkelai/Snorkel-Mistral-PairRM-DPO' +// | 'Snowflake/snowflake-arctic-instruct' +// | 'teknium/OpenHermes-2-Mistral-7B' +// | 'teknium/OpenHermes-2p5-Mistral-7B' +// | 'test/test11' +// | 'togethercomputer/alpaca-7b' +// | 'togethercomputer/guanaco-13b' +// | 'togethercomputer/guanaco-33b' +// | 'togethercomputer/guanaco-65b' +// | 'togethercomputer/guanaco-7b' +// | 'togethercomputer/SOLAR-10.7B-Instruct-v1.0-int4' +// | 'togethercomputer/SOLAR-10.7B-Instruct-v1.0' +// | 'Undi95/ReMM-SLERP-L2-13B' +// | 'Undi95/Toppy-M-7B' +// | 'WizardLM/WizardLM-13B-V1.2' +// | (string & {}); + +export interface TogetherAIChatSettings extends OpenAICompatChatSettings {} diff --git a/packages/togetherai/src/togetherai-completion-settings.ts b/packages/togetherai/src/togetherai-completion-settings.ts new file mode 100644 index 000000000000..562d08f49da2 --- /dev/null +++ b/packages/togetherai/src/togetherai-completion-settings.ts @@ -0,0 +1,59 @@ +import { OpenAICompatCompletionSettings } from '@ai-sdk/openai-compat'; + +// https://docs.together.ai/docs/serverless-models#language-models +export type TogetherAICompletionModelId = + | 'codellama/CodeLlama-34b-Instruct-hf' + | 'Qwen/Qwen2.5-Coder-32B-Instruct' + | (string & {}); + +// https://docs.together.ai/docs/dedicated-models#language-models +// export type TogetherAICompletionModelId = +// | 'allenai/OLMo-7B-Instruct' +// | 'EleutherAI/llemma_7b' +// | 'google/gemma-2-9b' +// | 'google/gemma-2b' +// | 'google/gemma-7b' +// | 'gpt-3.5-turbo-instruct' +// | 'huggyllama/llama-13b' +// | 'huggyllama/llama-30b' +// | 'huggyllama/llama-65b' +// | 'huggyllama/llama-7b' +// | 'meta-llama/Llama-2-13b-hf' +// | 'meta-llama/Llama-2-70b-hf' +// | 'meta-llama/Llama-2-7b-hf' +// | 'meta-llama/Llama-3-8b-hf' +// | 'meta-llama/Meta-Llama-3-70b-hf' +// | 'meta-llama/Meta-Llama-3-70B' +// | 'meta-llama/Meta-Llama-3-8B' +// | 'meta-llama/Meta-Llama-3.1-70B-Reference' +// | 'meta-llama/Meta-Llama-3.1-8B-Reference' +// | 'microsoft/phi-2' +// | 'mistralai/Mistral-7B-v0.1' +// | 'mistralai/Mixtral-8x22B' +// | 'mistralai/Mixtral-8x7B-v0.1' +// | 'Nexusflow/NexusRaven-V2-13B' +// | 'NousResearch/Nous-Hermes-13b' +// | 'Qwen/Qwen1.5-0.5B' +// | 'Qwen/Qwen1.5-1.8B' +// | 'Qwen/Qwen1.5-14B' +// | 'Qwen/Qwen1.5-32B' +// | 'Qwen/Qwen1.5-4B' +// | 'Qwen/Qwen1.5-72B' +// | 'Qwen/Qwen1.5-7B' +// | 'Qwen/Qwen2-1.5B' +// | 'Qwen/Qwen2-72B' +// | 'Qwen/Qwen2-7B' +// | 'togethercomputer/evo-1-131k-base' +// | 'togethercomputer/evo-1-8k-base' +// | 'togethercomputer/llama-2-13b' +// | 'togethercomputer/llama-2-70b' +// | 'togethercomputer/LLaMA-2-7B-32K' +// | 'togethercomputer/llama-2-7b' +// | 'togethercomputer/StripedHyena-Hessian-7B' +// | 'WizardLM/WizardLM-70B-V1.0' +// | 'zero-one-ai/Yi-34B' +// | 'zero-one-ai/Yi-6B' +// | (string & {}); + +export interface TogetherAICompletionSettings + extends OpenAICompatCompletionSettings {} diff --git a/packages/togetherai/src/togetherai-embedding-settings.ts b/packages/togetherai/src/togetherai-embedding-settings.ts new file mode 100644 index 000000000000..f010e8016d1b --- /dev/null +++ b/packages/togetherai/src/togetherai-embedding-settings.ts @@ -0,0 +1,29 @@ +import { OpenAICompatEmbeddingSettings } from '@ai-sdk/openai-compat'; + +// https://docs.together.ai/docs/serverless-models#embedding-models +export type TogetherAIEmbeddingModelId = + | 'BAAI/bge-base-en-v1.5' + | 'BAAI/bge-large-en-v1.5' + | 'bert-base-uncased' + | 'sentence-transformers/msmarco-bert-base-dot-v5' + | 'togethercomputer/m2-bert-80M-2k-retrieval' + | 'togethercomputer/m2-bert-80M-32k-retrieval' + | 'togethercomputer/m2-bert-80M-8k-retrieval' + | 'WhereIsAI/UAE-Large-V1' + | (string & {}); + +// https://docs.together.ai/docs/dedicated-models#embedding-models +// export type TogetherAIEmbeddingModelId = +// | 'BAAI/bge-base-en-v1.5' +// | 'BAAI/bge-large-en-v1.5' +// | 'bert-base-uncased' +// | 'hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1' +// | 'sentence-transformers/msmarco-bert-base-dot-v5' +// | 'togethercomputer/m2-bert-80M-2k-retrieval' +// | 'togethercomputer/m2-bert-80M-32k-retrieval' +// | 'togethercomputer/m2-bert-80M-8k-retrieval' +// | 'WhereIsAI/UAE-Large-V1' +// | (string & {}); + +export interface TogetherAIEmbeddingSettings + extends OpenAICompatEmbeddingSettings {} diff --git a/packages/togetherai/src/togetherai-provider.ts b/packages/togetherai/src/togetherai-provider.ts index b740c1f51da6..a2aa163b2dfb 100644 --- a/packages/togetherai/src/togetherai-provider.ts +++ b/packages/togetherai/src/togetherai-provider.ts @@ -4,55 +4,59 @@ import { OpenAICompatChatSettings, OpenAICompatProviderSettings, } from '@ai-sdk/openai-compat'; -import { LanguageModelV1 } from '@ai-sdk/provider'; - -// https://api.together.ai/models -// https://docs.together.ai/docs/serverless-models -// https://docs.together.ai/docs/dedicated-models -export type TogetherAIChatModelId = - | 'google/gemma-2-9b-it' - | 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'; - -// TODO(shaper): Add Language and Embedding model ids. +import { LanguageModelV1, EmbeddingModelV1 } from '@ai-sdk/provider'; +import { TogetherAIChatModelId } from './togetherai-chat-settings'; +import { + TogetherAIEmbeddingModelId, + TogetherAIEmbeddingSettings, +} from './togetherai-embedding-settings'; +import { + TogetherAICompletionModelId, + TogetherAICompletionSettings, +} from './togetherai-completion-settings'; export interface TogetherAIProviderSettings - extends OpenAICompatProviderSettings { - /** - * Additional Together-specific settings can be added here. - */ - togetherOption?: string; -} + extends OpenAICompatProviderSettings {} export interface TogetherAIProvider - extends OpenAICompatProvider { - /** - * Example of a Together-specific method. - */ - togetherSpecificMethod(): void; + extends OpenAICompatProvider< + | TogetherAIChatModelId + | TogetherAICompletionModelId + | TogetherAIEmbeddingModelId + > { + chatModel( + modelId: TogetherAIChatModelId, + settings?: OpenAICompatChatSettings, + ): LanguageModelV1; + + completionModel( + modelId: TogetherAICompletionModelId, + settings?: OpenAICompatChatSettings, + ): LanguageModelV1; + + textEmbeddingModel( + modelId: TogetherAIEmbeddingModelId, + settings?: TogetherAIEmbeddingSettings, + ): EmbeddingModelV1; } export function createTogetherAI( options: TogetherAIProviderSettings = {}, ): TogetherAIProvider { - // Create an instance of OpenAICompatProvider with the provided options - const openAICompatProvider = - createOpenAICompat(options); - - /** - * Implement Together-specific methods here. - * For example, a method that performs additional logging. - */ - const togetherSpecificMethod = () => { - console.log('Together-specific method invoked.'); - // Add any Together-specific logic here + const providerOptions: OpenAICompatProviderSettings = { + baseURL: 'https://api.together.xyz/v1/', + apiKeyEnvVarName: 'TOGETHER_AI_API_KEY', + apiKeyEnvVarDescription: "TogetherAI's API key", + ...options, }; + // TODO(shaper): Consider separating generics in the ctor. + const openAICompatProvider = createOpenAICompat< + | TogetherAIChatModelId + | TogetherAICompletionModelId + | TogetherAIEmbeddingModelId + >(providerOptions); - /** - * Combine OpenAICompatProvider with Together-specific methods. - * Object.assign is used to merge the functions and methods. - */ const togetheraiProvider: TogetherAIProvider = Object.assign( - // The provider function ( modelId: TogetherAIChatModelId, settings?: OpenAICompatChatSettings, @@ -60,19 +64,26 @@ export function createTogetherAI( return openAICompatProvider(modelId, settings); }, { - // Delegate the languageModel method to OpenAICompatProvider - languageModel: openAICompatProvider.languageModel, - - // Delegate the chat method to OpenAICompatProvider - // chat: openAICompatProvider.chat, - - // // Delegate the textEmbeddingModel method to OpenAICompatProvider - // textEmbeddingModel: openAICompatProvider.textEmbeddingModel, + chatModel: ( + modelId: TogetherAIChatModelId, + settings?: OpenAICompatChatSettings, + ) => { + return openAICompatProvider.chatModel(modelId, settings); + }, - // // Add Together-specific methods - // togetherSpecificMethod, + completionModel: ( + modelId: TogetherAICompletionModelId, + settings?: TogetherAICompletionSettings, + ) => { + return openAICompatProvider.languageModel(modelId, settings); + }, - // You can add more Together-specific methods or override existing ones if needed + textEmbeddingModel: ( + modelId: TogetherAIEmbeddingModelId, + settings?: TogetherAIEmbeddingSettings, + ) => { + return openAICompatProvider.textEmbeddingModel(modelId, settings); + }, }, ) as TogetherAIProvider; From e88a55a46eb307f2ef91caa24ae70ce9cf85951b Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Tue, 19 Nov 2024 11:10:53 -0800 Subject: [PATCH 03/13] Rename openai-compat to openai-compatible. --- packages/openai-compat/CHANGELOG.md | 1 - packages/openai-compat/src/index.ts | 8 --- .../src/openai-compat-api-types.ts | 52 ----------------- .../openai-compat/src/openai-compat-error.ts | 16 ----- packages/openai-compatible/CHANGELOG.md | 1 + .../README.md | 4 +- .../package.json | 2 +- ...o-openai-compatible-chat-messages.test.ts} | 6 +- ...ert-to-openai-compatible-chat-messages.ts} | 10 ++-- .../src/get-response-metadata.ts | 0 packages/openai-compatible/src/index.ts | 8 +++ .../map-openai-compatible-finish-reason.ts} | 2 +- .../src/openai-compatible-api-types.ts | 52 +++++++++++++++++ ...ai-compatible-chat-language-model.test.ts} | 2 +- .../openai-compatible-chat-language-model.ts} | 58 ++++++++++--------- .../src/openai-compatible-chat-settings.ts} | 4 +- .../openai-compatible-completion-settings.ts} | 4 +- .../src/openai-compatible-embedding-model.ts} | 22 +++---- .../openai-compatible-embedding-settings.ts} | 4 +- .../src/openai-compatible-error.ts | 17 ++++++ .../src/openai-compatible-prepare-tools.ts} | 0 .../src/openai-compatible-provider.ts} | 58 ++++++++++--------- .../tsconfig.json | 0 .../tsup.config.ts | 0 .../turbo.json | 0 .../vitest.edge.config.js | 0 .../vitest.node.config.js | 0 packages/togetherai/package.json | 2 +- .../src/togetherai-chat-settings.ts | 4 +- .../src/togetherai-completion-settings.ts | 4 +- .../src/togetherai-embedding-settings.ts | 4 +- .../togetherai/src/togetherai-provider.ts | 34 +++++------ pnpm-lock.yaml | 6 +- 33 files changed, 198 insertions(+), 187 deletions(-) delete mode 100644 packages/openai-compat/CHANGELOG.md delete mode 100644 packages/openai-compat/src/index.ts delete mode 100644 packages/openai-compat/src/openai-compat-api-types.ts delete mode 100644 packages/openai-compat/src/openai-compat-error.ts create mode 100644 packages/openai-compatible/CHANGELOG.md rename packages/{openai-compat => openai-compatible}/README.md (51%) rename packages/{openai-compat => openai-compatible}/package.json (97%) rename packages/{openai-compat/src/convert-to-openai-compat-chat-messages.test.ts => openai-compatible/src/convert-to-openai-compatible-chat-messages.test.ts} (85%) rename packages/{openai-compat/src/convert-to-openai-compat-chat-messages.ts => openai-compatible/src/convert-to-openai-compatible-chat-messages.ts} (90%) rename packages/{openai-compat => openai-compatible}/src/get-response-metadata.ts (100%) create mode 100644 packages/openai-compatible/src/index.ts rename packages/{openai-compat/src/map-openai-compat-finish-reason.ts => openai-compatible/src/map-openai-compatible-finish-reason.ts} (89%) create mode 100644 packages/openai-compatible/src/openai-compatible-api-types.ts rename packages/{openai-compat/src/openai-compat-chat-language-model.test.ts => openai-compatible/src/openai-compatible-chat-language-model.test.ts} (99%) rename packages/{openai-compat/src/openai-compat-chat-language-model.ts => openai-compatible/src/openai-compatible-chat-language-model.ts} (90%) rename packages/{openai-compat/src/openai-compat-chat-settings.ts => openai-compatible/src/openai-compatible-chat-settings.ts} (57%) rename packages/{openai-compat/src/openai-compat-completion-settings.ts => openai-compatible/src/openai-compatible-completion-settings.ts} (93%) rename packages/{openai-compat/src/openai-compat-embedding-model.ts => openai-compatible/src/openai-compatible-embedding-model.ts} (81%) rename packages/{openai-compat/src/openai-compat-embedding-settings.ts => openai-compatible/src/openai-compatible-embedding-settings.ts} (81%) create mode 100644 packages/openai-compatible/src/openai-compatible-error.ts rename packages/{openai-compat/src/openai-compat-prepare-tools.ts => openai-compatible/src/openai-compatible-prepare-tools.ts} (100%) rename packages/{openai-compat/src/openai-compat-provider.ts => openai-compatible/src/openai-compatible-provider.ts} (59%) rename packages/{openai-compat => openai-compatible}/tsconfig.json (100%) rename packages/{openai-compat => openai-compatible}/tsup.config.ts (100%) rename packages/{openai-compat => openai-compatible}/turbo.json (100%) rename packages/{openai-compat => openai-compatible}/vitest.edge.config.js (100%) rename packages/{openai-compat => openai-compatible}/vitest.node.config.js (100%) diff --git a/packages/openai-compat/CHANGELOG.md b/packages/openai-compat/CHANGELOG.md deleted file mode 100644 index 70b1e470d95f..000000000000 --- a/packages/openai-compat/CHANGELOG.md +++ /dev/null @@ -1 +0,0 @@ -# @ai-sdk/openai-compat diff --git a/packages/openai-compat/src/index.ts b/packages/openai-compat/src/index.ts deleted file mode 100644 index dfdbbb22f50e..000000000000 --- a/packages/openai-compat/src/index.ts +++ /dev/null @@ -1,8 +0,0 @@ -export { createOpenAICompat } from './openai-compat-provider'; -export type { - OpenAICompatProvider, - OpenAICompatProviderSettings, -} from './openai-compat-provider'; -export type { OpenAICompatChatSettings } from './openai-compat-chat-settings'; -export type { OpenAICompatCompletionSettings } from './openai-compat-completion-settings'; -export type { OpenAICompatEmbeddingSettings } from './openai-compat-embedding-settings'; diff --git a/packages/openai-compat/src/openai-compat-api-types.ts b/packages/openai-compat/src/openai-compat-api-types.ts deleted file mode 100644 index 65720888c33d..000000000000 --- a/packages/openai-compat/src/openai-compat-api-types.ts +++ /dev/null @@ -1,52 +0,0 @@ -export type OpenAICompatChatPrompt = Array; - -export type OpenAICompatMessage = - | OpenAICompatSystemMessage - | OpenAICompatUserMessage - | OpenAICompatAssistantMessage - | OpenAICompatToolMessage; - -export interface OpenAICompatSystemMessage { - role: 'system'; - content: string; -} - -export interface OpenAICompatUserMessage { - role: 'user'; - content: string | Array; -} - -export type OpenAICompatContentPart = - | OpenAICompatContentPartText - | OpenAICompatContentPartImage; - -export interface OpenAICompatContentPartImage { - type: 'image_url'; - image_url: { url: string }; -} - -export interface OpenAICompatContentPartText { - type: 'text'; - text: string; -} - -export interface OpenAICompatAssistantMessage { - role: 'assistant'; - content?: string | null; - tool_calls?: Array; -} - -export interface OpenAICompatMessageToolCall { - type: 'function'; - id: string; - function: { - arguments: string; - name: string; - }; -} - -export interface OpenAICompatToolMessage { - role: 'tool'; - content: string; - tool_call_id: string; -} diff --git a/packages/openai-compat/src/openai-compat-error.ts b/packages/openai-compat/src/openai-compat-error.ts deleted file mode 100644 index debbee8270e1..000000000000 --- a/packages/openai-compat/src/openai-compat-error.ts +++ /dev/null @@ -1,16 +0,0 @@ -import { z } from 'zod'; -import { createJsonErrorResponseHandler } from '@ai-sdk/provider-utils'; - -export const openaiCompatErrorDataSchema = z.object({ - code: z.string(), - error: z.string(), -}); - -export type OpenAICompatErrorData = z.infer; - -export const openaiCompatFailedResponseHandler = createJsonErrorResponseHandler( - { - errorSchema: openaiCompatErrorDataSchema, - errorToMessage: data => data.error, - }, -); diff --git a/packages/openai-compatible/CHANGELOG.md b/packages/openai-compatible/CHANGELOG.md new file mode 100644 index 000000000000..a3fa7ffb25eb --- /dev/null +++ b/packages/openai-compatible/CHANGELOG.md @@ -0,0 +1 @@ +# @ai-sdk/openai-compatible diff --git a/packages/openai-compat/README.md b/packages/openai-compatible/README.md similarity index 51% rename from packages/openai-compat/README.md rename to packages/openai-compatible/README.md index 1a2daa57c041..a06670c922f4 100644 --- a/packages/openai-compat/README.md +++ b/packages/openai-compatible/README.md @@ -1,7 +1,7 @@ # AI SDK - OpenAI Compatible Provider -This provider aims to support a core subset of functionality common to a wide -range of OpenAI compatible LLM providers. The intent is to allow code sharing +This packge aims to speed and support the implementation of new +OpenAI-compatible providers. The intent is to allow more effective code sharing across multiple concrete provider implementations. The primary OpenAI provider is heavier-weight than what this package offers. diff --git a/packages/openai-compat/package.json b/packages/openai-compatible/package.json similarity index 97% rename from packages/openai-compat/package.json rename to packages/openai-compatible/package.json index 234ba883b59d..23a744a4a90c 100644 --- a/packages/openai-compat/package.json +++ b/packages/openai-compatible/package.json @@ -1,5 +1,5 @@ { - "name": "@ai-sdk/openai-compat", + "name": "@ai-sdk/openai-compatible", "version": "0.0.0", "license": "Apache-2.0", "sideEffects": false, diff --git a/packages/openai-compat/src/convert-to-openai-compat-chat-messages.test.ts b/packages/openai-compatible/src/convert-to-openai-compatible-chat-messages.test.ts similarity index 85% rename from packages/openai-compat/src/convert-to-openai-compat-chat-messages.test.ts rename to packages/openai-compatible/src/convert-to-openai-compatible-chat-messages.test.ts index 9551240a02c4..2c6e05dec02f 100644 --- a/packages/openai-compat/src/convert-to-openai-compat-chat-messages.test.ts +++ b/packages/openai-compatible/src/convert-to-openai-compatible-chat-messages.test.ts @@ -1,8 +1,8 @@ -import { convertToOpenAICompatChatMessages } from './convert-to-openai-compat-chat-messages'; +import { convertToOpenAICompatibleChatMessages } from './convert-to-openai-compatible-chat-messages'; describe('user messages', () => { it('should convert messages with only a text part to a string content', async () => { - const result = convertToOpenAICompatChatMessages([ + const result = convertToOpenAICompatibleChatMessages([ { role: 'user', content: [{ type: 'text', text: 'Hello' }], @@ -15,7 +15,7 @@ describe('user messages', () => { describe('tool calls', () => { it('should stringify arguments to tool calls', () => { - const result = convertToOpenAICompatChatMessages([ + const result = convertToOpenAICompatibleChatMessages([ { role: 'assistant', content: [ diff --git a/packages/openai-compat/src/convert-to-openai-compat-chat-messages.ts b/packages/openai-compatible/src/convert-to-openai-compatible-chat-messages.ts similarity index 90% rename from packages/openai-compat/src/convert-to-openai-compat-chat-messages.ts rename to packages/openai-compatible/src/convert-to-openai-compatible-chat-messages.ts index 134c2bd64e07..f3bb13cb6576 100644 --- a/packages/openai-compat/src/convert-to-openai-compat-chat-messages.ts +++ b/packages/openai-compatible/src/convert-to-openai-compatible-chat-messages.ts @@ -2,13 +2,12 @@ import { LanguageModelV1Prompt, UnsupportedFunctionalityError, } from '@ai-sdk/provider'; -import { convertUint8ArrayToBase64 } from '@ai-sdk/provider-utils'; -import { OpenAICompatChatPrompt } from './openai-compat-api-types'; +import { OpenAICompatibleChatPrompt } from './openai-compatible-api-types'; -export function convertToOpenAICompatChatMessages( +export function convertToOpenAICompatibleChatMessages( prompt: LanguageModelV1Prompt, -): OpenAICompatChatPrompt { - const messages: OpenAICompatChatPrompt = []; +): OpenAICompatibleChatPrompt { + const messages: OpenAICompatibleChatPrompt = []; for (const { role, content } of prompt) { switch (role) { @@ -31,6 +30,7 @@ export function convertToOpenAICompatChatMessages( return { type: 'text', text: part.text }; } case 'image': { + // TODO(shaper): Add back the below. throw new UnsupportedFunctionalityError({ functionality: 'Image content parts in user messages', }); diff --git a/packages/openai-compat/src/get-response-metadata.ts b/packages/openai-compatible/src/get-response-metadata.ts similarity index 100% rename from packages/openai-compat/src/get-response-metadata.ts rename to packages/openai-compatible/src/get-response-metadata.ts diff --git a/packages/openai-compatible/src/index.ts b/packages/openai-compatible/src/index.ts new file mode 100644 index 000000000000..448e7aa57ff8 --- /dev/null +++ b/packages/openai-compatible/src/index.ts @@ -0,0 +1,8 @@ +export { createOpenAICompatible } from './openai-compatible-provider'; +export type { + OpenAICompatibleProvider, + OpenAICompatibleProviderSettings, +} from './openai-compatible-provider'; +export type { OpenAICompatibleChatSettings } from './openai-compatible-chat-settings'; +export type { OpenAICompatibleCompletionSettings } from './openai-compatible-completion-settings'; +export type { OpenAICompatibleEmbeddingSettings } from './openai-compatible-embedding-settings'; diff --git a/packages/openai-compat/src/map-openai-compat-finish-reason.ts b/packages/openai-compatible/src/map-openai-compatible-finish-reason.ts similarity index 89% rename from packages/openai-compat/src/map-openai-compat-finish-reason.ts rename to packages/openai-compatible/src/map-openai-compatible-finish-reason.ts index a6011eb0b4cb..6d2b232aba23 100644 --- a/packages/openai-compat/src/map-openai-compat-finish-reason.ts +++ b/packages/openai-compatible/src/map-openai-compatible-finish-reason.ts @@ -1,6 +1,6 @@ import { LanguageModelV1FinishReason } from '@ai-sdk/provider'; -export function mapOpenAICompatFinishReason( +export function mapOpenAICompatibleFinishReason( finishReason: string | null | undefined, ): LanguageModelV1FinishReason { switch (finishReason) { diff --git a/packages/openai-compatible/src/openai-compatible-api-types.ts b/packages/openai-compatible/src/openai-compatible-api-types.ts new file mode 100644 index 000000000000..e26c2d8a9ede --- /dev/null +++ b/packages/openai-compatible/src/openai-compatible-api-types.ts @@ -0,0 +1,52 @@ +export type OpenAICompatibleChatPrompt = Array; + +export type OpenAICompatibleMessage = + | OpenAICompatibleSystemMessage + | OpenAICompatibleUserMessage + | OpenAICompatibleAssistantMessage + | OpenAICompatibleToolMessage; + +export interface OpenAICompatibleSystemMessage { + role: 'system'; + content: string; +} + +export interface OpenAICompatibleUserMessage { + role: 'user'; + content: string | Array; +} + +export type OpenAICompatibleContentPart = + | OpenAICompatibleContentPartText + | OpenAICompatibleContentPartImage; + +export interface OpenAICompatibleContentPartImage { + type: 'image_url'; + image_url: { url: string }; +} + +export interface OpenAICompatibleContentPartText { + type: 'text'; + text: string; +} + +export interface OpenAICompatibleAssistantMessage { + role: 'assistant'; + content?: string | null; + tool_calls?: Array; +} + +export interface OpenAICompatibleMessageToolCall { + type: 'function'; + id: string; + function: { + arguments: string; + name: string; + }; +} + +export interface OpenAICompatibleToolMessage { + role: 'tool'; + content: string; + tool_call_id: string; +} diff --git a/packages/openai-compat/src/openai-compat-chat-language-model.test.ts b/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts similarity index 99% rename from packages/openai-compat/src/openai-compat-chat-language-model.test.ts rename to packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts index 5006d36527ab..925c718e3a5b 100644 --- a/packages/openai-compat/src/openai-compat-chat-language-model.test.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts @@ -4,7 +4,7 @@ import { StreamingTestServer, convertReadableStreamToArray, } from '@ai-sdk/provider-utils/test'; -import { createOpenAICompat } from './openai-compat-provider'; +import { createOpenAICompat } from './openai-compatible-provider'; const TEST_PROMPT: LanguageModelV1Prompt = [ { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, diff --git a/packages/openai-compat/src/openai-compat-chat-language-model.ts b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts similarity index 90% rename from packages/openai-compat/src/openai-compat-chat-language-model.ts rename to packages/openai-compatible/src/openai-compatible-chat-language-model.ts index e0c05924d28c..f165e38b14ee 100644 --- a/packages/openai-compat/src/openai-compat-chat-language-model.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts @@ -18,41 +18,41 @@ import { postJsonToApi, } from '@ai-sdk/provider-utils'; import { z } from 'zod'; -import { convertToOpenAICompatChatMessages } from './convert-to-openai-compat-chat-messages'; +import { convertToOpenAICompatibleChatMessages } from './convert-to-openai-compatible-chat-messages'; import { getResponseMetadata } from './get-response-metadata'; import { - OpenAICompatChatModelId, - OpenAICompatChatSettings, -} from './openai-compat-chat-settings'; + OpenAICompatibleChatModelId, + OpenAICompatibleChatSettings, +} from './openai-compatible-chat-settings'; import { - openaiCompatErrorDataSchema, - openaiCompatFailedResponseHandler, -} from './openai-compat-error'; -import { prepareTools } from './openai-compat-prepare-tools'; -import { mapOpenAICompatFinishReason } from './map-openai-compat-finish-reason'; + OpenAICompatibleErrorDataSchema, + OpenAICompatibleFailedResponseHandler, +} from './openai-compatible-error'; +import { prepareTools } from './openai-compatible-prepare-tools'; +import { mapOpenAICompatibleFinishReason } from './map-openai-compatible-finish-reason'; -type OpenAICompatChatConfig = { +type OpenAICompatibleChatConfig = { provider: string; headers: () => Record; url: (options: { modelId: string; path: string }) => string; fetch?: FetchFunction; }; -export class OpenAICompatChatLanguageModel implements LanguageModelV1 { +export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { readonly specificationVersion = 'v1'; readonly supportsStructuredOutputs = false; readonly defaultObjectGenerationMode = 'tool'; - readonly modelId: OpenAICompatChatModelId; - readonly settings: OpenAICompatChatSettings; + readonly modelId: OpenAICompatibleChatModelId; + readonly settings: OpenAICompatibleChatSettings; - private readonly config: OpenAICompatChatConfig; + private readonly config: OpenAICompatibleChatConfig; constructor( - modelId: OpenAICompatChatModelId, - settings: OpenAICompatChatSettings, - config: OpenAICompatChatConfig, + modelId: OpenAICompatibleChatModelId, + settings: OpenAICompatibleChatSettings, + config: OpenAICompatibleChatConfig, ) { this.modelId = modelId; this.settings = settings; @@ -124,7 +124,7 @@ export class OpenAICompatChatLanguageModel implements LanguageModelV1 { undefined, // messages: - messages: convertToOpenAICompatChatMessages(prompt), + messages: convertToOpenAICompatibleChatMessages(prompt), }; switch (type) { @@ -190,9 +190,9 @@ export class OpenAICompatChatLanguageModel implements LanguageModelV1 { }), headers: combineHeaders(this.config.headers(), options.headers), body: args, - failedResponseHandler: openaiCompatFailedResponseHandler, + failedResponseHandler: OpenAICompatibleFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( - openaiCompatChatResponseSchema, + OpenAICompatibleChatResponseSchema, ), abortSignal: options.abortSignal, fetch: this.config.fetch, @@ -209,7 +209,7 @@ export class OpenAICompatChatLanguageModel implements LanguageModelV1 { toolName: toolCall.function.name, args: toolCall.function.arguments!, })), - finishReason: mapOpenAICompatFinishReason(choice.finish_reason), + finishReason: mapOpenAICompatibleFinishReason(choice.finish_reason), usage: { promptTokens: response.usage?.prompt_tokens ?? NaN, completionTokens: response.usage?.completion_tokens ?? NaN, @@ -239,9 +239,9 @@ export class OpenAICompatChatLanguageModel implements LanguageModelV1 { ...args, stream: true, }, - failedResponseHandler: openaiCompatFailedResponseHandler, + failedResponseHandler: OpenAICompatibleFailedResponseHandler, successfulResponseHandler: createEventSourceResponseHandler( - openaiCompatChatChunkSchema, + OpenAICompatibleChatChunkSchema, ), abortSignal: options.abortSignal, fetch: this.config.fetch, @@ -272,7 +272,7 @@ export class OpenAICompatChatLanguageModel implements LanguageModelV1 { return { stream: response.pipeThrough( new TransformStream< - ParseResult>, + ParseResult>, LanguageModelV1StreamPart >({ transform(chunk, controller) { @@ -311,7 +311,9 @@ export class OpenAICompatChatLanguageModel implements LanguageModelV1 { const choice = value.choices[0]; if (choice?.finish_reason != null) { - finishReason = mapOpenAICompatFinishReason(choice.finish_reason); + finishReason = mapOpenAICompatibleFinishReason( + choice.finish_reason, + ); } if (choice?.delta == null) { @@ -453,7 +455,7 @@ export class OpenAICompatChatLanguageModel implements LanguageModelV1 { // limited version of the schema, focussed on what is needed for the implementation // this approach limits breakages when the API changes and increases efficiency -const openaiCompatChatResponseSchema = z.object({ +const OpenAICompatibleChatResponseSchema = z.object({ id: z.string().nullish(), created: z.number().nullish(), model: z.string().nullish(), @@ -489,7 +491,7 @@ const openaiCompatChatResponseSchema = z.object({ // limited version of the schema, focussed on what is needed for the implementation // this approach limits breakages when the API changes and increases efficiency -const openaiCompatChatChunkSchema = z.union([ +const OpenAICompatibleChatChunkSchema = z.union([ z.object({ id: z.string().nullish(), created: z.number().nullish(), @@ -526,5 +528,5 @@ const openaiCompatChatChunkSchema = z.union([ }) .nullish(), }), - openaiCompatErrorDataSchema, + OpenAICompatibleErrorDataSchema, ]); diff --git a/packages/openai-compat/src/openai-compat-chat-settings.ts b/packages/openai-compatible/src/openai-compatible-chat-settings.ts similarity index 57% rename from packages/openai-compat/src/openai-compat-chat-settings.ts rename to packages/openai-compatible/src/openai-compatible-chat-settings.ts index d7c6dbccec2d..515704032c21 100644 --- a/packages/openai-compat/src/openai-compat-chat-settings.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-settings.ts @@ -1,6 +1,6 @@ -export type OpenAICompatChatModelId = string; +export type OpenAICompatibleChatModelId = string; -export interface OpenAICompatChatSettings { +export interface OpenAICompatibleChatSettings { /** A unique identifier representing your end-user, which can help the provider to monitor and detect abuse. diff --git a/packages/openai-compat/src/openai-compat-completion-settings.ts b/packages/openai-compatible/src/openai-compatible-completion-settings.ts similarity index 93% rename from packages/openai-compat/src/openai-compat-completion-settings.ts rename to packages/openai-compatible/src/openai-compatible-completion-settings.ts index e64feac878e9..518cf4b26d04 100644 --- a/packages/openai-compat/src/openai-compat-completion-settings.ts +++ b/packages/openai-compatible/src/openai-compatible-completion-settings.ts @@ -1,6 +1,6 @@ -export type OpenAICompatCompletionModelId = string; +export type OpenAICompatibleCompletionModelId = string; -export interface OpenAICompatCompletionSettings { +export interface OpenAICompatibleCompletionSettings { /** Echo back the prompt in addition to the completion. */ diff --git a/packages/openai-compat/src/openai-compat-embedding-model.ts b/packages/openai-compatible/src/openai-compatible-embedding-model.ts similarity index 81% rename from packages/openai-compat/src/openai-compat-embedding-model.ts rename to packages/openai-compatible/src/openai-compatible-embedding-model.ts index 92638aeaed6a..23dc313590ff 100644 --- a/packages/openai-compat/src/openai-compat-embedding-model.ts +++ b/packages/openai-compatible/src/openai-compatible-embedding-model.ts @@ -10,10 +10,10 @@ import { } from '@ai-sdk/provider-utils'; import { z } from 'zod'; import { - OpenAICompatEmbeddingModelId, - OpenAICompatEmbeddingSettings, -} from './openai-compat-embedding-settings'; -import { openaiCompatFailedResponseHandler } from './openai-compat-error'; + OpenAICompatibleEmbeddingModelId, + OpenAICompatibleEmbeddingSettings, +} from './openai-compatible-embedding-settings'; +import { OpenAICompatibleFailedResponseHandler } from './openai-compatible-error'; type OpenAIEmbeddingConfig = { provider: string; @@ -22,12 +22,14 @@ type OpenAIEmbeddingConfig = { fetch?: FetchFunction; }; -export class OpenAICompatEmbeddingModel implements EmbeddingModelV1 { +export class OpenAICompatibleEmbeddingModel + implements EmbeddingModelV1 +{ readonly specificationVersion = 'v1'; - readonly modelId: OpenAICompatEmbeddingModelId; + readonly modelId: OpenAICompatibleEmbeddingModelId; private readonly config: OpenAIEmbeddingConfig; - private readonly settings: OpenAICompatEmbeddingSettings; + private readonly settings: OpenAICompatibleEmbeddingSettings; get provider(): string { return this.config.provider; @@ -42,8 +44,8 @@ export class OpenAICompatEmbeddingModel implements EmbeddingModelV1 { } constructor( - modelId: OpenAICompatEmbeddingModelId, - settings: OpenAICompatEmbeddingSettings, + modelId: OpenAICompatibleEmbeddingModelId, + settings: OpenAICompatibleEmbeddingSettings, config: OpenAIEmbeddingConfig, ) { this.modelId = modelId; @@ -80,7 +82,7 @@ export class OpenAICompatEmbeddingModel implements EmbeddingModelV1 { dimensions: this.settings.dimensions, user: this.settings.user, }, - failedResponseHandler: openaiCompatFailedResponseHandler, + failedResponseHandler: OpenAICompatibleFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( openaiTextEmbeddingResponseSchema, ), diff --git a/packages/openai-compat/src/openai-compat-embedding-settings.ts b/packages/openai-compatible/src/openai-compatible-embedding-settings.ts similarity index 81% rename from packages/openai-compat/src/openai-compat-embedding-settings.ts rename to packages/openai-compatible/src/openai-compatible-embedding-settings.ts index 42a8d730dc37..cce45e0667d5 100644 --- a/packages/openai-compat/src/openai-compat-embedding-settings.ts +++ b/packages/openai-compatible/src/openai-compatible-embedding-settings.ts @@ -1,6 +1,6 @@ -export type OpenAICompatEmbeddingModelId = string; +export type OpenAICompatibleEmbeddingModelId = string; -export interface OpenAICompatEmbeddingSettings { +export interface OpenAICompatibleEmbeddingSettings { /** Override the maximum number of embeddings per call. */ diff --git a/packages/openai-compatible/src/openai-compatible-error.ts b/packages/openai-compatible/src/openai-compatible-error.ts new file mode 100644 index 000000000000..dd847395fe9c --- /dev/null +++ b/packages/openai-compatible/src/openai-compatible-error.ts @@ -0,0 +1,17 @@ +import { z } from 'zod'; +import { createJsonErrorResponseHandler } from '@ai-sdk/provider-utils'; + +export const OpenAICompatibleErrorDataSchema = z.object({ + code: z.string(), + error: z.string(), +}); + +export type OpenAICompatibleErrorData = z.infer< + typeof OpenAICompatibleErrorDataSchema +>; + +export const OpenAICompatibleFailedResponseHandler = + createJsonErrorResponseHandler({ + errorSchema: OpenAICompatibleErrorDataSchema, + errorToMessage: data => data.error, + }); diff --git a/packages/openai-compat/src/openai-compat-prepare-tools.ts b/packages/openai-compatible/src/openai-compatible-prepare-tools.ts similarity index 100% rename from packages/openai-compat/src/openai-compat-prepare-tools.ts rename to packages/openai-compatible/src/openai-compatible-prepare-tools.ts diff --git a/packages/openai-compat/src/openai-compat-provider.ts b/packages/openai-compatible/src/openai-compatible-provider.ts similarity index 59% rename from packages/openai-compat/src/openai-compat-provider.ts rename to packages/openai-compatible/src/openai-compatible-provider.ts index 1de1c127fd30..de10c958e440 100644 --- a/packages/openai-compat/src/openai-compat-provider.ts +++ b/packages/openai-compatible/src/openai-compatible-provider.ts @@ -8,32 +8,35 @@ import { loadApiKey, withoutTrailingSlash, } from '@ai-sdk/provider-utils'; -import { OpenAICompatChatLanguageModel } from './openai-compat-chat-language-model'; -import { OpenAICompatChatSettings } from './openai-compat-chat-settings'; -import { OpenAICompatCompletionSettings } from './openai-compat-completion-settings'; -import { OpenAICompatEmbeddingSettings } from './openai-compat-embedding-settings'; -import { OpenAICompatEmbeddingModel } from './openai-compat-embedding-model'; +import { OpenAICompatibleChatLanguageModel } from './openai-compatible-chat-language-model'; +import { OpenAICompatibleChatSettings } from './openai-compatible-chat-settings'; +import { OpenAICompatibleCompletionSettings } from './openai-compatible-completion-settings'; +import { OpenAICompatibleEmbeddingSettings } from './openai-compatible-embedding-settings'; +import { OpenAICompatibleEmbeddingModel } from './openai-compatible-embedding-model'; -export interface OpenAICompatProvider +export interface OpenAICompatibleProvider extends ProviderV1 { - (modelId: M, settings?: OpenAICompatChatSettings): LanguageModelV1; + (modelId: M, settings?: OpenAICompatibleChatSettings): LanguageModelV1; languageModel( modelId: M, - settings?: OpenAICompatCompletionSettings, + settings?: OpenAICompatibleCompletionSettings, ): LanguageModelV1; - chatModel(modelId: M, settings?: OpenAICompatChatSettings): LanguageModelV1; + chatModel( + modelId: M, + settings?: OpenAICompatibleChatSettings, + ): LanguageModelV1; textEmbeddingModel( modelId: M, - settings?: OpenAICompatEmbeddingSettings, + settings?: OpenAICompatibleEmbeddingSettings, ): EmbeddingModelV1; } -export interface OpenAICompatProviderSettings { +export interface OpenAICompatibleProviderSettings { /** -Base URL for the OpenAICompat API calls. +Base URL for the OpenAICompatible API calls. */ baseURL?: string; @@ -65,11 +68,11 @@ Description of the API key environment variable for error messages. } /** -Create an OpenAICompat provider instance. +Create an OpenAICompatible provider instance. */ -export function createOpenAICompat( - options: OpenAICompatProviderSettings, -): OpenAICompatProvider { +export function createOpenAICompatible( + options: OpenAICompatibleProviderSettings, +): OpenAICompatibleProvider { // TODO(shaper): Throw if baseURL isn't set. const baseURL = withoutTrailingSlash(options.baseURL); @@ -84,11 +87,11 @@ export function createOpenAICompat( const createLanguageModel = ( modelId: M, - settings?: OpenAICompatCompletionSettings, + settings?: OpenAICompatibleCompletionSettings, ) => { if (new.target) { throw new Error( - 'The OpenAICompat model function cannot be called with the new keyword.', + 'The OpenAICompatible model function cannot be called with the new keyword.', ); } @@ -98,10 +101,10 @@ export function createOpenAICompat( const createChatModel = ( modelId: M, - settings: OpenAICompatChatSettings = {}, + settings: OpenAICompatibleChatSettings = {}, ) => - new OpenAICompatChatLanguageModel(modelId, settings, { - provider: 'openaiCompat.chat', + new OpenAICompatibleChatLanguageModel(modelId, settings, { + provider: 'openAICompatible.chat', url: ({ path }) => `${baseURL}${path}`, headers: getHeaders, fetch: options.fetch, @@ -109,16 +112,19 @@ export function createOpenAICompat( const createEmbeddingModel = ( modelId: M, - settings: OpenAICompatEmbeddingSettings = {}, + settings: OpenAICompatibleEmbeddingSettings = {}, ) => - new OpenAICompatEmbeddingModel(modelId, settings, { - provider: 'openaiCompat.embedding', + new OpenAICompatibleEmbeddingModel(modelId, settings, { + provider: 'openaiCompatible.embedding', url: ({ path }) => `${baseURL}${path}`, headers: getHeaders, fetch: options.fetch, }); - const provider = function (modelId: M, settings?: OpenAICompatChatSettings) { + const provider = function ( + modelId: M, + settings?: OpenAICompatibleChatSettings, + ) { return createLanguageModel(modelId, settings); }; @@ -132,5 +138,5 @@ export function createOpenAICompat( // throw new NoSuchModelError({ modelId, modelType: 'textEmbeddingModel' }); // }; - return provider as OpenAICompatProvider; + return provider as OpenAICompatibleProvider; } diff --git a/packages/openai-compat/tsconfig.json b/packages/openai-compatible/tsconfig.json similarity index 100% rename from packages/openai-compat/tsconfig.json rename to packages/openai-compatible/tsconfig.json diff --git a/packages/openai-compat/tsup.config.ts b/packages/openai-compatible/tsup.config.ts similarity index 100% rename from packages/openai-compat/tsup.config.ts rename to packages/openai-compatible/tsup.config.ts diff --git a/packages/openai-compat/turbo.json b/packages/openai-compatible/turbo.json similarity index 100% rename from packages/openai-compat/turbo.json rename to packages/openai-compatible/turbo.json diff --git a/packages/openai-compat/vitest.edge.config.js b/packages/openai-compatible/vitest.edge.config.js similarity index 100% rename from packages/openai-compat/vitest.edge.config.js rename to packages/openai-compatible/vitest.edge.config.js diff --git a/packages/openai-compat/vitest.node.config.js b/packages/openai-compatible/vitest.node.config.js similarity index 100% rename from packages/openai-compat/vitest.node.config.js rename to packages/openai-compatible/vitest.node.config.js diff --git a/packages/togetherai/package.json b/packages/togetherai/package.json index 94812940f4cb..59961118854b 100644 --- a/packages/togetherai/package.json +++ b/packages/togetherai/package.json @@ -37,7 +37,7 @@ } }, "dependencies": { - "@ai-sdk/openai-compat": "0.0.0", + "@ai-sdk/openai-compatible": "0.0.0", "@ai-sdk/provider": "1.0.0", "@ai-sdk/provider-utils": "2.0.0" }, diff --git a/packages/togetherai/src/togetherai-chat-settings.ts b/packages/togetherai/src/togetherai-chat-settings.ts index 4a964e866181..bb2c3d720482 100644 --- a/packages/togetherai/src/togetherai-chat-settings.ts +++ b/packages/togetherai/src/togetherai-chat-settings.ts @@ -1,4 +1,4 @@ -import { OpenAICompatChatSettings } from '@ai-sdk/openai-compat'; +import { OpenAICompatibleChatSettings } from '@ai-sdk/openai-compatible'; // https://docs.together.ai/docs/serverless-models#chat-models export type TogetherAIChatModelId = @@ -131,4 +131,4 @@ export type TogetherAIChatModelId = // | 'WizardLM/WizardLM-13B-V1.2' // | (string & {}); -export interface TogetherAIChatSettings extends OpenAICompatChatSettings {} +export interface TogetherAIChatSettings extends OpenAICompatibleChatSettings {} diff --git a/packages/togetherai/src/togetherai-completion-settings.ts b/packages/togetherai/src/togetherai-completion-settings.ts index 562d08f49da2..2316be30e587 100644 --- a/packages/togetherai/src/togetherai-completion-settings.ts +++ b/packages/togetherai/src/togetherai-completion-settings.ts @@ -1,4 +1,4 @@ -import { OpenAICompatCompletionSettings } from '@ai-sdk/openai-compat'; +import { OpenAICompatibleCompletionSettings } from '@ai-sdk/openai-compatible'; // https://docs.together.ai/docs/serverless-models#language-models export type TogetherAICompletionModelId = @@ -56,4 +56,4 @@ export type TogetherAICompletionModelId = // | (string & {}); export interface TogetherAICompletionSettings - extends OpenAICompatCompletionSettings {} + extends OpenAICompatibleCompletionSettings {} diff --git a/packages/togetherai/src/togetherai-embedding-settings.ts b/packages/togetherai/src/togetherai-embedding-settings.ts index f010e8016d1b..7d008442bd15 100644 --- a/packages/togetherai/src/togetherai-embedding-settings.ts +++ b/packages/togetherai/src/togetherai-embedding-settings.ts @@ -1,4 +1,4 @@ -import { OpenAICompatEmbeddingSettings } from '@ai-sdk/openai-compat'; +import { OpenAICompatibleEmbeddingSettings } from '@ai-sdk/openai-compatible'; // https://docs.together.ai/docs/serverless-models#embedding-models export type TogetherAIEmbeddingModelId = @@ -26,4 +26,4 @@ export type TogetherAIEmbeddingModelId = // | (string & {}); export interface TogetherAIEmbeddingSettings - extends OpenAICompatEmbeddingSettings {} + extends OpenAICompatibleEmbeddingSettings {} diff --git a/packages/togetherai/src/togetherai-provider.ts b/packages/togetherai/src/togetherai-provider.ts index a2aa163b2dfb..5d134443a4c3 100644 --- a/packages/togetherai/src/togetherai-provider.ts +++ b/packages/togetherai/src/togetherai-provider.ts @@ -1,9 +1,9 @@ import { - OpenAICompatProvider, - createOpenAICompat, - OpenAICompatChatSettings, - OpenAICompatProviderSettings, -} from '@ai-sdk/openai-compat'; + OpenAICompatibleProvider, + createOpenAICompatible, + OpenAICompatibleChatSettings, + OpenAICompatibleProviderSettings, +} from '@ai-sdk/openai-compatible'; import { LanguageModelV1, EmbeddingModelV1 } from '@ai-sdk/provider'; import { TogetherAIChatModelId } from './togetherai-chat-settings'; import { @@ -16,22 +16,22 @@ import { } from './togetherai-completion-settings'; export interface TogetherAIProviderSettings - extends OpenAICompatProviderSettings {} + extends OpenAICompatibleProviderSettings {} export interface TogetherAIProvider - extends OpenAICompatProvider< + extends OpenAICompatibleProvider< | TogetherAIChatModelId | TogetherAICompletionModelId | TogetherAIEmbeddingModelId > { chatModel( modelId: TogetherAIChatModelId, - settings?: OpenAICompatChatSettings, + settings?: OpenAICompatibleChatSettings, ): LanguageModelV1; completionModel( modelId: TogetherAICompletionModelId, - settings?: OpenAICompatChatSettings, + settings?: OpenAICompatibleChatSettings, ): LanguageModelV1; textEmbeddingModel( @@ -43,14 +43,14 @@ export interface TogetherAIProvider export function createTogetherAI( options: TogetherAIProviderSettings = {}, ): TogetherAIProvider { - const providerOptions: OpenAICompatProviderSettings = { + const providerOptions: OpenAICompatibleProviderSettings = { baseURL: 'https://api.together.xyz/v1/', apiKeyEnvVarName: 'TOGETHER_AI_API_KEY', apiKeyEnvVarDescription: "TogetherAI's API key", ...options, }; // TODO(shaper): Consider separating generics in the ctor. - const openAICompatProvider = createOpenAICompat< + const openAICompatibleProvider = createOpenAICompatible< | TogetherAIChatModelId | TogetherAICompletionModelId | TogetherAIEmbeddingModelId @@ -59,30 +59,30 @@ export function createTogetherAI( const togetheraiProvider: TogetherAIProvider = Object.assign( ( modelId: TogetherAIChatModelId, - settings?: OpenAICompatChatSettings, + settings?: OpenAICompatibleChatSettings, ): LanguageModelV1 => { - return openAICompatProvider(modelId, settings); + return openAICompatibleProvider(modelId, settings); }, { chatModel: ( modelId: TogetherAIChatModelId, - settings?: OpenAICompatChatSettings, + settings?: OpenAICompatibleChatSettings, ) => { - return openAICompatProvider.chatModel(modelId, settings); + return openAICompatibleProvider.chatModel(modelId, settings); }, completionModel: ( modelId: TogetherAICompletionModelId, settings?: TogetherAICompletionSettings, ) => { - return openAICompatProvider.languageModel(modelId, settings); + return openAICompatibleProvider.languageModel(modelId, settings); }, textEmbeddingModel: ( modelId: TogetherAIEmbeddingModelId, settings?: TogetherAIEmbeddingSettings, ) => { - return openAICompatProvider.textEmbeddingModel(modelId, settings); + return openAICompatibleProvider.textEmbeddingModel(modelId, settings); }, }, ) as TogetherAIProvider; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 98abc7023434..7c0711bc319e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1364,7 +1364,7 @@ importers: specifier: 3.23.8 version: 3.23.8 - packages/openai-compat: + packages/openai-compatible: dependencies: '@ai-sdk/provider': specifier: 1.0.0 @@ -1619,9 +1619,9 @@ importers: packages/togetherai: dependencies: - '@ai-sdk/openai-compat': + '@ai-sdk/openai-compatible': specifier: 0.0.0 - version: link:../openai-compat + version: link:../openai-compatible '@ai-sdk/provider': specifier: 1.0.0 version: link:../provider From 171644453db6a65afa8c02ac52a5928d8dedc547 Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Tue, 19 Nov 2024 11:30:56 -0800 Subject: [PATCH 04/13] Integrating completion model. --- ...-to-openai-compatible-completion-prompt.ts | 110 +++++ ...p-openai-compatible-completion-logprobs.ts | 24 ++ ...nai-compatible-chat-language-model.test.ts | 8 +- .../openai-compatible-chat-language-model.ts | 10 +- ...ai-compatible-completion-language-model.ts | 387 ++++++++++++++++++ .../src/openai-compatible-embedding-model.ts | 4 +- .../src/openai-compatible-error.ts | 10 +- 7 files changed, 538 insertions(+), 15 deletions(-) create mode 100644 packages/openai-compatible/src/convert-to-openai-compatible-completion-prompt.ts create mode 100644 packages/openai-compatible/src/map-openai-compatible-completion-logprobs.ts create mode 100644 packages/openai-compatible/src/openai-compatible-completion-language-model.ts diff --git a/packages/openai-compatible/src/convert-to-openai-compatible-completion-prompt.ts b/packages/openai-compatible/src/convert-to-openai-compatible-completion-prompt.ts new file mode 100644 index 000000000000..8b13f10b1ef3 --- /dev/null +++ b/packages/openai-compatible/src/convert-to-openai-compatible-completion-prompt.ts @@ -0,0 +1,110 @@ +import { + InvalidPromptError, + LanguageModelV1Prompt, + UnsupportedFunctionalityError, +} from '@ai-sdk/provider'; + +export function convertToOpenAICompatibleCompletionPrompt({ + prompt, + inputFormat, + user = 'user', + assistant = 'assistant', +}: { + prompt: LanguageModelV1Prompt; + inputFormat: 'prompt' | 'messages'; + user?: string; + assistant?: string; +}): { + prompt: string; + stopSequences?: string[]; +} { + // When the user supplied a prompt input, we don't transform it: + if ( + inputFormat === 'prompt' && + prompt.length === 1 && + prompt[0].role === 'user' && + prompt[0].content.length === 1 && + prompt[0].content[0].type === 'text' + ) { + return { prompt: prompt[0].content[0].text }; + } + + // otherwise transform to a chat message format: + let text = ''; + + // if first message is a system message, add it to the text: + if (prompt[0].role === 'system') { + text += `${prompt[0].content}\n\n`; + prompt = prompt.slice(1); + } + + for (const { role, content } of prompt) { + switch (role) { + case 'system': { + throw new InvalidPromptError({ + message: 'Unexpected system message in prompt: ${content}', + prompt, + }); + } + + case 'user': { + const userMessage = content + .map(part => { + switch (part.type) { + case 'text': { + return part.text; + } + case 'image': { + throw new UnsupportedFunctionalityError({ + functionality: 'images', + }); + } + } + }) + .join(''); + + text += `${user}:\n${userMessage}\n\n`; + break; + } + + case 'assistant': { + const assistantMessage = content + .map(part => { + switch (part.type) { + case 'text': { + return part.text; + } + case 'tool-call': { + throw new UnsupportedFunctionalityError({ + functionality: 'tool-call messages', + }); + } + } + }) + .join(''); + + text += `${assistant}:\n${assistantMessage}\n\n`; + break; + } + + case 'tool': { + throw new UnsupportedFunctionalityError({ + functionality: 'tool messages', + }); + } + + default: { + const _exhaustiveCheck: never = role; + throw new Error(`Unsupported role: ${_exhaustiveCheck}`); + } + } + } + + // Assistant message prefix: + text += `${assistant}:\n`; + + return { + prompt: text, + stopSequences: [`\n${user}:`], + }; +} diff --git a/packages/openai-compatible/src/map-openai-compatible-completion-logprobs.ts b/packages/openai-compatible/src/map-openai-compatible-completion-logprobs.ts new file mode 100644 index 000000000000..a7ea6279574d --- /dev/null +++ b/packages/openai-compatible/src/map-openai-compatible-completion-logprobs.ts @@ -0,0 +1,24 @@ +import { LanguageModelV1LogProbs } from '@ai-sdk/provider'; + +type OpenAICompatibleCompletionLogProps = { + tokens: string[]; + token_logprobs: number[]; + top_logprobs: Record[] | null; +}; + +export function mapOpenAICompatibleCompletionLogProbs( + logprobs: OpenAICompatibleCompletionLogProps | null | undefined, +): LanguageModelV1LogProbs | undefined { + return logprobs?.tokens.map((token, index) => ({ + token, + logprob: logprobs.token_logprobs[index], + topLogprobs: logprobs.top_logprobs + ? Object.entries(logprobs.top_logprobs[index]).map( + ([token, logprob]) => ({ + token, + logprob, + }), + ) + : [], + })); +} diff --git a/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts b/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts index 925c718e3a5b..3e79947f16a1 100644 --- a/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts @@ -4,13 +4,13 @@ import { StreamingTestServer, convertReadableStreamToArray, } from '@ai-sdk/provider-utils/test'; -import { createOpenAICompat } from './openai-compatible-provider'; +import { createOpenAICompat as createOpenAICompatible } from './openai-compatible-provider'; const TEST_PROMPT: LanguageModelV1Prompt = [ { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, ]; -const provider = createOpenAICompat({ +const provider = createOpenAICompatible({ apiKey: 'test-api-key', }); @@ -305,7 +305,7 @@ describe('doGenerate', () => { it('should pass headers', async () => { prepareJsonResponse({ content: '' }); - const provider = createOpenAICompat({ + const provider = createOpenAICompatible({ apiKey: 'test-api-key', headers: { 'Custom-Provider-Header': 'provider-header-value', @@ -877,7 +877,7 @@ describe('doStream', () => { it('should pass headers', async () => { prepareStreamResponse({ content: [] }); - const provider = createOpenAICompat({ + const provider = createOpenAICompatible({ apiKey: 'test-api-key', headers: { 'Custom-Provider-Header': 'provider-header-value', diff --git a/packages/openai-compatible/src/openai-compatible-chat-language-model.ts b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts index f165e38b14ee..99c03b363bb5 100644 --- a/packages/openai-compatible/src/openai-compatible-chat-language-model.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts @@ -25,8 +25,8 @@ import { OpenAICompatibleChatSettings, } from './openai-compatible-chat-settings'; import { - OpenAICompatibleErrorDataSchema, - OpenAICompatibleFailedResponseHandler, + openAICompatibleErrorDataSchema, + openAICompatibleFailedResponseHandler, } from './openai-compatible-error'; import { prepareTools } from './openai-compatible-prepare-tools'; import { mapOpenAICompatibleFinishReason } from './map-openai-compatible-finish-reason'; @@ -190,7 +190,7 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { }), headers: combineHeaders(this.config.headers(), options.headers), body: args, - failedResponseHandler: OpenAICompatibleFailedResponseHandler, + failedResponseHandler: openAICompatibleFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( OpenAICompatibleChatResponseSchema, ), @@ -239,7 +239,7 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { ...args, stream: true, }, - failedResponseHandler: OpenAICompatibleFailedResponseHandler, + failedResponseHandler: openAICompatibleFailedResponseHandler, successfulResponseHandler: createEventSourceResponseHandler( OpenAICompatibleChatChunkSchema, ), @@ -528,5 +528,5 @@ const OpenAICompatibleChatChunkSchema = z.union([ }) .nullish(), }), - OpenAICompatibleErrorDataSchema, + openAICompatibleErrorDataSchema, ]); diff --git a/packages/openai-compatible/src/openai-compatible-completion-language-model.ts b/packages/openai-compatible/src/openai-compatible-completion-language-model.ts new file mode 100644 index 000000000000..51b0370dab71 --- /dev/null +++ b/packages/openai-compatible/src/openai-compatible-completion-language-model.ts @@ -0,0 +1,387 @@ +import { + LanguageModelV1, + LanguageModelV1CallWarning, + LanguageModelV1FinishReason, + LanguageModelV1LogProbs, + LanguageModelV1StreamPart, + UnsupportedFunctionalityError, +} from '@ai-sdk/provider'; +import { + FetchFunction, + ParseResult, + combineHeaders, + createEventSourceResponseHandler, + createJsonResponseHandler, + postJsonToApi, +} from '@ai-sdk/provider-utils'; +import { z } from 'zod'; +import { convertToOpenAICompatibleCompletionPrompt } from './convert-to-openai-compatible-completion-prompt'; +import { mapOpenAICompatibleCompletionLogProbs } from './map-openai-compatible-completion-logprobs'; +import { mapOpenAICompatibleFinishReason } from './map-openai-compatible-finish-reason'; +import { + OpenAICompatibleCompletionModelId, + OpenAICompatibleCompletionSettings, +} from './openai-compatible-completion-settings'; +import { + openAICompatibleErrorDataSchema, + openAICompatibleFailedResponseHandler, +} from './openai-compatible-error'; +import { getResponseMetadata } from './get-response-metadata'; + +type OpenAICompatibleCompletionConfig = { + provider: string; + compatibility: 'strict' | 'compatible'; + headers: () => Record; + url: (options: { modelId: string; path: string }) => string; + fetch?: FetchFunction; +}; + +export class OpenAICompatibleCompletionLanguageModel + implements LanguageModelV1 +{ + readonly specificationVersion = 'v1'; + readonly defaultObjectGenerationMode = undefined; + + readonly modelId: OpenAICompatibleCompletionModelId; + readonly settings: OpenAICompatibleCompletionSettings; + + private readonly config: OpenAICompatibleCompletionConfig; + + constructor( + modelId: OpenAICompatibleCompletionModelId, + settings: OpenAICompatibleCompletionSettings, + config: OpenAICompatibleCompletionConfig, + ) { + this.modelId = modelId; + this.settings = settings; + this.config = config; + } + + get provider(): string { + return this.config.provider; + } + + private getArgs({ + mode, + inputFormat, + prompt, + maxTokens, + temperature, + topP, + topK, + frequencyPenalty, + presencePenalty, + stopSequences: userStopSequences, + responseFormat, + seed, + }: Parameters[0]) { + const type = mode.type; + + const warnings: LanguageModelV1CallWarning[] = []; + + if (topK != null) { + warnings.push({ + type: 'unsupported-setting', + setting: 'topK', + }); + } + + if (responseFormat != null && responseFormat.type !== 'text') { + warnings.push({ + type: 'unsupported-setting', + setting: 'responseFormat', + details: 'JSON response format is not supported.', + }); + } + + const { prompt: completionPrompt, stopSequences } = + convertToOpenAICompatibleCompletionPrompt({ prompt, inputFormat }); + + const stop = [...(stopSequences ?? []), ...(userStopSequences ?? [])]; + + const baseArgs = { + // model id: + model: this.modelId, + + // model specific settings: + echo: this.settings.echo, + logit_bias: this.settings.logitBias, + logprobs: + typeof this.settings.logprobs === 'number' + ? this.settings.logprobs + : typeof this.settings.logprobs === 'boolean' + ? this.settings.logprobs + ? 0 + : undefined + : undefined, + suffix: this.settings.suffix, + user: this.settings.user, + + // standardized settings: + max_tokens: maxTokens, + temperature, + top_p: topP, + frequency_penalty: frequencyPenalty, + presence_penalty: presencePenalty, + seed, + + // prompt: + prompt: completionPrompt, + + // stop sequences: + stop: stop.length > 0 ? stop : undefined, + }; + + switch (type) { + case 'regular': { + if (mode.tools?.length) { + throw new UnsupportedFunctionalityError({ + functionality: 'tools', + }); + } + + if (mode.toolChoice) { + throw new UnsupportedFunctionalityError({ + functionality: 'toolChoice', + }); + } + + return { args: baseArgs, warnings }; + } + + case 'object-json': { + throw new UnsupportedFunctionalityError({ + functionality: 'object-json mode', + }); + } + + case 'object-tool': { + throw new UnsupportedFunctionalityError({ + functionality: 'object-tool mode', + }); + } + + default: { + const _exhaustiveCheck: never = type; + throw new Error(`Unsupported type: ${_exhaustiveCheck}`); + } + } + } + + async doGenerate( + options: Parameters[0], + ): Promise>> { + const { args, warnings } = this.getArgs(options); + + const { responseHeaders, value: response } = await postJsonToApi({ + url: this.config.url({ + path: '/completions', + modelId: this.modelId, + }), + headers: combineHeaders(this.config.headers(), options.headers), + body: args, + failedResponseHandler: openAICompatibleFailedResponseHandler, + successfulResponseHandler: createJsonResponseHandler( + openAICompatibleCompletionResponseSchema, + ), + abortSignal: options.abortSignal, + fetch: this.config.fetch, + }); + + const { prompt: rawPrompt, ...rawSettings } = args; + const choice = response.choices[0]; + + return { + text: choice.text, + usage: { + promptTokens: response.usage.prompt_tokens, + completionTokens: response.usage.completion_tokens, + }, + finishReason: mapOpenAICompatibleFinishReason(choice.finish_reason), + logprobs: mapOpenAICompatibleCompletionLogProbs(choice.logprobs), + rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, + response: getResponseMetadata(response), + warnings, + request: { body: JSON.stringify(args) }, + }; + } + + async doStream( + options: Parameters[0], + ): Promise>> { + const { args, warnings } = this.getArgs(options); + + const body = { + ...args, + stream: true, + + // only include stream_options when in strict compatibility mode: + stream_options: + this.config.compatibility === 'strict' + ? { include_usage: true } + : undefined, + }; + + const { responseHeaders, value: response } = await postJsonToApi({ + url: this.config.url({ + path: '/completions', + modelId: this.modelId, + }), + headers: combineHeaders(this.config.headers(), options.headers), + body, + failedResponseHandler: openAICompatibleFailedResponseHandler, + successfulResponseHandler: createEventSourceResponseHandler( + openaiCompatibleCompletionChunkSchema, + ), + abortSignal: options.abortSignal, + fetch: this.config.fetch, + }); + + const { prompt: rawPrompt, ...rawSettings } = args; + + let finishReason: LanguageModelV1FinishReason = 'unknown'; + let usage: { promptTokens: number; completionTokens: number } = { + promptTokens: Number.NaN, + completionTokens: Number.NaN, + }; + let logprobs: LanguageModelV1LogProbs; + let isFirstChunk = true; + + return { + stream: response.pipeThrough( + new TransformStream< + ParseResult>, + LanguageModelV1StreamPart + >({ + transform(chunk, controller) { + // handle failed chunk parsing / validation: + if (!chunk.success) { + finishReason = 'error'; + controller.enqueue({ type: 'error', error: chunk.error }); + return; + } + + const value = chunk.value; + + // handle error chunks: + if ('error' in value) { + finishReason = 'error'; + controller.enqueue({ type: 'error', error: value.error }); + return; + } + + if (isFirstChunk) { + isFirstChunk = false; + + controller.enqueue({ + type: 'response-metadata', + ...getResponseMetadata(value), + }); + } + + if (value.usage != null) { + usage = { + promptTokens: value.usage.prompt_tokens, + completionTokens: value.usage.completion_tokens, + }; + } + + const choice = value.choices[0]; + + if (choice?.finish_reason != null) { + finishReason = mapOpenAICompatibleFinishReason( + choice.finish_reason, + ); + } + + if (choice?.text != null) { + controller.enqueue({ + type: 'text-delta', + textDelta: choice.text, + }); + } + + const mappedLogprobs = mapOpenAICompatibleCompletionLogProbs( + choice?.logprobs, + ); + if (mappedLogprobs?.length) { + if (logprobs === undefined) logprobs = []; + logprobs.push(...mappedLogprobs); + } + }, + + flush(controller) { + controller.enqueue({ + type: 'finish', + finishReason, + logprobs, + usage, + }); + }, + }), + ), + rawCall: { rawPrompt, rawSettings }, + rawResponse: { headers: responseHeaders }, + warnings, + request: { body: JSON.stringify(body) }, + }; + } +} + +// limited version of the schema, focussed on what is needed for the implementation +// this approach limits breakages when the API changes and increases efficiency +// TODO(shaper): Fix naming to match others e.g. 'openai' +const openAICompatibleCompletionResponseSchema = z.object({ + id: z.string().nullish(), + created: z.number().nullish(), + model: z.string().nullish(), + choices: z.array( + z.object({ + text: z.string(), + finish_reason: z.string(), + logprobs: z + .object({ + tokens: z.array(z.string()), + token_logprobs: z.array(z.number()), + top_logprobs: z.array(z.record(z.string(), z.number())).nullable(), + }) + .nullish(), + }), + ), + usage: z.object({ + prompt_tokens: z.number(), + completion_tokens: z.number(), + }), +}); + +// limited version of the schema, focussed on what is needed for the implementation +// this approach limits breakages when the API changes and increases efficiency +const openaiCompatibleCompletionChunkSchema = z.union([ + z.object({ + id: z.string().nullish(), + created: z.number().nullish(), + model: z.string().nullish(), + choices: z.array( + z.object({ + text: z.string(), + finish_reason: z.string().nullish(), + index: z.number(), + logprobs: z + .object({ + tokens: z.array(z.string()), + token_logprobs: z.array(z.number()), + top_logprobs: z.array(z.record(z.string(), z.number())).nullable(), + }) + .nullish(), + }), + ), + usage: z + .object({ + prompt_tokens: z.number(), + completion_tokens: z.number(), + }) + .nullish(), + }), + openAICompatibleErrorDataSchema, +]); diff --git a/packages/openai-compatible/src/openai-compatible-embedding-model.ts b/packages/openai-compatible/src/openai-compatible-embedding-model.ts index 23dc313590ff..53184dc405b5 100644 --- a/packages/openai-compatible/src/openai-compatible-embedding-model.ts +++ b/packages/openai-compatible/src/openai-compatible-embedding-model.ts @@ -13,7 +13,7 @@ import { OpenAICompatibleEmbeddingModelId, OpenAICompatibleEmbeddingSettings, } from './openai-compatible-embedding-settings'; -import { OpenAICompatibleFailedResponseHandler } from './openai-compatible-error'; +import { openAICompatibleFailedResponseHandler } from './openai-compatible-error'; type OpenAIEmbeddingConfig = { provider: string; @@ -82,7 +82,7 @@ export class OpenAICompatibleEmbeddingModel dimensions: this.settings.dimensions, user: this.settings.user, }, - failedResponseHandler: OpenAICompatibleFailedResponseHandler, + failedResponseHandler: openAICompatibleFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( openaiTextEmbeddingResponseSchema, ), diff --git a/packages/openai-compatible/src/openai-compatible-error.ts b/packages/openai-compatible/src/openai-compatible-error.ts index dd847395fe9c..883f5993cf8e 100644 --- a/packages/openai-compatible/src/openai-compatible-error.ts +++ b/packages/openai-compatible/src/openai-compatible-error.ts @@ -1,17 +1,19 @@ import { z } from 'zod'; import { createJsonErrorResponseHandler } from '@ai-sdk/provider-utils'; -export const OpenAICompatibleErrorDataSchema = z.object({ +// TODO(shaper): Reconcile this with openai-error.ts. We derived from `xai`. + +export const openAICompatibleErrorDataSchema = z.object({ code: z.string(), error: z.string(), }); export type OpenAICompatibleErrorData = z.infer< - typeof OpenAICompatibleErrorDataSchema + typeof openAICompatibleErrorDataSchema >; -export const OpenAICompatibleFailedResponseHandler = +export const openAICompatibleFailedResponseHandler = createJsonErrorResponseHandler({ - errorSchema: OpenAICompatibleErrorDataSchema, + errorSchema: openAICompatibleErrorDataSchema, errorToMessage: data => data.error, }); From fb946ab2130cf382a32b12884bfaf41354344bef Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Tue, 19 Nov 2024 12:43:02 -0800 Subject: [PATCH 05/13] Cleaning up some TODOs and model id sets. --- .../src/openai-compatible-provider.ts | 95 +++++++++++------- .../src/togetherai-chat-settings.ts | 96 ------------------- .../src/togetherai-completion-settings.ts | 49 ---------- .../src/togetherai-embedding-settings.ts | 13 --- .../togetherai/src/togetherai-provider.ts | 14 +-- 5 files changed, 65 insertions(+), 202 deletions(-) diff --git a/packages/openai-compatible/src/openai-compatible-provider.ts b/packages/openai-compatible/src/openai-compatible-provider.ts index de10c958e440..7d04ab67e771 100644 --- a/packages/openai-compatible/src/openai-compatible-provider.ts +++ b/packages/openai-compatible/src/openai-compatible-provider.ts @@ -10,33 +10,42 @@ import { } from '@ai-sdk/provider-utils'; import { OpenAICompatibleChatLanguageModel } from './openai-compatible-chat-language-model'; import { OpenAICompatibleChatSettings } from './openai-compatible-chat-settings'; +import { OpenAICompatibleCompletionLanguageModel } from './openai-compatible-completion-language-model'; import { OpenAICompatibleCompletionSettings } from './openai-compatible-completion-settings'; import { OpenAICompatibleEmbeddingSettings } from './openai-compatible-embedding-settings'; import { OpenAICompatibleEmbeddingModel } from './openai-compatible-embedding-model'; -export interface OpenAICompatibleProvider - extends ProviderV1 { - (modelId: M, settings?: OpenAICompatibleChatSettings): LanguageModelV1; +export interface OpenAICompatibleProvider< + L extends string = string, + C extends string = string, + E extends string = string, +> extends ProviderV1 { + (modelId: L, settings?: OpenAICompatibleChatSettings): LanguageModelV1; languageModel( - modelId: M, - settings?: OpenAICompatibleCompletionSettings, + modelId: L, + settings?: OpenAICompatibleChatSettings, ): LanguageModelV1; chatModel( - modelId: M, + modelId: L, settings?: OpenAICompatibleChatSettings, ): LanguageModelV1; + completionModel( + modelId: C, + settings?: OpenAICompatibleCompletionSettings, + ): LanguageModelV1; + textEmbeddingModel( - modelId: M, + modelId: E, settings?: OpenAICompatibleEmbeddingSettings, ): EmbeddingModelV1; } export interface OpenAICompatibleProviderSettings { /** -Base URL for the OpenAICompatible API calls. +Base URL for the API calls. */ baseURL?: string; @@ -57,24 +66,36 @@ or to provide a custom fetch implementation for e.g. testing. fetch?: FetchFunction; /** -The name of the environment variable from which to load the API key if not explicitly provided. +The name of the environment variable from which to load the API key (if a key isn't explicitly provided). */ apiKeyEnvVarName?: string; /** -Description of the API key environment variable for error messages. +Description of the API key environment variable (for use in error messages). */ apiKeyEnvVarDescription?: string; + + /** +Provider name. Overrides the `openai` default name for 3rd party providers. + */ + name?: string; } /** Create an OpenAICompatible provider instance. */ -export function createOpenAICompatible( +export function createOpenAICompatible< + L extends string, + C extends string, + E extends string, +>( options: OpenAICompatibleProviderSettings, -): OpenAICompatibleProvider { - // TODO(shaper): Throw if baseURL isn't set. +): OpenAICompatibleProvider { + // TODO(shaper): + // - consider throwing if baseUrl, name, sufficient api key info not available + // - force only 'compatible' -- look into whether we can remove some 'strict' logic/configs entirely const baseURL = withoutTrailingSlash(options.baseURL); + const providerName = options.name ?? 'openaiCompatible'; const getHeaders = () => ({ Authorization: `Bearer ${loadApiKey({ @@ -86,43 +107,48 @@ export function createOpenAICompatible( }); const createLanguageModel = ( - modelId: M, - settings?: OpenAICompatibleCompletionSettings, - ) => { - if (new.target) { - throw new Error( - 'The OpenAICompatible model function cannot be called with the new keyword.', - ); - } - - // TODO(shaper): Do we need to pull in and strip down the OpenAI Completion Model? - return createChatModel(modelId, settings); - }; + modelId: L, + settings?: OpenAICompatibleChatSettings, + ) => createChatModel(modelId, settings); + // TODO(shaper): Change provider strings below to allow concrete impls to specify. + // See openai-provider.ts:141 and subsequent configs. const createChatModel = ( - modelId: M, + modelId: L, settings: OpenAICompatibleChatSettings = {}, ) => new OpenAICompatibleChatLanguageModel(modelId, settings, { - provider: 'openAICompatible.chat', + provider: `${providerName}.chat`, + url: ({ path }) => `${baseURL}${path}`, + headers: getHeaders, + fetch: options.fetch, + }); + + const createCompletionModel = ( + modelId: C, + settings: OpenAICompatibleCompletionSettings = {}, + ) => + new OpenAICompatibleCompletionLanguageModel(modelId, settings, { + provider: `${providerName}.completion`, + compatibility: 'compatible', url: ({ path }) => `${baseURL}${path}`, headers: getHeaders, fetch: options.fetch, }); const createEmbeddingModel = ( - modelId: M, + modelId: E, settings: OpenAICompatibleEmbeddingSettings = {}, ) => new OpenAICompatibleEmbeddingModel(modelId, settings, { - provider: 'openaiCompatible.embedding', + provider: `${providerName}.embedding`, url: ({ path }) => `${baseURL}${path}`, headers: getHeaders, fetch: options.fetch, }); const provider = function ( - modelId: M, + modelId: L, settings?: OpenAICompatibleChatSettings, ) { return createLanguageModel(modelId, settings); @@ -130,13 +156,8 @@ export function createOpenAICompatible( provider.languageModel = createLanguageModel; provider.chatModel = createChatModel; + provider.completionModel = createCompletionModel; provider.textEmbeddingModel = createEmbeddingModel; - // TODO(shaper): Need a way for concrete impls to note if they don't support - // one of the model types. - // provider.textEmbeddingModel = (modelId: string) => { - // throw new NoSuchModelError({ modelId, modelType: 'textEmbeddingModel' }); - // }; - - return provider as OpenAICompatibleProvider; + return provider as OpenAICompatibleProvider; } diff --git a/packages/togetherai/src/togetherai-chat-settings.ts b/packages/togetherai/src/togetherai-chat-settings.ts index bb2c3d720482..29c4d60f6247 100644 --- a/packages/togetherai/src/togetherai-chat-settings.ts +++ b/packages/togetherai/src/togetherai-chat-settings.ts @@ -35,100 +35,4 @@ export type TogetherAIChatModelId = | 'upstage/SOLAR-10.7B-Instruct-v1.0' | (string & {}); -// https://docs.together.ai/docs/dedicated-models#chat-models -// export type TogetherAIChatModelId_Dedicated = -// | 'databricks/dbrx-instruct' -// | 'deepseek-ai/deepseek-coder-33b-instruct' -// | 'deepseek-ai/deepseek-llm-67b-chat' -// | 'google/gemma-2-27b-it' -// | 'google/gemma-2-9b-it' -// | 'google/gemma-2b-it' -// | 'google/gemma-7b-it' -// | 'gradientai/Llama-3-70B-Instruct-Gradient-1048k' -// | 'Gryphe/MythoMax-L2-13b-Lite' -// | 'Gryphe/MythoMax-L2-13b' -// | 'Haotian Liu/LLaVa-Next (Mistral-7B)' -// | 'HuggingFaceH4/zephyr-7b-beta' -// | 'lmSys/Koala (13B)' -// | 'lmSys/Koala-7B' -// | 'lmSys/Vicuna v1.3 (13B)' -// | 'lmSys/Vicuna v1.3 (7B)' -// | 'lmSys/Vicuna v1.5 (13B)' -// | 'lmSys/Vicuna v1.5 (7B)' -// | 'lmSys/Vicuna v1.5 16K (13B)' -// | 'Meta/Code Llama Instruct (13B)' -// | 'Meta/Code Llama Instruct (13B)' -// | 'Meta/Code Llama Instruct (34B)' -// | 'Meta/Code Llama Instruct (34B)' -// | 'Meta/Code Llama Instruct (70B)' -// | 'Meta/Code Llama Instruct (7B)' -// | 'Meta/LLaMA-2 Chat (13B)' -// | 'Meta/LLaMA-2 Chat (13B)' -// | 'Meta/LLaMA-2 Chat (70B)' -// | 'Meta/LLaMA-2 Chat (7B)' -// | 'Meta/LLaMA-2 Chat (7B)' -// | 'Meta/Llama3 8B Chat HF INT4' -// | 'Meta/Meta Llama 3 70B Instruct Lite' -// | 'Meta/Meta Llama 3 70B Instruct Reference' -// | 'Meta/Meta Llama 3 70B Instruct Turbo' -// | 'Meta/Meta Llama 3 70B Instruct' -// | 'Meta/Meta Llama 3 8B Instruct Lite' -// | 'Meta/Meta Llama 3 8B Instruct Reference' -// | 'Meta/Meta Llama 3 8B Instruct Turbo' -// | 'Meta/Meta Llama 3 8B Instruct' -// | 'Meta/Meta Llama 3.1 405B Instruct Turbo' -// | 'Meta/Meta Llama 3.1 405B Instruct Turbo' -// | 'Meta/Meta Llama 3.1 70B Instruct Turbo' -// | 'Meta/Meta Llama 3.1 8B Instruct Turbo' -// | 'Meta/Meta Llama 3.2 11B Vision Instruct Turbo' -// | 'Meta/Meta Llama 3.2 3B Instruct Turbo' -// | 'Meta/Meta Llama 3.2 90B Vision Instruct Turbo' -// | 'Meta/Meta Llama Vision Free' -// | 'Meta/Togethercomputer Llama3 8B Instruct Int8' -// | 'microsoft/WizardLM-2-8x22B' -// | 'mistralai/Mistral-7B-Instruct-v0.1' -// | 'mistralai/Mistral-7B-Instruct-v0.2' -// | 'mistralai/Mistral-7B-Instruct-v0.3' -// | 'mistralai/Mixtral-8x22B-Instruct-v0.1' -// | 'mistralai/Mixtral-8x7B-Instruct-v0.1' -// | 'NousResearch/Hermes-2-Theta-Llama-3-70B' -// | 'NousResearch/Nous-Capybara-7B-V1p9' -// | 'NousResearch/Nous-Hermes-2-Mistral-7B-DPO' -// | 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO' -// | 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT' -// | 'NousResearch/Nous-Hermes-Llama2-13b' -// | 'NousResearch/Nous-Hermes-Llama2-70b' -// | 'Open-Orca/Mistral-7B-OpenOrca' -// | 'openchat/openchat-3.5-1210' -// | 'Qwen/Qwen1.5-0.5B-Chat' -// | 'Qwen/Qwen1.5-1.8B-Chat' -// | 'Qwen/Qwen1.5-110B-Chat' -// | 'Qwen/Qwen1.5-14B-Chat' -// | 'Qwen/Qwen1.5-32B-Chat' -// | 'Qwen/Qwen1.5-4B-Chat' -// | 'Qwen/Qwen1.5-72B-Chat' -// | 'Qwen/Qwen1.5-7B-Chat' -// | 'Qwen/Qwen2-1.5B-Instruct' -// | 'Qwen/Qwen2-72B-Instruct' -// | 'Qwen/Qwen2-7B-Instruct' -// | 'Qwen/Qwen2.5-72B-Instruct-Turbo' -// | 'Qwen/Qwen2.5-7B-Instruct-Turbo' -// | 'Qwen/Qwen2.5-Coder-32B-Instruct' -// | 'snorkelai/Snorkel-Mistral-PairRM-DPO' -// | 'Snowflake/snowflake-arctic-instruct' -// | 'teknium/OpenHermes-2-Mistral-7B' -// | 'teknium/OpenHermes-2p5-Mistral-7B' -// | 'test/test11' -// | 'togethercomputer/alpaca-7b' -// | 'togethercomputer/guanaco-13b' -// | 'togethercomputer/guanaco-33b' -// | 'togethercomputer/guanaco-65b' -// | 'togethercomputer/guanaco-7b' -// | 'togethercomputer/SOLAR-10.7B-Instruct-v1.0-int4' -// | 'togethercomputer/SOLAR-10.7B-Instruct-v1.0' -// | 'Undi95/ReMM-SLERP-L2-13B' -// | 'Undi95/Toppy-M-7B' -// | 'WizardLM/WizardLM-13B-V1.2' -// | (string & {}); - export interface TogetherAIChatSettings extends OpenAICompatibleChatSettings {} diff --git a/packages/togetherai/src/togetherai-completion-settings.ts b/packages/togetherai/src/togetherai-completion-settings.ts index 2316be30e587..16fc86425161 100644 --- a/packages/togetherai/src/togetherai-completion-settings.ts +++ b/packages/togetherai/src/togetherai-completion-settings.ts @@ -6,54 +6,5 @@ export type TogetherAICompletionModelId = | 'Qwen/Qwen2.5-Coder-32B-Instruct' | (string & {}); -// https://docs.together.ai/docs/dedicated-models#language-models -// export type TogetherAICompletionModelId = -// | 'allenai/OLMo-7B-Instruct' -// | 'EleutherAI/llemma_7b' -// | 'google/gemma-2-9b' -// | 'google/gemma-2b' -// | 'google/gemma-7b' -// | 'gpt-3.5-turbo-instruct' -// | 'huggyllama/llama-13b' -// | 'huggyllama/llama-30b' -// | 'huggyllama/llama-65b' -// | 'huggyllama/llama-7b' -// | 'meta-llama/Llama-2-13b-hf' -// | 'meta-llama/Llama-2-70b-hf' -// | 'meta-llama/Llama-2-7b-hf' -// | 'meta-llama/Llama-3-8b-hf' -// | 'meta-llama/Meta-Llama-3-70b-hf' -// | 'meta-llama/Meta-Llama-3-70B' -// | 'meta-llama/Meta-Llama-3-8B' -// | 'meta-llama/Meta-Llama-3.1-70B-Reference' -// | 'meta-llama/Meta-Llama-3.1-8B-Reference' -// | 'microsoft/phi-2' -// | 'mistralai/Mistral-7B-v0.1' -// | 'mistralai/Mixtral-8x22B' -// | 'mistralai/Mixtral-8x7B-v0.1' -// | 'Nexusflow/NexusRaven-V2-13B' -// | 'NousResearch/Nous-Hermes-13b' -// | 'Qwen/Qwen1.5-0.5B' -// | 'Qwen/Qwen1.5-1.8B' -// | 'Qwen/Qwen1.5-14B' -// | 'Qwen/Qwen1.5-32B' -// | 'Qwen/Qwen1.5-4B' -// | 'Qwen/Qwen1.5-72B' -// | 'Qwen/Qwen1.5-7B' -// | 'Qwen/Qwen2-1.5B' -// | 'Qwen/Qwen2-72B' -// | 'Qwen/Qwen2-7B' -// | 'togethercomputer/evo-1-131k-base' -// | 'togethercomputer/evo-1-8k-base' -// | 'togethercomputer/llama-2-13b' -// | 'togethercomputer/llama-2-70b' -// | 'togethercomputer/LLaMA-2-7B-32K' -// | 'togethercomputer/llama-2-7b' -// | 'togethercomputer/StripedHyena-Hessian-7B' -// | 'WizardLM/WizardLM-70B-V1.0' -// | 'zero-one-ai/Yi-34B' -// | 'zero-one-ai/Yi-6B' -// | (string & {}); - export interface TogetherAICompletionSettings extends OpenAICompatibleCompletionSettings {} diff --git a/packages/togetherai/src/togetherai-embedding-settings.ts b/packages/togetherai/src/togetherai-embedding-settings.ts index 7d008442bd15..9cc008f3026f 100644 --- a/packages/togetherai/src/togetherai-embedding-settings.ts +++ b/packages/togetherai/src/togetherai-embedding-settings.ts @@ -12,18 +12,5 @@ export type TogetherAIEmbeddingModelId = | 'WhereIsAI/UAE-Large-V1' | (string & {}); -// https://docs.together.ai/docs/dedicated-models#embedding-models -// export type TogetherAIEmbeddingModelId = -// | 'BAAI/bge-base-en-v1.5' -// | 'BAAI/bge-large-en-v1.5' -// | 'bert-base-uncased' -// | 'hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1' -// | 'sentence-transformers/msmarco-bert-base-dot-v5' -// | 'togethercomputer/m2-bert-80M-2k-retrieval' -// | 'togethercomputer/m2-bert-80M-32k-retrieval' -// | 'togethercomputer/m2-bert-80M-8k-retrieval' -// | 'WhereIsAI/UAE-Large-V1' -// | (string & {}); - export interface TogetherAIEmbeddingSettings extends OpenAICompatibleEmbeddingSettings {} diff --git a/packages/togetherai/src/togetherai-provider.ts b/packages/togetherai/src/togetherai-provider.ts index 5d134443a4c3..d4b0f4188486 100644 --- a/packages/togetherai/src/togetherai-provider.ts +++ b/packages/togetherai/src/togetherai-provider.ts @@ -20,9 +20,9 @@ export interface TogetherAIProviderSettings export interface TogetherAIProvider extends OpenAICompatibleProvider< - | TogetherAIChatModelId - | TogetherAICompletionModelId - | TogetherAIEmbeddingModelId + TogetherAIChatModelId, + TogetherAICompletionModelId, + TogetherAIEmbeddingModelId > { chatModel( modelId: TogetherAIChatModelId, @@ -45,15 +45,15 @@ export function createTogetherAI( ): TogetherAIProvider { const providerOptions: OpenAICompatibleProviderSettings = { baseURL: 'https://api.together.xyz/v1/', + name: 'togetherai', apiKeyEnvVarName: 'TOGETHER_AI_API_KEY', apiKeyEnvVarDescription: "TogetherAI's API key", ...options, }; - // TODO(shaper): Consider separating generics in the ctor. const openAICompatibleProvider = createOpenAICompatible< - | TogetherAIChatModelId - | TogetherAICompletionModelId - | TogetherAIEmbeddingModelId + TogetherAIChatModelId, + TogetherAICompletionModelId, + TogetherAIEmbeddingModelId >(providerOptions); const togetheraiProvider: TogetherAIProvider = Object.assign( From 610a98b34304338548cc9d9eb6a8f444b2b13369 Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Tue, 19 Nov 2024 15:36:38 -0800 Subject: [PATCH 06/13] Allow concrete impl to specify defaultObjectGenerationMode. --- .../src/generate-text/togetherai-tool-call.ts | 60 +++++++++++++++++++ .../openai-compatible-chat-language-model.ts | 7 ++- .../src/openai-compatible-chat-settings.ts | 13 +++- .../togetherai/src/togetherai-provider.ts | 20 +++++-- 4 files changed, 92 insertions(+), 8 deletions(-) create mode 100644 examples/ai-core/src/generate-text/togetherai-tool-call.ts diff --git a/examples/ai-core/src/generate-text/togetherai-tool-call.ts b/examples/ai-core/src/generate-text/togetherai-tool-call.ts new file mode 100644 index 000000000000..e0972cf113a7 --- /dev/null +++ b/examples/ai-core/src/generate-text/togetherai-tool-call.ts @@ -0,0 +1,60 @@ +import { togetherai } from '@ai-sdk/togetherai'; +import { generateText, tool } from 'ai'; +import 'dotenv/config'; +import { z } from 'zod'; +import { weatherTool } from '../tools/weather-tool'; + +async function main() { + const result = await generateText({ + model: togetherai('meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'), + maxTokens: 512, + tools: { + weather: weatherTool, + cityAttractions: tool({ + parameters: z.object({ city: z.string() }), + }), + }, + prompt: + 'What is the weather in San Francisco and what attractions should I visit?', + }); + + // typed tool calls: + for (const toolCall of result.toolCalls) { + switch (toolCall.toolName) { + case 'cityAttractions': { + toolCall.args.city; // string + break; + } + + case 'weather': { + toolCall.args.location; // string + break; + } + } + } + + // typed tool results for tools with execute method: + for (const toolResult of result.toolResults) { + switch (toolResult.toolName) { + // NOT AVAILABLE (NO EXECUTE METHOD) + // case 'cityAttractions': { + // toolResult.args.city; // string + // toolResult.result; + // break; + // } + + case 'weather': { + toolResult.args.location; // string + toolResult.result.location; // string + toolResult.result.temperature; // number + break; + } + } + } + + console.log('Text:', result.text); + console.log('Tool Calls:', JSON.stringify(result.toolCalls, null, 2)); + console.log('Tool Results:', JSON.stringify(result.toolResults, null, 2)); +} + +main().catch(console.error); diff --git a/packages/openai-compatible/src/openai-compatible-chat-language-model.ts b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts index 99c03b363bb5..d73de06cf4af 100644 --- a/packages/openai-compatible/src/openai-compatible-chat-language-model.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts @@ -42,7 +42,6 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { readonly specificationVersion = 'v1'; readonly supportsStructuredOutputs = false; - readonly defaultObjectGenerationMode = 'tool'; readonly modelId: OpenAICompatibleChatModelId; readonly settings: OpenAICompatibleChatSettings; @@ -59,6 +58,10 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { this.config = config; } + get defaultObjectGenerationMode(): 'json' | 'tool' | undefined { + return this.settings.defaultObjectGenerationMode; + } + get provider(): string { return this.config.provider; } @@ -120,6 +123,7 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { // response format: response_format: + // TODO(shaper): Review vs. OpenAI impl here. // json object response format is not currently supported undefined, @@ -141,6 +145,7 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { } case 'object-json': { + // TODO(shaper): Review vs. OpenAI impl here. throw new UnsupportedFunctionalityError({ functionality: 'object-json mode', }); diff --git a/packages/openai-compatible/src/openai-compatible-chat-settings.ts b/packages/openai-compatible/src/openai-compatible-chat-settings.ts index 515704032c21..e3ab3d164370 100644 --- a/packages/openai-compatible/src/openai-compatible-chat-settings.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-settings.ts @@ -4,6 +4,17 @@ export interface OpenAICompatibleChatSettings { /** A unique identifier representing your end-user, which can help the provider to monitor and detect abuse. -*/ + */ user?: string; + + /** +Default object generation mode that should be used with this model when +no mode is specified. Should be the mode with the best results for this +model. `undefined` can be returned if object generation is not supported. + +This is needed to generate the best objects possible w/o requiring the +user to explicitly specify the object generation mode. + */ + // TODO(shaper): This is really model-specific, move to config or elsewhere? + defaultObjectGenerationMode?: 'json' | 'tool' | undefined; } diff --git a/packages/togetherai/src/togetherai-provider.ts b/packages/togetherai/src/togetherai-provider.ts index d4b0f4188486..f59864bf59c2 100644 --- a/packages/togetherai/src/togetherai-provider.ts +++ b/packages/togetherai/src/togetherai-provider.ts @@ -5,7 +5,10 @@ import { OpenAICompatibleProviderSettings, } from '@ai-sdk/openai-compatible'; import { LanguageModelV1, EmbeddingModelV1 } from '@ai-sdk/provider'; -import { TogetherAIChatModelId } from './togetherai-chat-settings'; +import { + TogetherAIChatModelId, + TogetherAIChatSettings, +} from './togetherai-chat-settings'; import { TogetherAIEmbeddingModelId, TogetherAIEmbeddingSettings, @@ -26,12 +29,12 @@ export interface TogetherAIProvider > { chatModel( modelId: TogetherAIChatModelId, - settings?: OpenAICompatibleChatSettings, + settings?: TogetherAIChatSettings, ): LanguageModelV1; completionModel( modelId: TogetherAICompletionModelId, - settings?: OpenAICompatibleChatSettings, + settings?: TogetherAICompletionSettings, ): LanguageModelV1; textEmbeddingModel( @@ -59,16 +62,21 @@ export function createTogetherAI( const togetheraiProvider: TogetherAIProvider = Object.assign( ( modelId: TogetherAIChatModelId, - settings?: OpenAICompatibleChatSettings, + settings?: TogetherAIChatSettings, ): LanguageModelV1 => { return openAICompatibleProvider(modelId, settings); }, { chatModel: ( modelId: TogetherAIChatModelId, - settings?: OpenAICompatibleChatSettings, + settings?: TogetherAIChatSettings, ) => { - return openAICompatibleProvider.chatModel(modelId, settings); + // TODO(shaper): Perhaps the object generation mode will vary by model. + const defaultSettings: Partial = { + defaultObjectGenerationMode: 'json', + }; + const mergedSettings = { ...defaultSettings, ...settings }; + return openAICompatibleProvider.chatModel(modelId, mergedSettings); }, completionModel: ( From cfd4fdb165a14e150111e3b1f10623d6607a47eb Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Tue, 19 Nov 2024 17:03:01 -0800 Subject: [PATCH 07/13] Rudimentary test fixing/adding. --- examples/ai-core/src/embed/togetherai.ts | 15 + .../ai-core/src/generate-object/togetherai.ts | 4 +- .../src/js-test/rename-format-stream-part.js | 4 - .../src/stream-text/togetherai-tool-call.ts | 70 +++ ...nai-compatible-chat-language-model.test.ts | 9 +- ...mpatible-completion-language-model.test.ts | 529 ++++++++++++++++++ .../openai-compatible-embedding-model.test.ts | 136 +++++ .../src/togetherai-provider.test.ts | 113 ++++ .../togetherai/src/togetherai-provider.ts | 66 +-- 9 files changed, 903 insertions(+), 43 deletions(-) create mode 100644 examples/ai-core/src/embed/togetherai.ts delete mode 100644 examples/ai-core/src/js-test/rename-format-stream-part.js create mode 100644 examples/ai-core/src/stream-text/togetherai-tool-call.ts create mode 100644 packages/openai-compatible/src/openai-compatible-completion-language-model.test.ts create mode 100644 packages/openai-compatible/src/openai-compatible-embedding-model.test.ts create mode 100644 packages/togetherai/src/togetherai-provider.test.ts diff --git a/examples/ai-core/src/embed/togetherai.ts b/examples/ai-core/src/embed/togetherai.ts new file mode 100644 index 000000000000..43c2ec20b939 --- /dev/null +++ b/examples/ai-core/src/embed/togetherai.ts @@ -0,0 +1,15 @@ +import { togetherai } from '@ai-sdk/togetherai'; +import { embed } from 'ai'; +import 'dotenv/config'; + +async function main() { + const { embedding, usage } = await embed({ + model: togetherai.textEmbeddingModel('BAAI/bge-base-en-v1.5'), + value: 'sunny day at the beach', + }); + + console.log(embedding); + console.log(usage); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/generate-object/togetherai.ts b/examples/ai-core/src/generate-object/togetherai.ts index 8c44b0eddce3..e37f22397559 100644 --- a/examples/ai-core/src/generate-object/togetherai.ts +++ b/examples/ai-core/src/generate-object/togetherai.ts @@ -5,7 +5,9 @@ import { z } from 'zod'; async function main() { const result = await generateObject({ - model: togetherai('google/gemma-2b-it'), + model: togetherai.chatModel( + 'meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo', + ), schema: z.object({ recipe: z.object({ name: z.string(), diff --git a/examples/ai-core/src/js-test/rename-format-stream-part.js b/examples/ai-core/src/js-test/rename-format-stream-part.js deleted file mode 100644 index 27f763edcfdd..000000000000 --- a/examples/ai-core/src/js-test/rename-format-stream-part.js +++ /dev/null @@ -1,4 +0,0 @@ -// @ts-nocheck -import { formatStreamPart } from 'ai'; - -const response = new Response(formatStreamPart('text', cached)); diff --git a/examples/ai-core/src/stream-text/togetherai-tool-call.ts b/examples/ai-core/src/stream-text/togetherai-tool-call.ts new file mode 100644 index 000000000000..896242b61434 --- /dev/null +++ b/examples/ai-core/src/stream-text/togetherai-tool-call.ts @@ -0,0 +1,70 @@ +import { togetherai } from '@ai-sdk/togetherai'; +import { streamText, CoreMessage, ToolCallPart, ToolResultPart } from 'ai'; +import 'dotenv/config'; +import { weatherTool } from '../tools/weather-tool'; + +const messages: CoreMessage[] = []; + +async function main() { + let toolResponseAvailable = false; + + const result = streamText({ + model: togetherai('meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'), + maxTokens: 512, + tools: { + weather: weatherTool, + }, + toolChoice: 'required', + prompt: + 'What is the weather in San Francisco and what attractions should I visit?', + }); + + let fullResponse = ''; + const toolCalls: ToolCallPart[] = []; + const toolResponses: ToolResultPart[] = []; + + for await (const delta of result.fullStream) { + switch (delta.type) { + case 'text-delta': { + fullResponse += delta.textDelta; + process.stdout.write(delta.textDelta); + break; + } + + case 'tool-call': { + toolCalls.push(delta); + + process.stdout.write( + `\nTool call: '${delta.toolName}' ${JSON.stringify(delta.args)}`, + ); + break; + } + + case 'tool-result': { + toolResponses.push(delta); + + process.stdout.write( + `\nTool response: '${delta.toolName}' ${JSON.stringify( + delta.result, + )}`, + ); + break; + } + } + } + process.stdout.write('\n\n'); + + messages.push({ + role: 'assistant', + content: [{ type: 'text', text: fullResponse }, ...toolCalls], + }); + + if (toolResponses.length > 0) { + messages.push({ role: 'tool', content: toolResponses }); + } + + toolResponseAvailable = toolCalls.length > 0; + console.log('Messages:', messages[0].content); +} + +main().catch(console.error); diff --git a/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts b/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts index 3e79947f16a1..e40763762244 100644 --- a/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts @@ -4,7 +4,7 @@ import { StreamingTestServer, convertReadableStreamToArray, } from '@ai-sdk/provider-utils/test'; -import { createOpenAICompat as createOpenAICompatible } from './openai-compatible-provider'; +import { createOpenAICompatible } from './openai-compatible-provider'; const TEST_PROMPT: LanguageModelV1Prompt = [ { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, @@ -12,12 +12,13 @@ const TEST_PROMPT: LanguageModelV1Prompt = [ const provider = createOpenAICompatible({ apiKey: 'test-api-key', + baseURL: 'https://my.api.com/v1/', }); const model = provider('grok-beta'); describe('doGenerate', () => { - const server = new JsonTestServer('https://api.x.ai/v1/chat/completions'); + const server = new JsonTestServer('https://my.api.com/v1/chat/completions'); server.setupTestEnvironment(); @@ -307,6 +308,7 @@ describe('doGenerate', () => { const provider = createOpenAICompatible({ apiKey: 'test-api-key', + baseURL: 'https://my.api.com/v1/', headers: { 'Custom-Provider-Header': 'provider-header-value', }, @@ -397,7 +399,7 @@ describe('doGenerate', () => { describe('doStream', () => { const server = new StreamingTestServer( - 'https://api.x.ai/v1/chat/completions', + 'https://my.api.com/v1/chat/completions', ); server.setupTestEnvironment(); @@ -879,6 +881,7 @@ describe('doStream', () => { const provider = createOpenAICompatible({ apiKey: 'test-api-key', + baseURL: 'https://my.api.com/v1', headers: { 'Custom-Provider-Header': 'provider-header-value', }, diff --git a/packages/openai-compatible/src/openai-compatible-completion-language-model.test.ts b/packages/openai-compatible/src/openai-compatible-completion-language-model.test.ts new file mode 100644 index 000000000000..f1e64a5431ee --- /dev/null +++ b/packages/openai-compatible/src/openai-compatible-completion-language-model.test.ts @@ -0,0 +1,529 @@ +import { LanguageModelV1Prompt } from '@ai-sdk/provider'; +import { + JsonTestServer, + StreamingTestServer, + convertReadableStreamToArray, +} from '@ai-sdk/provider-utils/test'; +import { createOpenAICompatible } from './openai-compatible-provider'; +import { mapOpenAICompatibleCompletionLogProbs } from './map-openai-compatible-completion-logprobs'; + +const TEST_PROMPT: LanguageModelV1Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, +]; + +const TEST_LOGPROBS = { + tokens: [' ever', ' after', '.\n\n', 'The', ' end', '.'], + token_logprobs: [ + -0.0664508, -0.014520033, -1.3820221, -0.7890417, -0.5323165, -0.10247037, + ], + top_logprobs: [ + { + ' ever': -0.0664508, + }, + { + ' after': -0.014520033, + }, + { + '.\n\n': -1.3820221, + }, + { + The: -0.7890417, + }, + { + ' end': -0.5323165, + }, + { + '.': -0.10247037, + }, + ] as Record[], +}; + +const provider = createOpenAICompatible({ + apiKey: 'test-api-key', + baseURL: 'https://my.api.com/v1/', +}); + +const model = provider.completionModel('gpt-3.5-turbo-instruct'); + +describe('doGenerate', () => { + const server = new JsonTestServer('https://my.api.com/v1/completions'); + + server.setupTestEnvironment(); + + function prepareJsonResponse({ + content = '', + usage = { + prompt_tokens: 4, + total_tokens: 34, + completion_tokens: 30, + }, + logprobs = null, + finish_reason = 'stop', + id = 'cmpl-96cAM1v77r4jXa4qb2NSmRREV5oWB', + created = 1711363706, + model = 'gpt-3.5-turbo-instruct', + }: { + content?: string; + usage?: { + prompt_tokens: number; + total_tokens: number; + completion_tokens: number; + }; + logprobs?: { + tokens: string[]; + token_logprobs: number[]; + top_logprobs: Record[]; + } | null; + finish_reason?: string; + id?: string; + created?: number; + model?: string; + }) { + server.responseBodyJson = { + id, + object: 'text_completion', + created, + model, + choices: [ + { + text: content, + index: 0, + logprobs, + finish_reason, + }, + ], + usage, + }; + } + + it('should extract text response', async () => { + prepareJsonResponse({ content: 'Hello, World!' }); + + const { text } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(text).toStrictEqual('Hello, World!'); + }); + + it('should extract usage', async () => { + prepareJsonResponse({ + content: '', + usage: { prompt_tokens: 20, total_tokens: 25, completion_tokens: 5 }, + }); + + const { usage } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(usage).toStrictEqual({ + promptTokens: 20, + completionTokens: 5, + }); + }); + + it('should send request body', async () => { + prepareJsonResponse({}); + + const { request } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(request).toStrictEqual({ + body: '{"model":"gpt-3.5-turbo-instruct","prompt":"Hello"}', + }); + }); + + it('should send additional response information', async () => { + prepareJsonResponse({ + id: 'test-id', + created: 123, + model: 'test-model', + }); + + const { response } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(response).toStrictEqual({ + id: 'test-id', + timestamp: new Date(123 * 1000), + modelId: 'test-model', + }); + }); + + it('should extract logprobs', async () => { + prepareJsonResponse({ logprobs: TEST_LOGPROBS }); + + const provider = createOpenAICompatible({ + apiKey: 'test-api-key', + baseURL: 'https://my.api.com/v1/', + }); + + const response = await provider + .completionModel('gpt-3.5-turbo', { logprobs: 1 }) + .doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + expect(response.logprobs).toStrictEqual( + mapOpenAICompatibleCompletionLogProbs(TEST_LOGPROBS), + ); + }); + + it('should extract finish reason', async () => { + prepareJsonResponse({ + content: '', + finish_reason: 'stop', + }); + + const { finishReason } = await provider + .completionModel('gpt-3.5-turbo-instruct') + .doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(finishReason).toStrictEqual('stop'); + }); + + it('should support unknown finish reason', async () => { + prepareJsonResponse({ + content: '', + finish_reason: 'eos', + }); + + const { finishReason } = await provider + .completionModel('gpt-3.5-turbo-instruct') + .doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(finishReason).toStrictEqual('unknown'); + }); + + it('should expose the raw response headers', async () => { + prepareJsonResponse({ content: '' }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-length': '266', + 'content-type': 'application/json', + + // custom header + 'test-header': 'test-value', + }); + }); + + it('should pass the model and the prompt', async () => { + prepareJsonResponse({ content: '' }); + + await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + model: 'gpt-3.5-turbo-instruct', + prompt: 'Hello', + }); + }); + + it('should pass headers', async () => { + prepareJsonResponse({ content: '' }); + + const provider = createOpenAICompatible({ + apiKey: 'test-api-key', + baseURL: 'https://my.api.com/v1/', + // TODO(shaper): Do we need these? + // organization: 'test-organization', + // project: 'test-project', + headers: { + 'Custom-Provider-Header': 'provider-header-value', + }, + }); + + await provider.completionModel('gpt-3.5-turbo-instruct').doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + headers: { + 'Custom-Request-Header': 'request-header-value', + }, + }); + + const requestHeaders = await server.getRequestHeaders(); + + expect(requestHeaders).toStrictEqual({ + authorization: 'Bearer test-api-key', + 'content-type': 'application/json', + 'custom-provider-header': 'provider-header-value', + 'custom-request-header': 'request-header-value', + // 'openai-organization': 'test-organization', + // 'openai-project': 'test-project', + }); + }); +}); + +describe('doStream', () => { + const server = new StreamingTestServer('https://my.api.com/v1/completions'); + + server.setupTestEnvironment(); + + function prepareStreamResponse({ + content, + finish_reason = 'stop', + usage = { + prompt_tokens: 10, + total_tokens: 372, + completion_tokens: 362, + }, + logprobs = null, + }: { + content: string[]; + usage?: { + prompt_tokens: number; + total_tokens: number; + completion_tokens: number; + }; + logprobs?: { + tokens: string[]; + token_logprobs: number[]; + top_logprobs: Record[]; + } | null; + finish_reason?: string; + }) { + server.responseChunks = [ + ...content.map(text => { + return ( + `data: {"id":"cmpl-96c64EdfhOw8pjFFgVpLuT8k2MtdT","object":"text_completion","created":1711363440,` + + `"choices":[{"text":"${text}","index":0,"logprobs":null,"finish_reason":null}],"model":"gpt-3.5-turbo-instruct"}\n\n` + ); + }), + `data: {"id":"cmpl-96c3yLQE1TtZCd6n6OILVmzev8M8H","object":"text_completion","created":1711363310,` + + `"choices":[{"text":"","index":0,"logprobs":${JSON.stringify( + logprobs, + )},"finish_reason":"${finish_reason}"}],"model":"gpt-3.5-turbo-instruct"}\n\n`, + `data: {"id":"cmpl-96c3yLQE1TtZCd6n6OILVmzev8M8H","object":"text_completion","created":1711363310,` + + `"model":"gpt-3.5-turbo-instruct","usage":${JSON.stringify( + usage, + )},"choices":[]}\n\n`, + 'data: [DONE]\n\n', + ]; + } + + it('should stream text deltas', async () => { + prepareStreamResponse({ + content: ['Hello', ', ', 'World!'], + finish_reason: 'stop', + usage: { + prompt_tokens: 10, + total_tokens: 372, + completion_tokens: 362, + }, + logprobs: TEST_LOGPROBS, + }); + + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + // note: space moved to last chunk bc of trimming + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { + id: 'cmpl-96c64EdfhOw8pjFFgVpLuT8k2MtdT', + modelId: 'gpt-3.5-turbo-instruct', + timestamp: new Date('2024-03-25T10:44:00.000Z'), + type: 'response-metadata', + }, + { type: 'text-delta', textDelta: 'Hello' }, + { type: 'text-delta', textDelta: ', ' }, + { type: 'text-delta', textDelta: 'World!' }, + { type: 'text-delta', textDelta: '' }, + { + type: 'finish', + finishReason: 'stop', + logprobs: mapOpenAICompatibleCompletionLogProbs(TEST_LOGPROBS), + usage: { promptTokens: 10, completionTokens: 362 }, + }, + ]); + }); + + // TODO(shaper): Look into type validation failure in the below. + it.skip('should handle error stream parts', async () => { + server.responseChunks = [ + `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our ` + + `help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}\n\n`, + 'data: [DONE]\n\n', + ]; + + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { + type: 'error', + error: { + message: + 'The server had an error processing your request. Sorry about that! ' + + 'You can retry your request, or contact us through our help center at ' + + 'help.openai.com if you keep seeing this error.', + type: 'server_error', + code: null, + param: null, + }, + }, + { + finishReason: 'error', + logprobs: undefined, + type: 'finish', + usage: { + completionTokens: NaN, + promptTokens: NaN, + }, + }, + ]); + }); + + it('should handle unparsable stream parts', async () => { + server.responseChunks = [`data: {unparsable}\n\n`, 'data: [DONE]\n\n']; + + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + const elements = await convertReadableStreamToArray(stream); + + expect(elements.length).toBe(2); + expect(elements[0].type).toBe('error'); + expect(elements[1]).toStrictEqual({ + finishReason: 'error', + logprobs: undefined, + type: 'finish', + usage: { + completionTokens: NaN, + promptTokens: NaN, + }, + }); + }); + + it('should send request body', async () => { + prepareStreamResponse({ content: [] }); + + const { request } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(request).toStrictEqual({ + // body: '{"model":"gpt-3.5-turbo-instruct","prompt":"Hello","stream":true,"stream_options":{"include_usage":true}}', + body: '{"model":"gpt-3.5-turbo-instruct","prompt":"Hello","stream":true}', + }); + }); + + it('should expose the raw response headers', async () => { + prepareStreamResponse({ content: [] }); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'text/event-stream', + 'cache-control': 'no-cache', + connection: 'keep-alive', + + // custom header + 'test-header': 'test-value', + }); + }); + + it('should pass the model and the prompt', async () => { + prepareStreamResponse({ content: [] }); + + await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + stream: true, + // stream_options: { include_usage: true }, + model: 'gpt-3.5-turbo-instruct', + prompt: 'Hello', + }); + }); + + it('should pass headers', async () => { + prepareStreamResponse({ content: [] }); + + const provider = createOpenAICompatible({ + apiKey: 'test-api-key', + baseURL: 'https://my.api.com/v1/', + // organization: 'test-organization', + // project: 'test-project', + headers: { + 'Custom-Provider-Header': 'provider-header-value', + }, + }); + + await provider.completionModel('gpt-3.5-turbo-instruct').doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + headers: { + 'Custom-Request-Header': 'request-header-value', + }, + }); + + const requestHeaders = await server.getRequestHeaders(); + + expect(requestHeaders).toStrictEqual({ + authorization: 'Bearer test-api-key', + 'content-type': 'application/json', + 'custom-provider-header': 'provider-header-value', + 'custom-request-header': 'request-header-value', + // 'openai-organization': 'test-organization', + // 'openai-project': 'test-project', + }); + }); +}); diff --git a/packages/openai-compatible/src/openai-compatible-embedding-model.test.ts b/packages/openai-compatible/src/openai-compatible-embedding-model.test.ts new file mode 100644 index 000000000000..2e81b01ffe62 --- /dev/null +++ b/packages/openai-compatible/src/openai-compatible-embedding-model.test.ts @@ -0,0 +1,136 @@ +import { EmbeddingModelV1Embedding } from '@ai-sdk/provider'; +import { JsonTestServer } from '@ai-sdk/provider-utils/test'; +import { createOpenAICompatible } from './openai-compatible-provider'; + +const dummyEmbeddings = [ + [0.1, 0.2, 0.3, 0.4, 0.5], + [0.6, 0.7, 0.8, 0.9, 1.0], +]; +const testValues = ['sunny day at the beach', 'rainy day in the city']; + +const provider = createOpenAICompatible({ + apiKey: 'test-api-key', + baseURL: 'https://my.api.com/v1/', +}); +const model = provider.textEmbeddingModel('text-embedding-3-large'); + +describe('doEmbed', () => { + const server = new JsonTestServer('https://my.api.com/v1/embeddings'); + + server.setupTestEnvironment(); + + function prepareJsonResponse({ + embeddings = dummyEmbeddings, + usage = { prompt_tokens: 8, total_tokens: 8 }, + }: { + embeddings?: EmbeddingModelV1Embedding[]; + usage?: { prompt_tokens: number; total_tokens: number }; + } = {}) { + server.responseBodyJson = { + object: 'list', + data: embeddings.map((embedding, i) => ({ + object: 'embedding', + index: i, + embedding, + })), + model: 'text-embedding-3-large', + usage, + }; + } + + it('should extract embedding', async () => { + prepareJsonResponse(); + + const { embeddings } = await model.doEmbed({ values: testValues }); + + expect(embeddings).toStrictEqual(dummyEmbeddings); + }); + + it('should expose the raw response headers', async () => { + prepareJsonResponse(); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doEmbed({ values: testValues }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-length': '236', + 'content-type': 'application/json', + + // custom header + 'test-header': 'test-value', + }); + }); + + it('should extract usage', async () => { + prepareJsonResponse({ + usage: { prompt_tokens: 20, total_tokens: 20 }, + }); + + const { usage } = await model.doEmbed({ values: testValues }); + + expect(usage).toStrictEqual({ tokens: 20 }); + }); + + it('should pass the model and the values', async () => { + prepareJsonResponse(); + + await model.doEmbed({ values: testValues }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + model: 'text-embedding-3-large', + input: testValues, + encoding_format: 'float', + }); + }); + + it('should pass the dimensions setting', async () => { + prepareJsonResponse(); + + await provider + .textEmbeddingModel('text-embedding-3-large', { dimensions: 64 }) + .doEmbed({ values: testValues }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + model: 'text-embedding-3-large', + input: testValues, + encoding_format: 'float', + dimensions: 64, + }); + }); + + it('should pass headers', async () => { + prepareJsonResponse(); + + const provider = createOpenAICompatible({ + apiKey: 'test-api-key', + baseURL: 'https://my.api.com/v1/', + // organization: 'test-organization', + // project: 'test-project', + headers: { + 'Custom-Provider-Header': 'provider-header-value', + }, + }); + + await provider.textEmbeddingModel('text-embedding-3-large').doEmbed({ + values: testValues, + headers: { + 'Custom-Request-Header': 'request-header-value', + }, + }); + + const requestHeaders = await server.getRequestHeaders(); + + expect(requestHeaders).toStrictEqual({ + authorization: 'Bearer test-api-key', + 'content-type': 'application/json', + 'custom-provider-header': 'provider-header-value', + 'custom-request-header': 'request-header-value', + // 'openai-organization': 'test-organization', + // 'openai-project': 'test-project', + }); + }); +}); diff --git a/packages/togetherai/src/togetherai-provider.test.ts b/packages/togetherai/src/togetherai-provider.test.ts new file mode 100644 index 000000000000..8a44903f3013 --- /dev/null +++ b/packages/togetherai/src/togetherai-provider.test.ts @@ -0,0 +1,113 @@ +import { describe, it, expect, vi, beforeEach, Mock } from 'vitest'; +import { createTogetherAI } from './togetherai-provider'; +import { + OpenAICompatibleProvider, + createOpenAICompatible, +} from '@ai-sdk/openai-compatible'; +import { LanguageModelV1, EmbeddingModelV1 } from '@ai-sdk/provider'; +import { TogetherAIChatSettings } from './togetherai-chat-settings'; + +vi.mock('@ai-sdk/openai-compatible', () => { + const actual = vi.importActual('@ai-sdk/openai-compatible'); + return { + ...actual, + createOpenAICompatible: vi.fn(), + }; +}); + +describe('TogetherAIProvider', () => { + let mockLanguageModel: LanguageModelV1; + let mockEmbeddingModel: EmbeddingModelV1; + let mockOpenAICompatibleProvider: OpenAICompatibleProvider; + let createOpenAICompatibleMock: Mock; + + beforeEach(() => { + // Mock implementations of models + mockLanguageModel = {} as LanguageModelV1; + mockEmbeddingModel = {} as EmbeddingModelV1; + + // Mock the OpenAICompatibleProvider methods + mockOpenAICompatibleProvider = Object.assign( + vi.fn(() => mockLanguageModel), + { + chatModel: vi.fn(() => mockLanguageModel), + completionModel: vi.fn(() => mockLanguageModel), + languageModel: vi.fn(() => mockLanguageModel), + textEmbeddingModel: vi.fn(() => mockEmbeddingModel), + }, + ); + + // Mock createOpenAICompatible to return our mock provider + createOpenAICompatibleMock = createOpenAICompatible as unknown as Mock; + createOpenAICompatibleMock.mockReturnValue(mockOpenAICompatibleProvider); + }); + + describe('createTogetherAI', () => { + it('should create a TogetherAIProvider instance', () => { + const provider = createTogetherAI(); + expect(provider).toBeDefined(); + expect(typeof provider).toBe('function'); + }); + + it('should return a default language model when called as a function', () => { + const provider = createTogetherAI(); + const modelId = 'foo-model-id'; + const settings = { user: 'foo-user' }; + + const model = provider(modelId, settings); + + expect(model).toBe(mockLanguageModel); + expect(mockOpenAICompatibleProvider.languageModel).toHaveBeenCalledWith( + modelId, + settings, + ); + }); + }); + + describe('chatModel', () => { + it('should construct a chat model using the openAICompatibleProvider', () => { + const provider = createTogetherAI(); + const modelId = 'together-chat-model'; + const settings: TogetherAIChatSettings = { user: 'foo-user' }; + + const model = provider.chatModel(modelId, settings); + + expect(model).toBe(mockLanguageModel); + expect(mockOpenAICompatibleProvider.chatModel).toHaveBeenCalledWith( + modelId, + { defaultObjectGenerationMode: 'json', ...settings }, + ); + }); + }); + + describe('completionModel', () => { + it('should construct a completion model using the openAICompatibleProvider', () => { + const provider = createTogetherAI(); + const modelId = 'together-completion-model'; + const settings: TogetherAIChatSettings = { user: 'foo-user' }; + + const model = provider.completionModel(modelId, settings); + + expect(model).toBe(mockLanguageModel); + expect(mockOpenAICompatibleProvider.languageModel).toHaveBeenCalledWith( + modelId, + settings, + ); + }); + }); + + describe('textEmbeddingModel', () => { + it('should construct a text embedding model using the openAICompatibleProvider', () => { + const provider = createTogetherAI(); + const modelId = 'together-embedding-model'; + const settings: TogetherAIChatSettings = { user: 'foo-user' }; + + const model = provider.textEmbeddingModel(modelId, settings); + + expect(model).toBe(mockEmbeddingModel); + expect( + mockOpenAICompatibleProvider.textEmbeddingModel, + ).toHaveBeenCalledWith(modelId, settings); + }); + }); +}); diff --git a/packages/togetherai/src/togetherai-provider.ts b/packages/togetherai/src/togetherai-provider.ts index f59864bf59c2..36c35faa39fb 100644 --- a/packages/togetherai/src/togetherai-provider.ts +++ b/packages/togetherai/src/togetherai-provider.ts @@ -1,7 +1,6 @@ import { OpenAICompatibleProvider, createOpenAICompatible, - OpenAICompatibleChatSettings, OpenAICompatibleProviderSettings, } from '@ai-sdk/openai-compatible'; import { LanguageModelV1, EmbeddingModelV1 } from '@ai-sdk/provider'; @@ -59,43 +58,40 @@ export function createTogetherAI( TogetherAIEmbeddingModelId >(providerOptions); - const togetheraiProvider: TogetherAIProvider = Object.assign( - ( - modelId: TogetherAIChatModelId, - settings?: TogetherAIChatSettings, - ): LanguageModelV1 => { - return openAICompatibleProvider(modelId, settings); - }, - { - chatModel: ( - modelId: TogetherAIChatModelId, - settings?: TogetherAIChatSettings, - ) => { - // TODO(shaper): Perhaps the object generation mode will vary by model. - const defaultSettings: Partial = { - defaultObjectGenerationMode: 'json', - }; - const mergedSettings = { ...defaultSettings, ...settings }; - return openAICompatibleProvider.chatModel(modelId, mergedSettings); - }, + const createChatModel = ( + modelId: TogetherAIChatModelId, + settings?: TogetherAIChatSettings, + ) => { + // TODO(shaper): Perhaps the object generation mode will vary by model. + const defaultSettings: Partial = { + defaultObjectGenerationMode: 'json', + }; + const mergedSettings = { ...defaultSettings, ...settings }; + return openAICompatibleProvider.chatModel(modelId, mergedSettings); + }; + + const createCompletionModel = ( + modelId: TogetherAICompletionModelId, + settings?: TogetherAICompletionSettings, + ) => openAICompatibleProvider.languageModel(modelId, settings); - completionModel: ( - modelId: TogetherAICompletionModelId, - settings?: TogetherAICompletionSettings, - ) => { - return openAICompatibleProvider.languageModel(modelId, settings); - }, + const createTextEmbeddingModel = ( + modelId: TogetherAIEmbeddingModelId, + settings?: TogetherAIEmbeddingSettings, + ) => openAICompatibleProvider.textEmbeddingModel(modelId, settings); + + const provider = function ( + modelId: TogetherAIChatModelId, + settings?: TogetherAIChatSettings, + ) { + return createCompletionModel(modelId, settings); + }; - textEmbeddingModel: ( - modelId: TogetherAIEmbeddingModelId, - settings?: TogetherAIEmbeddingSettings, - ) => { - return openAICompatibleProvider.textEmbeddingModel(modelId, settings); - }, - }, - ) as TogetherAIProvider; + provider.completionModel = createCompletionModel; + provider.chatModel = createChatModel; + provider.textEmbeddingModel = createTextEmbeddingModel; - return togetheraiProvider; + return provider as TogetherAIProvider; } export const togetherai = createTogetherAI(); From 69bb4b1f29c3c240dbe028b199f7849ea9317ec0 Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Wed, 20 Nov 2024 10:58:00 -0800 Subject: [PATCH 08/13] Investigating tools and restoring json. --- .../ai-core/src/generate-object/togetherai.ts | 4 +-- .../ai-core/src/stream-object/togetherai.ts | 33 +++++++++++++++++++ packages/openai-compatible/README.md | 2 +- .../openai-compatible-chat-language-model.ts | 12 ++++--- .../togetherai/src/togetherai-provider.ts | 2 +- 5 files changed, 44 insertions(+), 9 deletions(-) create mode 100644 examples/ai-core/src/stream-object/togetherai.ts diff --git a/examples/ai-core/src/generate-object/togetherai.ts b/examples/ai-core/src/generate-object/togetherai.ts index e37f22397559..6eae7344ca56 100644 --- a/examples/ai-core/src/generate-object/togetherai.ts +++ b/examples/ai-core/src/generate-object/togetherai.ts @@ -5,9 +5,7 @@ import { z } from 'zod'; async function main() { const result = await generateObject({ - model: togetherai.chatModel( - 'meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo', - ), + model: togetherai.chatModel('mistralai/Mistral-7B-Instruct-v0.1'), schema: z.object({ recipe: z.object({ name: z.string(), diff --git a/examples/ai-core/src/stream-object/togetherai.ts b/examples/ai-core/src/stream-object/togetherai.ts new file mode 100644 index 000000000000..8cfb6b557379 --- /dev/null +++ b/examples/ai-core/src/stream-object/togetherai.ts @@ -0,0 +1,33 @@ +import { togetherai } from '@ai-sdk/togetherai'; +import { streamObject } from 'ai'; +import 'dotenv/config'; +import { z } from 'zod'; + +async function main() { + const result = streamObject({ + model: togetherai.chatModel('mistralai/Mistral-7B-Instruct-v0.1'), + schema: z.object({ + characters: z.array( + z.object({ + name: z.string(), + class: z + .string() + .describe('Character class, e.g. warrior, mage, or thief.'), + description: z.string(), + }), + ), + }), + prompt: + 'Generate 3 character descriptions for a fantasy role playing game.', + }); + + for await (const partialObject of result.partialObjectStream) { + console.clear(); + console.log(partialObject); + } + + console.log(); + console.log('Token usage:', await result.usage); +} + +main().catch(console.error); diff --git a/packages/openai-compatible/README.md b/packages/openai-compatible/README.md index a06670c922f4..2bbcb1f19d2c 100644 --- a/packages/openai-compatible/README.md +++ b/packages/openai-compatible/README.md @@ -1,6 +1,6 @@ # AI SDK - OpenAI Compatible Provider -This packge aims to speed and support the implementation of new +This package aims to speed and support the implementation of new OpenAI-compatible providers. The intent is to allow more effective code sharing across multiple concrete provider implementations. diff --git a/packages/openai-compatible/src/openai-compatible-chat-language-model.ts b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts index d73de06cf4af..f74d7130a796 100644 --- a/packages/openai-compatible/src/openai-compatible-chat-language-model.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts @@ -145,10 +145,14 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { } case 'object-json': { - // TODO(shaper): Review vs. OpenAI impl here. - throw new UnsupportedFunctionalityError({ - functionality: 'object-json mode', - }); + return { + args: { + ...baseArgs, + // TODO(shaper): We removed structuredOutputs here. + response_format: { type: 'json_object' }, + }, + warnings, + }; } case 'object-tool': { diff --git a/packages/togetherai/src/togetherai-provider.ts b/packages/togetherai/src/togetherai-provider.ts index 36c35faa39fb..5223ff06fe20 100644 --- a/packages/togetherai/src/togetherai-provider.ts +++ b/packages/togetherai/src/togetherai-provider.ts @@ -64,7 +64,7 @@ export function createTogetherAI( ) => { // TODO(shaper): Perhaps the object generation mode will vary by model. const defaultSettings: Partial = { - defaultObjectGenerationMode: 'json', + defaultObjectGenerationMode: 'tool', }; const mergedSettings = { ...defaultSettings, ...settings }; return openAICompatibleProvider.chatModel(modelId, mergedSettings); From b961925822618527c33ffe8ba8384ee0e0d9be53 Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Wed, 20 Nov 2024 14:23:27 -0800 Subject: [PATCH 09/13] rm logprobs, capitalize types in provider --- ...p-openai-compatible-completion-logprobs.ts | 24 ------ ...nai-compatible-chat-language-model.test.ts | 11 ++- ...mpatible-completion-language-model.test.ts | 80 ++----------------- ...ai-compatible-completion-language-model.ts | 35 -------- .../openai-compatible-completion-settings.ts | 13 --- .../openai-compatible-embedding-model.test.ts | 6 +- .../src/openai-compatible-provider.ts | 78 +++++++++++------- pnpm-lock.yaml | 30 ++++++- 8 files changed, 92 insertions(+), 185 deletions(-) delete mode 100644 packages/openai-compatible/src/map-openai-compatible-completion-logprobs.ts diff --git a/packages/openai-compatible/src/map-openai-compatible-completion-logprobs.ts b/packages/openai-compatible/src/map-openai-compatible-completion-logprobs.ts deleted file mode 100644 index a7ea6279574d..000000000000 --- a/packages/openai-compatible/src/map-openai-compatible-completion-logprobs.ts +++ /dev/null @@ -1,24 +0,0 @@ -import { LanguageModelV1LogProbs } from '@ai-sdk/provider'; - -type OpenAICompatibleCompletionLogProps = { - tokens: string[]; - token_logprobs: number[]; - top_logprobs: Record[] | null; -}; - -export function mapOpenAICompatibleCompletionLogProbs( - logprobs: OpenAICompatibleCompletionLogProps | null | undefined, -): LanguageModelV1LogProbs | undefined { - return logprobs?.tokens.map((token, index) => ({ - token, - logprob: logprobs.token_logprobs[index], - topLogprobs: logprobs.top_logprobs - ? Object.entries(logprobs.top_logprobs[index]).map( - ([token, logprob]) => ({ - token, - logprob, - }), - ) - : [], - })); -} diff --git a/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts b/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts index e40763762244..9cf8e6186471 100644 --- a/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts @@ -13,6 +13,7 @@ const TEST_PROMPT: LanguageModelV1Prompt = [ const provider = createOpenAICompatible({ apiKey: 'test-api-key', baseURL: 'https://my.api.com/v1/', + name: 'test-provider', }); const model = provider('grok-beta'); @@ -309,6 +310,7 @@ describe('doGenerate', () => { const provider = createOpenAICompatible({ apiKey: 'test-api-key', baseURL: 'https://my.api.com/v1/', + name: 'test-provider', headers: { 'Custom-Provider-Header': 'provider-header-value', }, @@ -423,7 +425,7 @@ describe('doStream', () => { `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1702657020,"model":"grok-beta",` + `"system_fingerprint":null,"choices":[{"index":0,"delta":{},"finish_reason":"${finish_reason}"}]}\n\n`, `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1729171479,"model":"grok-beta",` + - `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"${finish_reason}"}],` + + `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"finish_reason":"${finish_reason}"}],` + `"usage":{"queue_time":0.061348671,"prompt_tokens":18,"prompt_time":0.000211569,` + `"completion_tokens":439,"completion_time":0.798181818,"total_tokens":457,"total_time":0.798393387}}\n\n`, 'data: [DONE]\n\n', @@ -490,7 +492,7 @@ describe('doStream', () => { `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\\"}"}}]},` + `"finish_reason":null}]}\n\n`, `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1729171479,"model":"grok-beta",` + - `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],` + + `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}],` + `"usage":{"queue_time":0.061348671,"prompt_tokens":18,"prompt_time":0.000211569,` + `"completion_tokens":439,"completion_time":0.798181818,"total_tokens":457,"total_time":0.798393387}}\n\n`, 'data: [DONE]\n\n', @@ -616,7 +618,7 @@ describe('doStream', () => { `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\\"}"}}]},` + `"finish_reason":null}]}\n\n`, `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1729171479,"model":"grok-beta",` + - `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],` + + `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}],` + `"usage":{"queue_time":0.061348671,"prompt_tokens":18,"prompt_time":0.000211569,` + `"completion_tokens":439,"completion_time":0.798181818,"total_tokens":457,"total_time":0.798393387}}\n\n`, 'data: [DONE]\n\n', @@ -728,7 +730,7 @@ describe('doStream', () => { `"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":"{\\"value\\":\\"Sparkle Day\\"}"}}]},` + `"finish_reason":null}]}\n\n`, `data: {"id":"chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798","object":"chat.completion.chunk","created":1729171479,"model":"grok-beta",` + - `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],` + + `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}],` + `"usage":{"queue_time":0.061348671,"prompt_tokens":18,"prompt_time":0.000211569,` + `"completion_tokens":439,"completion_time":0.798181818,"total_tokens":457,"total_time":0.798393387}}\n\n`, 'data: [DONE]\n\n', @@ -882,6 +884,7 @@ describe('doStream', () => { const provider = createOpenAICompatible({ apiKey: 'test-api-key', baseURL: 'https://my.api.com/v1', + name: 'test-provider', headers: { 'Custom-Provider-Header': 'provider-header-value', }, diff --git a/packages/openai-compatible/src/openai-compatible-completion-language-model.test.ts b/packages/openai-compatible/src/openai-compatible-completion-language-model.test.ts index f1e64a5431ee..e63ac8ccafdc 100644 --- a/packages/openai-compatible/src/openai-compatible-completion-language-model.test.ts +++ b/packages/openai-compatible/src/openai-compatible-completion-language-model.test.ts @@ -5,42 +5,15 @@ import { convertReadableStreamToArray, } from '@ai-sdk/provider-utils/test'; import { createOpenAICompatible } from './openai-compatible-provider'; -import { mapOpenAICompatibleCompletionLogProbs } from './map-openai-compatible-completion-logprobs'; const TEST_PROMPT: LanguageModelV1Prompt = [ { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, ]; -const TEST_LOGPROBS = { - tokens: [' ever', ' after', '.\n\n', 'The', ' end', '.'], - token_logprobs: [ - -0.0664508, -0.014520033, -1.3820221, -0.7890417, -0.5323165, -0.10247037, - ], - top_logprobs: [ - { - ' ever': -0.0664508, - }, - { - ' after': -0.014520033, - }, - { - '.\n\n': -1.3820221, - }, - { - The: -0.7890417, - }, - { - ' end': -0.5323165, - }, - { - '.': -0.10247037, - }, - ] as Record[], -}; - const provider = createOpenAICompatible({ apiKey: 'test-api-key', baseURL: 'https://my.api.com/v1/', + name: 'test-provider', }); const model = provider.completionModel('gpt-3.5-turbo-instruct'); @@ -57,7 +30,6 @@ describe('doGenerate', () => { total_tokens: 34, completion_tokens: 30, }, - logprobs = null, finish_reason = 'stop', id = 'cmpl-96cAM1v77r4jXa4qb2NSmRREV5oWB', created = 1711363706, @@ -69,11 +41,6 @@ describe('doGenerate', () => { total_tokens: number; completion_tokens: number; }; - logprobs?: { - tokens: string[]; - token_logprobs: number[]; - top_logprobs: Record[]; - } | null; finish_reason?: string; id?: string; created?: number; @@ -88,7 +55,6 @@ describe('doGenerate', () => { { text: content, index: 0, - logprobs, finish_reason, }, ], @@ -160,26 +126,6 @@ describe('doGenerate', () => { }); }); - it('should extract logprobs', async () => { - prepareJsonResponse({ logprobs: TEST_LOGPROBS }); - - const provider = createOpenAICompatible({ - apiKey: 'test-api-key', - baseURL: 'https://my.api.com/v1/', - }); - - const response = await provider - .completionModel('gpt-3.5-turbo', { logprobs: 1 }) - .doGenerate({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - }); - expect(response.logprobs).toStrictEqual( - mapOpenAICompatibleCompletionLogProbs(TEST_LOGPROBS), - ); - }); - it('should extract finish reason', async () => { prepareJsonResponse({ content: '', @@ -229,7 +175,7 @@ describe('doGenerate', () => { expect(rawResponse?.headers).toStrictEqual({ // default headers: - 'content-length': '266', + 'content-length': '250', 'content-type': 'application/json', // custom header @@ -258,6 +204,7 @@ describe('doGenerate', () => { const provider = createOpenAICompatible({ apiKey: 'test-api-key', baseURL: 'https://my.api.com/v1/', + name: 'test-provider', // TODO(shaper): Do we need these? // organization: 'test-organization', // project: 'test-project', @@ -301,7 +248,6 @@ describe('doStream', () => { total_tokens: 372, completion_tokens: 362, }, - logprobs = null, }: { content: string[]; usage?: { @@ -309,24 +255,17 @@ describe('doStream', () => { total_tokens: number; completion_tokens: number; }; - logprobs?: { - tokens: string[]; - token_logprobs: number[]; - top_logprobs: Record[]; - } | null; finish_reason?: string; }) { server.responseChunks = [ ...content.map(text => { return ( `data: {"id":"cmpl-96c64EdfhOw8pjFFgVpLuT8k2MtdT","object":"text_completion","created":1711363440,` + - `"choices":[{"text":"${text}","index":0,"logprobs":null,"finish_reason":null}],"model":"gpt-3.5-turbo-instruct"}\n\n` + `"choices":[{"text":"${text}","index":0,"finish_reason":null}],"model":"gpt-3.5-turbo-instruct"}\n\n` ); }), `data: {"id":"cmpl-96c3yLQE1TtZCd6n6OILVmzev8M8H","object":"text_completion","created":1711363310,` + - `"choices":[{"text":"","index":0,"logprobs":${JSON.stringify( - logprobs, - )},"finish_reason":"${finish_reason}"}],"model":"gpt-3.5-turbo-instruct"}\n\n`, + `"choices":[{"text":"","index":0,"finish_reason":"${finish_reason}"}],"model":"gpt-3.5-turbo-instruct"}\n\n`, `data: {"id":"cmpl-96c3yLQE1TtZCd6n6OILVmzev8M8H","object":"text_completion","created":1711363310,` + `"model":"gpt-3.5-turbo-instruct","usage":${JSON.stringify( usage, @@ -344,7 +283,6 @@ describe('doStream', () => { total_tokens: 372, completion_tokens: 362, }, - logprobs: TEST_LOGPROBS, }); const { stream } = await model.doStream({ @@ -368,7 +306,6 @@ describe('doStream', () => { { type: 'finish', finishReason: 'stop', - logprobs: mapOpenAICompatibleCompletionLogProbs(TEST_LOGPROBS), usage: { promptTokens: 10, completionTokens: 362 }, }, ]); @@ -403,7 +340,6 @@ describe('doStream', () => { }, { finishReason: 'error', - logprobs: undefined, type: 'finish', usage: { completionTokens: NaN, @@ -428,7 +364,6 @@ describe('doStream', () => { expect(elements[0].type).toBe('error'); expect(elements[1]).toStrictEqual({ finishReason: 'error', - logprobs: undefined, type: 'finish', usage: { completionTokens: NaN, @@ -499,8 +434,7 @@ describe('doStream', () => { const provider = createOpenAICompatible({ apiKey: 'test-api-key', baseURL: 'https://my.api.com/v1/', - // organization: 'test-organization', - // project: 'test-project', + name: 'test-provider', headers: { 'Custom-Provider-Header': 'provider-header-value', }, @@ -522,8 +456,6 @@ describe('doStream', () => { 'content-type': 'application/json', 'custom-provider-header': 'provider-header-value', 'custom-request-header': 'request-header-value', - // 'openai-organization': 'test-organization', - // 'openai-project': 'test-project', }); }); }); diff --git a/packages/openai-compatible/src/openai-compatible-completion-language-model.ts b/packages/openai-compatible/src/openai-compatible-completion-language-model.ts index 51b0370dab71..2d0f155db44d 100644 --- a/packages/openai-compatible/src/openai-compatible-completion-language-model.ts +++ b/packages/openai-compatible/src/openai-compatible-completion-language-model.ts @@ -2,7 +2,6 @@ import { LanguageModelV1, LanguageModelV1CallWarning, LanguageModelV1FinishReason, - LanguageModelV1LogProbs, LanguageModelV1StreamPart, UnsupportedFunctionalityError, } from '@ai-sdk/provider'; @@ -16,7 +15,6 @@ import { } from '@ai-sdk/provider-utils'; import { z } from 'zod'; import { convertToOpenAICompatibleCompletionPrompt } from './convert-to-openai-compatible-completion-prompt'; -import { mapOpenAICompatibleCompletionLogProbs } from './map-openai-compatible-completion-logprobs'; import { mapOpenAICompatibleFinishReason } from './map-openai-compatible-finish-reason'; import { OpenAICompatibleCompletionModelId, @@ -106,14 +104,6 @@ export class OpenAICompatibleCompletionLanguageModel // model specific settings: echo: this.settings.echo, logit_bias: this.settings.logitBias, - logprobs: - typeof this.settings.logprobs === 'number' - ? this.settings.logprobs - : typeof this.settings.logprobs === 'boolean' - ? this.settings.logprobs - ? 0 - : undefined - : undefined, suffix: this.settings.suffix, user: this.settings.user, @@ -198,7 +188,6 @@ export class OpenAICompatibleCompletionLanguageModel completionTokens: response.usage.completion_tokens, }, finishReason: mapOpenAICompatibleFinishReason(choice.finish_reason), - logprobs: mapOpenAICompatibleCompletionLogProbs(choice.logprobs), rawCall: { rawPrompt, rawSettings }, rawResponse: { headers: responseHeaders }, response: getResponseMetadata(response), @@ -245,7 +234,6 @@ export class OpenAICompatibleCompletionLanguageModel promptTokens: Number.NaN, completionTokens: Number.NaN, }; - let logprobs: LanguageModelV1LogProbs; let isFirstChunk = true; return { @@ -301,21 +289,12 @@ export class OpenAICompatibleCompletionLanguageModel textDelta: choice.text, }); } - - const mappedLogprobs = mapOpenAICompatibleCompletionLogProbs( - choice?.logprobs, - ); - if (mappedLogprobs?.length) { - if (logprobs === undefined) logprobs = []; - logprobs.push(...mappedLogprobs); - } }, flush(controller) { controller.enqueue({ type: 'finish', finishReason, - logprobs, usage, }); }, @@ -340,13 +319,6 @@ const openAICompatibleCompletionResponseSchema = z.object({ z.object({ text: z.string(), finish_reason: z.string(), - logprobs: z - .object({ - tokens: z.array(z.string()), - token_logprobs: z.array(z.number()), - top_logprobs: z.array(z.record(z.string(), z.number())).nullable(), - }) - .nullish(), }), ), usage: z.object({ @@ -367,13 +339,6 @@ const openaiCompatibleCompletionChunkSchema = z.union([ text: z.string(), finish_reason: z.string().nullish(), index: z.number(), - logprobs: z - .object({ - tokens: z.array(z.string()), - token_logprobs: z.array(z.number()), - top_logprobs: z.array(z.record(z.string(), z.number())).nullable(), - }) - .nullish(), }), ), usage: z diff --git a/packages/openai-compatible/src/openai-compatible-completion-settings.ts b/packages/openai-compatible/src/openai-compatible-completion-settings.ts index 518cf4b26d04..6396750d80c6 100644 --- a/packages/openai-compatible/src/openai-compatible-completion-settings.ts +++ b/packages/openai-compatible/src/openai-compatible-completion-settings.ts @@ -22,19 +22,6 @@ token from being generated. */ logitBias?: Record; - /** -Return the log probabilities of the tokens. Including logprobs will increase -the response size and can slow down response times. However, it can -be useful to better understand how the model is behaving. - -Setting to true will return the log probabilities of the tokens that -were generated. - -Setting to a number will return the log probabilities of the top n -tokens that were generated. - */ - logprobs?: boolean | number; - /** The suffix that comes after a completion of inserted text. */ diff --git a/packages/openai-compatible/src/openai-compatible-embedding-model.test.ts b/packages/openai-compatible/src/openai-compatible-embedding-model.test.ts index 2e81b01ffe62..6c7be160ae91 100644 --- a/packages/openai-compatible/src/openai-compatible-embedding-model.test.ts +++ b/packages/openai-compatible/src/openai-compatible-embedding-model.test.ts @@ -11,6 +11,7 @@ const testValues = ['sunny day at the beach', 'rainy day in the city']; const provider = createOpenAICompatible({ apiKey: 'test-api-key', baseURL: 'https://my.api.com/v1/', + name: 'test-provider', }); const model = provider.textEmbeddingModel('text-embedding-3-large'); @@ -108,8 +109,7 @@ describe('doEmbed', () => { const provider = createOpenAICompatible({ apiKey: 'test-api-key', baseURL: 'https://my.api.com/v1/', - // organization: 'test-organization', - // project: 'test-project', + name: 'test-provider', headers: { 'Custom-Provider-Header': 'provider-header-value', }, @@ -129,8 +129,6 @@ describe('doEmbed', () => { 'content-type': 'application/json', 'custom-provider-header': 'provider-header-value', 'custom-request-header': 'request-header-value', - // 'openai-organization': 'test-organization', - // 'openai-project': 'test-project', }); }); }); diff --git a/packages/openai-compatible/src/openai-compatible-provider.ts b/packages/openai-compatible/src/openai-compatible-provider.ts index 7d04ab67e771..72b7ee79f3fe 100644 --- a/packages/openai-compatible/src/openai-compatible-provider.ts +++ b/packages/openai-compatible/src/openai-compatible-provider.ts @@ -16,29 +16,32 @@ import { OpenAICompatibleEmbeddingSettings } from './openai-compatible-embedding import { OpenAICompatibleEmbeddingModel } from './openai-compatible-embedding-model'; export interface OpenAICompatibleProvider< - L extends string = string, - C extends string = string, - E extends string = string, + CHAT_MODEL_IDS extends string = string, + COMPLETION_MODEL_IDS extends string = string, + EMBEDDING_MODEL_IDS extends string = string, > extends ProviderV1 { - (modelId: L, settings?: OpenAICompatibleChatSettings): LanguageModelV1; + ( + modelId: CHAT_MODEL_IDS, + settings?: OpenAICompatibleChatSettings, + ): LanguageModelV1; languageModel( - modelId: L, + modelId: CHAT_MODEL_IDS, settings?: OpenAICompatibleChatSettings, ): LanguageModelV1; chatModel( - modelId: L, + modelId: CHAT_MODEL_IDS, settings?: OpenAICompatibleChatSettings, ): LanguageModelV1; completionModel( - modelId: C, + modelId: COMPLETION_MODEL_IDS, settings?: OpenAICompatibleCompletionSettings, ): LanguageModelV1; textEmbeddingModel( - modelId: E, + modelId: EMBEDDING_MODEL_IDS, settings?: OpenAICompatibleEmbeddingSettings, ): EmbeddingModelV1; } @@ -85,36 +88,49 @@ Provider name. Overrides the `openai` default name for 3rd party providers. Create an OpenAICompatible provider instance. */ export function createOpenAICompatible< - L extends string, - C extends string, - E extends string, + CHAT_MODEL_IDS extends string, + COMPLETION_MODEL_IDS extends string, + EMBEDDING_MODEL_IDS extends string, >( options: OpenAICompatibleProviderSettings, -): OpenAICompatibleProvider { - // TODO(shaper): - // - consider throwing if baseUrl, name, sufficient api key info not available - // - force only 'compatible' -- look into whether we can remove some 'strict' logic/configs entirely +): OpenAICompatibleProvider< + CHAT_MODEL_IDS, + COMPLETION_MODEL_IDS, + EMBEDDING_MODEL_IDS +> { + // TODO(shaper): force only 'compatible' -- look into whether we can remove some 'strict' logic/configs entirely + + if (!options.baseURL) { + throw new Error('Base URL is required'); + } const baseURL = withoutTrailingSlash(options.baseURL); - const providerName = options.name ?? 'openaiCompatible'; + + if (!options.name) { + throw new Error('Provider name is required'); + } + const providerName = options.name; + + const apiKey = loadApiKey({ + apiKey: options.apiKey, + environmentVariableName: options.apiKeyEnvVarName ?? '', + description: options.apiKeyEnvVarDescription ?? '', + }); + if (!apiKey) { + throw new Error('API key is required'); + } const getHeaders = () => ({ - Authorization: `Bearer ${loadApiKey({ - apiKey: options.apiKey, - environmentVariableName: options.apiKeyEnvVarName ?? '', - description: options.apiKeyEnvVarDescription ?? '', - })}`, + Authorization: `Bearer ${apiKey}`, ...options.headers, }); const createLanguageModel = ( - modelId: L, + modelId: CHAT_MODEL_IDS, settings?: OpenAICompatibleChatSettings, ) => createChatModel(modelId, settings); - // TODO(shaper): Change provider strings below to allow concrete impls to specify. - // See openai-provider.ts:141 and subsequent configs. const createChatModel = ( - modelId: L, + modelId: CHAT_MODEL_IDS, settings: OpenAICompatibleChatSettings = {}, ) => new OpenAICompatibleChatLanguageModel(modelId, settings, { @@ -125,7 +141,7 @@ export function createOpenAICompatible< }); const createCompletionModel = ( - modelId: C, + modelId: COMPLETION_MODEL_IDS, settings: OpenAICompatibleCompletionSettings = {}, ) => new OpenAICompatibleCompletionLanguageModel(modelId, settings, { @@ -137,7 +153,7 @@ export function createOpenAICompatible< }); const createEmbeddingModel = ( - modelId: E, + modelId: EMBEDDING_MODEL_IDS, settings: OpenAICompatibleEmbeddingSettings = {}, ) => new OpenAICompatibleEmbeddingModel(modelId, settings, { @@ -148,7 +164,7 @@ export function createOpenAICompatible< }); const provider = function ( - modelId: L, + modelId: CHAT_MODEL_IDS, settings?: OpenAICompatibleChatSettings, ) { return createLanguageModel(modelId, settings); @@ -159,5 +175,9 @@ export function createOpenAICompatible< provider.completionModel = createCompletionModel; provider.textEmbeddingModel = createEmbeddingModel; - return provider as OpenAICompatibleProvider; + return provider as OpenAICompatibleProvider< + CHAT_MODEL_IDS, + COMPLETION_MODEL_IDS, + EMBEDDING_MODEL_IDS + >; } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 7c0711bc319e..ab1608dd406b 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1371,7 +1371,7 @@ importers: version: link:../provider '@ai-sdk/provider-utils': specifier: 2.0.0 - version: link:../provider-utils + version: 2.0.0(zod@3.23.8) devDependencies: '@types/node': specifier: ^18 @@ -1627,7 +1627,7 @@ importers: version: link:../provider '@ai-sdk/provider-utils': specifier: 2.0.0 - version: link:../provider-utils + version: 2.0.0(zod@3.23.8) devDependencies: '@types/node': specifier: ^18 @@ -1781,6 +1781,19 @@ packages: '@adobe/css-tools@4.4.0': resolution: {integrity: sha512-Ff9+ksdQQB3rMncgqDK78uLznstjyfIf2Arnh22pW8kBpLs6rpKDwgnZT46hin5Hl1WzazzK64DOrhSwYpS7bQ==} + '@ai-sdk/provider-utils@2.0.0': + resolution: {integrity: sha512-uITgVJByhtzuQU2ZW+2CidWRmQqTUTp6KADevy+4aRnmILZxY2LCt+UZ/ZtjJqq0MffwkuQPPY21ExmFAQ6kKA==} + engines: {node: '>=18'} + peerDependencies: + zod: ^3.0.0 + peerDependenciesMeta: + zod: + optional: true + + '@ai-sdk/provider@1.0.0': + resolution: {integrity: sha512-Sj29AzooJ7SYvhPd+AAWt/E7j63E9+AzRnoMHUaJPRYzOd/WDrVNxxv85prF9gDcQ7XPVlSk9j6oAZV9/DXYpA==} + engines: {node: '>=18'} + '@alloc/quick-lru@5.2.0': resolution: {integrity: sha512-UrcABB+4bUrFABwbluTIBErXwvbsU/V7TZWfmbgJfbkwiBuziS9gxdODUyuiecfdGQ85jglMW6juS3+z5TsKLw==} engines: {node: '>=10'} @@ -13762,6 +13775,19 @@ snapshots: '@adobe/css-tools@4.4.0': {} + '@ai-sdk/provider-utils@2.0.0(zod@3.23.8)': + dependencies: + '@ai-sdk/provider': 1.0.0 + eventsource-parser: 3.0.0 + nanoid: 5.0.8 + secure-json-parse: 2.7.0 + optionalDependencies: + zod: 3.23.8 + + '@ai-sdk/provider@1.0.0': + dependencies: + json-schema: 0.4.0 + '@alloc/quick-lru@5.2.0': {} '@ampproject/remapping@2.2.1': From 26ca652cd453e47e4a10ce6647053284def65da1 Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Wed, 20 Nov 2024 14:48:25 -0800 Subject: [PATCH 10/13] rm compatibility --- .../src/openai-compatible-completion-language-model.ts | 7 ------- .../openai-compatible/src/openai-compatible-provider.ts | 3 --- 2 files changed, 10 deletions(-) diff --git a/packages/openai-compatible/src/openai-compatible-completion-language-model.ts b/packages/openai-compatible/src/openai-compatible-completion-language-model.ts index 2d0f155db44d..8b5f7d5c0690 100644 --- a/packages/openai-compatible/src/openai-compatible-completion-language-model.ts +++ b/packages/openai-compatible/src/openai-compatible-completion-language-model.ts @@ -28,7 +28,6 @@ import { getResponseMetadata } from './get-response-metadata'; type OpenAICompatibleCompletionConfig = { provider: string; - compatibility: 'strict' | 'compatible'; headers: () => Record; url: (options: { modelId: string; path: string }) => string; fetch?: FetchFunction; @@ -204,12 +203,6 @@ export class OpenAICompatibleCompletionLanguageModel const body = { ...args, stream: true, - - // only include stream_options when in strict compatibility mode: - stream_options: - this.config.compatibility === 'strict' - ? { include_usage: true } - : undefined, }; const { responseHeaders, value: response } = await postJsonToApi({ diff --git a/packages/openai-compatible/src/openai-compatible-provider.ts b/packages/openai-compatible/src/openai-compatible-provider.ts index 72b7ee79f3fe..673a9216d3fe 100644 --- a/packages/openai-compatible/src/openai-compatible-provider.ts +++ b/packages/openai-compatible/src/openai-compatible-provider.ts @@ -98,8 +98,6 @@ export function createOpenAICompatible< COMPLETION_MODEL_IDS, EMBEDDING_MODEL_IDS > { - // TODO(shaper): force only 'compatible' -- look into whether we can remove some 'strict' logic/configs entirely - if (!options.baseURL) { throw new Error('Base URL is required'); } @@ -146,7 +144,6 @@ export function createOpenAICompatible< ) => new OpenAICompatibleCompletionLanguageModel(modelId, settings, { provider: `${providerName}.completion`, - compatibility: 'compatible', url: ({ path }) => `${baseURL}${path}`, headers: getHeaders, fetch: options.fetch, From b7aed86390ce454caa6241b5a33a89b208ec4554 Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Wed, 20 Nov 2024 15:44:18 -0800 Subject: [PATCH 11/13] move default object generation mode to model config. --- .../openai-compatible-chat-language-model.ts | 9 ++- .../src/openai-compatible-chat-settings.ts | 11 ---- .../src/openai-compatible-provider.ts | 56 +++++++++++-------- .../src/togetherai-provider.test.ts | 3 +- .../togetherai/src/togetherai-provider.ts | 11 ++-- 5 files changed, 47 insertions(+), 43 deletions(-) diff --git a/packages/openai-compatible/src/openai-compatible-chat-language-model.ts b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts index f74d7130a796..b5737a9fa5e9 100644 --- a/packages/openai-compatible/src/openai-compatible-chat-language-model.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts @@ -36,6 +36,13 @@ type OpenAICompatibleChatConfig = { headers: () => Record; url: (options: { modelId: string; path: string }) => string; fetch?: FetchFunction; + + /** +Default object generation mode that should be used with this model when +no mode is specified. Should be the mode with the best results for this +model. `undefined` can be specified if object generation is not supported. + */ + defaultObjectGenerationMode?: 'json' | 'tool' | undefined; }; export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { @@ -59,7 +66,7 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { } get defaultObjectGenerationMode(): 'json' | 'tool' | undefined { - return this.settings.defaultObjectGenerationMode; + return this.config.defaultObjectGenerationMode; } get provider(): string { diff --git a/packages/openai-compatible/src/openai-compatible-chat-settings.ts b/packages/openai-compatible/src/openai-compatible-chat-settings.ts index e3ab3d164370..face0c1e700b 100644 --- a/packages/openai-compatible/src/openai-compatible-chat-settings.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-settings.ts @@ -6,15 +6,4 @@ A unique identifier representing your end-user, which can help the provider to monitor and detect abuse. */ user?: string; - - /** -Default object generation mode that should be used with this model when -no mode is specified. Should be the mode with the best results for this -model. `undefined` can be returned if object generation is not supported. - -This is needed to generate the best objects possible w/o requiring the -user to explicitly specify the object generation mode. - */ - // TODO(shaper): This is really model-specific, move to config or elsewhere? - defaultObjectGenerationMode?: 'json' | 'tool' | undefined; } diff --git a/packages/openai-compatible/src/openai-compatible-provider.ts b/packages/openai-compatible/src/openai-compatible-provider.ts index 673a9216d3fe..efbf2b4d41b1 100644 --- a/packages/openai-compatible/src/openai-compatible-provider.ts +++ b/packages/openai-compatible/src/openai-compatible-provider.ts @@ -33,6 +33,7 @@ export interface OpenAICompatibleProvider< chatModel( modelId: CHAT_MODEL_IDS, settings?: OpenAICompatibleChatSettings, + options?: { defaultObjectGenerationMode: 'json' | 'tool' | undefined }, ): LanguageModelV1; completionModel( @@ -113,15 +114,27 @@ export function createOpenAICompatible< environmentVariableName: options.apiKeyEnvVarName ?? '', description: options.apiKeyEnvVarDescription ?? '', }); - if (!apiKey) { - throw new Error('API key is required'); - } - const getHeaders = () => ({ Authorization: `Bearer ${apiKey}`, ...options.headers, }); + interface CommonModelConfig { + provider: string; + url: ({ path }: { path: string }) => string; + headers: () => Record; + fetch?: FetchFunction; + } + + const getCommonModelConfig = (modelType: string): CommonModelConfig => { + return { + provider: `${providerName}.${modelType}`, + url: ({ path }) => `${baseURL}${path}`, + headers: getHeaders, + fetch: options.fetch, + }; + }; + const createLanguageModel = ( modelId: CHAT_MODEL_IDS, settings?: OpenAICompatibleChatSettings, @@ -130,42 +143,37 @@ export function createOpenAICompatible< const createChatModel = ( modelId: CHAT_MODEL_IDS, settings: OpenAICompatibleChatSettings = {}, + options: { defaultObjectGenerationMode?: 'tool' | 'json' | undefined } = {}, ) => new OpenAICompatibleChatLanguageModel(modelId, settings, { - provider: `${providerName}.chat`, - url: ({ path }) => `${baseURL}${path}`, - headers: getHeaders, - fetch: options.fetch, + ...getCommonModelConfig('chat'), + defaultObjectGenerationMode: options.defaultObjectGenerationMode, }); const createCompletionModel = ( modelId: COMPLETION_MODEL_IDS, settings: OpenAICompatibleCompletionSettings = {}, ) => - new OpenAICompatibleCompletionLanguageModel(modelId, settings, { - provider: `${providerName}.completion`, - url: ({ path }) => `${baseURL}${path}`, - headers: getHeaders, - fetch: options.fetch, - }); + new OpenAICompatibleCompletionLanguageModel( + modelId, + settings, + getCommonModelConfig('completion'), + ); const createEmbeddingModel = ( modelId: EMBEDDING_MODEL_IDS, settings: OpenAICompatibleEmbeddingSettings = {}, ) => - new OpenAICompatibleEmbeddingModel(modelId, settings, { - provider: `${providerName}.embedding`, - url: ({ path }) => `${baseURL}${path}`, - headers: getHeaders, - fetch: options.fetch, - }); + new OpenAICompatibleEmbeddingModel( + modelId, + settings, + getCommonModelConfig('embedding'), + ); - const provider = function ( + const provider = ( modelId: CHAT_MODEL_IDS, settings?: OpenAICompatibleChatSettings, - ) { - return createLanguageModel(modelId, settings); - }; + ) => createLanguageModel(modelId, settings); provider.languageModel = createLanguageModel; provider.chatModel = createChatModel; diff --git a/packages/togetherai/src/togetherai-provider.test.ts b/packages/togetherai/src/togetherai-provider.test.ts index 8a44903f3013..ab87333463c7 100644 --- a/packages/togetherai/src/togetherai-provider.test.ts +++ b/packages/togetherai/src/togetherai-provider.test.ts @@ -75,7 +75,8 @@ describe('TogetherAIProvider', () => { expect(model).toBe(mockLanguageModel); expect(mockOpenAICompatibleProvider.chatModel).toHaveBeenCalledWith( modelId, - { defaultObjectGenerationMode: 'json', ...settings }, + settings, + { defaultObjectGenerationMode: 'json' }, ); }); }); diff --git a/packages/togetherai/src/togetherai-provider.ts b/packages/togetherai/src/togetherai-provider.ts index 5223ff06fe20..91f14e8692e4 100644 --- a/packages/togetherai/src/togetherai-provider.ts +++ b/packages/togetherai/src/togetherai-provider.ts @@ -29,6 +29,7 @@ export interface TogetherAIProvider chatModel( modelId: TogetherAIChatModelId, settings?: TogetherAIChatSettings, + options?: { defaultObjectGenerationMode: 'json' | 'tool' | undefined }, ): LanguageModelV1; completionModel( @@ -62,12 +63,10 @@ export function createTogetherAI( modelId: TogetherAIChatModelId, settings?: TogetherAIChatSettings, ) => { - // TODO(shaper): Perhaps the object generation mode will vary by model. - const defaultSettings: Partial = { - defaultObjectGenerationMode: 'tool', - }; - const mergedSettings = { ...defaultSettings, ...settings }; - return openAICompatibleProvider.chatModel(modelId, mergedSettings); + // TODO(shaper): Likely need a registry of model to object generation mode. + return openAICompatibleProvider.chatModel(modelId, settings, { + defaultObjectGenerationMode: 'json', + }); }; const createCompletionModel = ( From 2184a4a418ea1f33b8541c17424d668e322be45f Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Wed, 20 Nov 2024 16:26:31 -0800 Subject: [PATCH 12/13] restore responseFormat, fix openAI casing, move embed settings to config --- .../openai-compatible-chat-language-model.ts | 16 ++++++---------- ...nai-compatible-completion-language-model.ts | 15 +++++++-------- .../src/openai-compatible-embedding-model.ts | 18 ++++++++++++++---- .../openai-compatible-embedding-settings.ts | 10 ---------- .../src/openai-compatible-error.ts | 8 ++++---- 5 files changed, 31 insertions(+), 36 deletions(-) diff --git a/packages/openai-compatible/src/openai-compatible-chat-language-model.ts b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts index b5737a9fa5e9..a9674a7ff2a8 100644 --- a/packages/openai-compatible/src/openai-compatible-chat-language-model.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts @@ -5,7 +5,6 @@ import { LanguageModelV1FinishReason, LanguageModelV1ProviderMetadata, LanguageModelV1StreamPart, - UnsupportedFunctionalityError, } from '@ai-sdk/provider'; import { FetchFunction, @@ -25,8 +24,8 @@ import { OpenAICompatibleChatSettings, } from './openai-compatible-chat-settings'; import { - openAICompatibleErrorDataSchema, - openAICompatibleFailedResponseHandler, + openaiCompatibleErrorDataSchema, + openaiCompatibleFailedResponseHandler, } from './openai-compatible-error'; import { prepareTools } from './openai-compatible-prepare-tools'; import { mapOpenAICompatibleFinishReason } from './map-openai-compatible-finish-reason'; @@ -130,9 +129,7 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { // response format: response_format: - // TODO(shaper): Review vs. OpenAI impl here. - // json object response format is not currently supported - undefined, + responseFormat?.type === 'json' ? { type: 'json_object' } : undefined, // messages: messages: convertToOpenAICompatibleChatMessages(prompt), @@ -155,7 +152,6 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { return { args: { ...baseArgs, - // TODO(shaper): We removed structuredOutputs here. response_format: { type: 'json_object' }, }, warnings, @@ -206,7 +202,7 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { }), headers: combineHeaders(this.config.headers(), options.headers), body: args, - failedResponseHandler: openAICompatibleFailedResponseHandler, + failedResponseHandler: openaiCompatibleFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( OpenAICompatibleChatResponseSchema, ), @@ -255,7 +251,7 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { ...args, stream: true, }, - failedResponseHandler: openAICompatibleFailedResponseHandler, + failedResponseHandler: openaiCompatibleFailedResponseHandler, successfulResponseHandler: createEventSourceResponseHandler( OpenAICompatibleChatChunkSchema, ), @@ -544,5 +540,5 @@ const OpenAICompatibleChatChunkSchema = z.union([ }) .nullish(), }), - openAICompatibleErrorDataSchema, + openaiCompatibleErrorDataSchema, ]); diff --git a/packages/openai-compatible/src/openai-compatible-completion-language-model.ts b/packages/openai-compatible/src/openai-compatible-completion-language-model.ts index 8b5f7d5c0690..624c3fe80c0f 100644 --- a/packages/openai-compatible/src/openai-compatible-completion-language-model.ts +++ b/packages/openai-compatible/src/openai-compatible-completion-language-model.ts @@ -21,8 +21,8 @@ import { OpenAICompatibleCompletionSettings, } from './openai-compatible-completion-settings'; import { - openAICompatibleErrorDataSchema, - openAICompatibleFailedResponseHandler, + openaiCompatibleErrorDataSchema, + openaiCompatibleFailedResponseHandler, } from './openai-compatible-error'; import { getResponseMetadata } from './get-response-metadata'; @@ -169,9 +169,9 @@ export class OpenAICompatibleCompletionLanguageModel }), headers: combineHeaders(this.config.headers(), options.headers), body: args, - failedResponseHandler: openAICompatibleFailedResponseHandler, + failedResponseHandler: openaiCompatibleFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( - openAICompatibleCompletionResponseSchema, + openaiCompatibleCompletionResponseSchema, ), abortSignal: options.abortSignal, fetch: this.config.fetch, @@ -212,7 +212,7 @@ export class OpenAICompatibleCompletionLanguageModel }), headers: combineHeaders(this.config.headers(), options.headers), body, - failedResponseHandler: openAICompatibleFailedResponseHandler, + failedResponseHandler: openaiCompatibleFailedResponseHandler, successfulResponseHandler: createEventSourceResponseHandler( openaiCompatibleCompletionChunkSchema, ), @@ -303,8 +303,7 @@ export class OpenAICompatibleCompletionLanguageModel // limited version of the schema, focussed on what is needed for the implementation // this approach limits breakages when the API changes and increases efficiency -// TODO(shaper): Fix naming to match others e.g. 'openai' -const openAICompatibleCompletionResponseSchema = z.object({ +const openaiCompatibleCompletionResponseSchema = z.object({ id: z.string().nullish(), created: z.number().nullish(), model: z.string().nullish(), @@ -341,5 +340,5 @@ const openaiCompatibleCompletionChunkSchema = z.union([ }) .nullish(), }), - openAICompatibleErrorDataSchema, + openaiCompatibleErrorDataSchema, ]); diff --git a/packages/openai-compatible/src/openai-compatible-embedding-model.ts b/packages/openai-compatible/src/openai-compatible-embedding-model.ts index 53184dc405b5..8dbe75393854 100644 --- a/packages/openai-compatible/src/openai-compatible-embedding-model.ts +++ b/packages/openai-compatible/src/openai-compatible-embedding-model.ts @@ -13,9 +13,19 @@ import { OpenAICompatibleEmbeddingModelId, OpenAICompatibleEmbeddingSettings, } from './openai-compatible-embedding-settings'; -import { openAICompatibleFailedResponseHandler } from './openai-compatible-error'; +import { openaiCompatibleFailedResponseHandler } from './openai-compatible-error'; type OpenAIEmbeddingConfig = { + /** +Override the maximum number of embeddings per call. + */ + maxEmbeddingsPerCall?: number; + + /** +Override the parallelism of embedding calls. + */ + supportsParallelCalls?: boolean; + provider: string; url: (options: { modelId: string; path: string }) => string; headers: () => Record; @@ -36,11 +46,11 @@ export class OpenAICompatibleEmbeddingModel } get maxEmbeddingsPerCall(): number { - return this.settings.maxEmbeddingsPerCall ?? 2048; + return this.config.maxEmbeddingsPerCall ?? 2048; } get supportsParallelCalls(): boolean { - return this.settings.supportsParallelCalls ?? true; + return this.config.supportsParallelCalls ?? true; } constructor( @@ -82,7 +92,7 @@ export class OpenAICompatibleEmbeddingModel dimensions: this.settings.dimensions, user: this.settings.user, }, - failedResponseHandler: openAICompatibleFailedResponseHandler, + failedResponseHandler: openaiCompatibleFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( openaiTextEmbeddingResponseSchema, ), diff --git a/packages/openai-compatible/src/openai-compatible-embedding-settings.ts b/packages/openai-compatible/src/openai-compatible-embedding-settings.ts index cce45e0667d5..4fab037bcbad 100644 --- a/packages/openai-compatible/src/openai-compatible-embedding-settings.ts +++ b/packages/openai-compatible/src/openai-compatible-embedding-settings.ts @@ -1,16 +1,6 @@ export type OpenAICompatibleEmbeddingModelId = string; export interface OpenAICompatibleEmbeddingSettings { - /** -Override the maximum number of embeddings per call. - */ - maxEmbeddingsPerCall?: number; - - /** -Override the parallelism of embedding calls. - */ - supportsParallelCalls?: boolean; - /** The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. diff --git a/packages/openai-compatible/src/openai-compatible-error.ts b/packages/openai-compatible/src/openai-compatible-error.ts index 883f5993cf8e..5acb8d139fff 100644 --- a/packages/openai-compatible/src/openai-compatible-error.ts +++ b/packages/openai-compatible/src/openai-compatible-error.ts @@ -3,17 +3,17 @@ import { createJsonErrorResponseHandler } from '@ai-sdk/provider-utils'; // TODO(shaper): Reconcile this with openai-error.ts. We derived from `xai`. -export const openAICompatibleErrorDataSchema = z.object({ +export const openaiCompatibleErrorDataSchema = z.object({ code: z.string(), error: z.string(), }); export type OpenAICompatibleErrorData = z.infer< - typeof openAICompatibleErrorDataSchema + typeof openaiCompatibleErrorDataSchema >; -export const openAICompatibleFailedResponseHandler = +export const openaiCompatibleFailedResponseHandler = createJsonErrorResponseHandler({ - errorSchema: openAICompatibleErrorDataSchema, + errorSchema: openaiCompatibleErrorDataSchema, errorToMessage: data => data.error, }); From e326ae688bc9305a4f9fb2339e4ff057a1e0630d Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Wed, 20 Nov 2024 17:07:59 -0800 Subject: [PATCH 13/13] restore original openai error data schema --- ...openai-compatible-chat-language-model.test.ts | 8 ++++---- .../src/openai-compatible-chat-language-model.ts | 2 +- .../src/openai-compatible-error.ts | 16 +++++++++++----- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts b/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts index 9cf8e6186471..a56ab6aad1c2 100644 --- a/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-language-model.test.ts @@ -788,7 +788,7 @@ describe('doStream', () => { it('should handle error stream parts', async () => { server.responseChunks = [ - `data: {"code":"Client specified an invalid argument","error":"Incorrect API key provided: as***T7. You can obtain an API key from https://console.x.ai."}\n\n`, + `data: {"error": {"message": "Incorrect API key provided: as***T7. You can obtain an API key from https://console.api.com.", "code": "Client specified an invalid argument"}}\n\n`, 'data: [DONE]\n\n', ]; @@ -802,14 +802,14 @@ describe('doStream', () => { { type: 'error', error: - 'Incorrect API key provided: as***T7. You can obtain an API key from https://console.x.ai.', + 'Incorrect API key provided: as***T7. You can obtain an API key from https://console.api.com.', }, { - finishReason: 'error', type: 'finish', + finishReason: 'error', usage: { - completionTokens: NaN, promptTokens: NaN, + completionTokens: NaN, }, }, ]); diff --git a/packages/openai-compatible/src/openai-compatible-chat-language-model.ts b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts index a9674a7ff2a8..4ccdcfc4a6fa 100644 --- a/packages/openai-compatible/src/openai-compatible-chat-language-model.ts +++ b/packages/openai-compatible/src/openai-compatible-chat-language-model.ts @@ -300,7 +300,7 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 { // handle error chunks: if ('error' in value) { finishReason = 'error'; - controller.enqueue({ type: 'error', error: value.error }); + controller.enqueue({ type: 'error', error: value.error.message }); return; } diff --git a/packages/openai-compatible/src/openai-compatible-error.ts b/packages/openai-compatible/src/openai-compatible-error.ts index 5acb8d139fff..28a7d5b6c9d1 100644 --- a/packages/openai-compatible/src/openai-compatible-error.ts +++ b/packages/openai-compatible/src/openai-compatible-error.ts @@ -1,11 +1,17 @@ import { z } from 'zod'; import { createJsonErrorResponseHandler } from '@ai-sdk/provider-utils'; -// TODO(shaper): Reconcile this with openai-error.ts. We derived from `xai`. - export const openaiCompatibleErrorDataSchema = z.object({ - code: z.string(), - error: z.string(), + error: z.object({ + message: z.string(), + + // The additional information below is handled loosely to support + // OpenAI-compatible providers that have slightly different error + // responses: + type: z.string().nullish(), + param: z.any().nullish(), + code: z.union([z.string(), z.number()]).nullish(), + }), }); export type OpenAICompatibleErrorData = z.infer< @@ -15,5 +21,5 @@ export type OpenAICompatibleErrorData = z.infer< export const openaiCompatibleFailedResponseHandler = createJsonErrorResponseHandler({ errorSchema: openaiCompatibleErrorDataSchema, - errorToMessage: data => data.error, + errorToMessage: data => data.error.message, });