diff --git a/.changeset/fuzzy-colts-sleep.md b/.changeset/fuzzy-colts-sleep.md new file mode 100644 index 000000000..54a17396c --- /dev/null +++ b/.changeset/fuzzy-colts-sleep.md @@ -0,0 +1,5 @@ +--- +"@browserbasehq/stagehand": patch +--- + +Refactor LLMProvider to use MODEL_PROVIDER_MAP for model-provider mapping diff --git a/packages/core/lib/v3/llm/LLMProvider.ts b/packages/core/lib/v3/llm/LLMProvider.ts index 7c16f2118..492cca588 100644 --- a/packages/core/lib/v3/llm/LLMProvider.ts +++ b/packages/core/lib/v3/llm/LLMProvider.ts @@ -7,7 +7,9 @@ import { LogLine } from "../types/public/logs"; import { AvailableModel, ClientOptions, + KnownModel, ModelProvider, + MODEL_PROVIDER_MAP, } from "../types/public/model"; import { AISdkClient } from "./aisdk"; import { AnthropicClient } from "./AnthropicClient"; @@ -58,41 +60,6 @@ const AISDKProvidersWithAPIKey: Record = { perplexity: createPerplexity, }; -const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { - "gpt-4.1": "openai", - "gpt-4.1-mini": "openai", - "gpt-4.1-nano": "openai", - "o4-mini": "openai", - //prettier-ignore - "o3": "openai", - "o3-mini": "openai", - //prettier-ignore - "o1": "openai", - "o1-mini": "openai", - "gpt-4o": "openai", - "gpt-4o-mini": "openai", - "gpt-4o-2024-08-06": "openai", - "gpt-4.5-preview": "openai", - "o1-preview": "openai", - "claude-3-5-sonnet-latest": "anthropic", - "claude-3-5-sonnet-20240620": "anthropic", - "claude-3-5-sonnet-20241022": "anthropic", - "claude-3-7-sonnet-20250219": "anthropic", - "claude-3-7-sonnet-latest": "anthropic", - "cerebras-llama-3.3-70b": "cerebras", - "cerebras-llama-3.1-8b": "cerebras", - "groq-llama-3.3-70b-versatile": "groq", - "groq-llama-3.3-70b-specdec": "groq", - "moonshotai/kimi-k2-instruct": "groq", - "gemini-1.5-flash": "google", - "gemini-1.5-pro": "google", - "gemini-1.5-flash-8b": "google", - "gemini-2.0-flash-lite": "google", - "gemini-2.0-flash": "google", - "gemini-2.5-flash-preview-04-17": "google", - "gemini-2.5-pro-preview-03-25": "google", -}; - export function getAISDKLanguageModel( subProvider: string, subModelName: string, @@ -156,11 +123,11 @@ export class LLMProvider { }); } - const provider = modelToProviderMap[modelName]; + const provider = MODEL_PROVIDER_MAP[modelName as KnownModel]; if (!provider) { - throw new UnsupportedModelError(Object.keys(modelToProviderMap)); + throw new UnsupportedModelError(Object.keys(MODEL_PROVIDER_MAP)); } - const availableModel = modelName as AvailableModel; + const availableModel = modelName as KnownModel; switch (provider) { case "openai": return new OpenAIClient({ @@ -194,7 +161,7 @@ export class LLMProvider { }); default: throw new UnsupportedModelProviderError([ - ...new Set(Object.values(modelToProviderMap)), + ...new Set(Object.values(MODEL_PROVIDER_MAP)), ]); } } @@ -207,7 +174,8 @@ export class LLMProvider { return "aisdk"; } } - const provider = modelToProviderMap[modelName]; + const provider = + MODEL_PROVIDER_MAP[modelName as keyof typeof MODEL_PROVIDER_MAP]; return provider; } } diff --git a/packages/core/lib/v3/types/public/model.ts b/packages/core/lib/v3/types/public/model.ts index ea8aa57da..96608ead3 100644 --- a/packages/core/lib/v3/types/public/model.ts +++ b/packages/core/lib/v3/types/public/model.ts @@ -26,46 +26,48 @@ export type AISDKCustomProvider = (options: { apiKey: string; }) => AISDKProvider; -export type AvailableModel = - | "gpt-4.1" - | "gpt-4.1-mini" - | "gpt-4.1-nano" - | "o4-mini" - | "o3" - | "o3-mini" - | "o1" - | "o1-mini" - | "gpt-4o" - | "gpt-4o-mini" - | "gpt-4o-2024-08-06" - | "gpt-4.5-preview" - | "o1-preview" - | "claude-3-5-sonnet-latest" - | "claude-3-5-sonnet-20241022" - | "claude-3-5-sonnet-20240620" - | "claude-3-7-sonnet-latest" - | "claude-3-7-sonnet-20250219" - | "cerebras-llama-3.3-70b" - | "cerebras-llama-3.1-8b" - | "groq-llama-3.3-70b-versatile" - | "groq-llama-3.3-70b-specdec" - | "gemini-1.5-flash" - | "gemini-1.5-pro" - | "gemini-1.5-flash-8b" - | "gemini-2.0-flash-lite" - | "gemini-2.0-flash" - | "gemini-2.5-flash-preview-04-17" - | "gemini-2.5-pro-preview-03-25" - | string; +export const MODEL_PROVIDER_MAP = { + "gpt-4.1": "openai", + "gpt-4.1-mini": "openai", + "gpt-4.1-nano": "openai", + "o4-mini": "openai", + //prettier-ignore + o3: "openai", + "o3-mini": "openai", + //prettier-ignore + o1: "openai", + "o1-mini": "openai", + "gpt-4o": "openai", + "gpt-4o-mini": "openai", + "gpt-4o-2024-08-06": "openai", + "gpt-4.5-preview": "openai", + "o1-preview": "openai", + "claude-3-5-sonnet-latest": "anthropic", + "claude-3-5-sonnet-20241022": "anthropic", + "claude-3-5-sonnet-20240620": "anthropic", + "claude-3-7-sonnet-latest": "anthropic", + "claude-3-7-sonnet-20250219": "anthropic", + "cerebras-llama-3.3-70b": "cerebras", + "cerebras-llama-3.1-8b": "cerebras", + "groq-llama-3.3-70b-versatile": "groq", + "groq-llama-3.3-70b-specdec": "groq", + "moonshotai/kimi-k2-instruct": "groq", + "gemini-1.5-flash": "google", + "gemini-1.5-pro": "google", + "gemini-1.5-flash-8b": "google", + "gemini-2.0-flash-lite": "google", + "gemini-2.0-flash": "google", + "gemini-2.5-flash-preview-04-17": "google", + "gemini-2.5-pro-preview-03-25": "google", +} as const; export type ModelProvider = - | "openai" - | "anthropic" - | "cerebras" - | "groq" - | "google" + | (typeof MODEL_PROVIDER_MAP)[keyof typeof MODEL_PROVIDER_MAP] | "aisdk"; +export type KnownModel = keyof typeof MODEL_PROVIDER_MAP; +export type AvailableModel = KnownModel | (string & {}); + export type ClientOptions = OpenAIClientOptions | AnthropicClientOptions; export type ModelConfiguration = diff --git a/packages/core/tests/public-types.test.ts b/packages/core/tests/public-types.test.ts index ec36d3b07..b2c0e3eaf 100644 --- a/packages/core/tests/public-types.test.ts +++ b/packages/core/tests/public-types.test.ts @@ -24,6 +24,7 @@ const publicApiShape = { LLMResponseError: Stagehand.LLMResponseError, LOG_LEVEL_NAMES: Stagehand.LOG_LEVEL_NAMES, MCPConnectionError: Stagehand.MCPConnectionError, + MODEL_PROVIDER_MAP: Stagehand.MODEL_PROVIDER_MAP, MissingEnvironmentVariableError: Stagehand.MissingEnvironmentVariableError, MissingLLMConfigurationError: Stagehand.MissingLLMConfigurationError, PageNotFoundError: Stagehand.PageNotFoundError,