diff --git a/langchain-core/src/language_models/chat_models.ts b/langchain-core/src/language_models/chat_models.ts index 1f39802436d0..772964724e58 100644 --- a/langchain-core/src/language_models/chat_models.ts +++ b/langchain-core/src/language_models/chat_models.ts @@ -46,6 +46,9 @@ import { concat } from "../utils/stream.js"; import { RunnablePassthrough } from "../runnables/passthrough.js"; import { isZodSchema } from "../utils/types/is_zod_schema.js"; +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type ToolChoice = string | Record | "auto" | "any"; + /** * Represents a serialized chat model. */ @@ -73,7 +76,23 @@ export type BaseChatModelParams = BaseLanguageModelParams; /** * Represents the call options for a base chat model. */ -export type BaseChatModelCallOptions = BaseLanguageModelCallOptions; +export type BaseChatModelCallOptions = BaseLanguageModelCallOptions & { + /** + * Specifies how the chat model should use tools. + * @default undefined + * + * Possible values: + * - "auto": The model may choose to use any of the provided tools, or none. + * - "any": The model must use one of the provided tools. + * - "none": The model must not use any tools. + * - A string (not "auto", "any", or "none"): The name of a specific tool the model must use. + * - An object: A custom schema specifying tool choice parameters. Specific to the provider. + * + * Note: Not all providers support tool_choice. An error will be thrown + * if used with an unsupporting model. + */ + tool_choice?: ToolChoice; +}; /** * Creates a transform stream for encoding chat message chunks. diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index 8aa8295f2676..9d1e0bab5fba 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -21,12 +21,12 @@ import { import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { BaseChatModel, + BaseChatModelCallOptions, LangSmithParams, type BaseChatModelParams, } from "@langchain/core/language_models/chat_models"; import { type StructuredOutputMethodOptions, - type BaseLanguageModelCallOptions, type BaseLanguageModelInput, type ToolDefinition, isOpenAITool, @@ -53,6 +53,11 @@ import { extractToolCalls, } from "./output_parsers.js"; import { AnthropicToolResponse } from "./types.js"; +import { + AnthropicToolChoice, + AnthropicToolTypes, + handleToolChoice, +} from "./utils.js"; type AnthropicMessage = Anthropic.MessageParam; type AnthropicMessageCreateParams = Anthropic.MessageCreateParamsNonStreaming; @@ -60,23 +65,11 @@ type AnthropicStreamingMessageCreateParams = Anthropic.MessageCreateParamsStreaming; type AnthropicMessageStreamEvent = Anthropic.MessageStreamEvent; type AnthropicRequestOptions = Anthropic.RequestOptions; -type AnthropicToolChoice = - | { - type: "tool"; - name: string; - } - | "any" - | "auto"; + export interface ChatAnthropicCallOptions - extends BaseLanguageModelCallOptions, + extends BaseChatModelCallOptions, Pick { - tools?: ( - | StructuredToolInterface - | AnthropicTool - | Record - | ToolDefinition - | RunnableToolLike - )[]; + tools?: AnthropicToolTypes[]; /** * Whether or not to specify what tool the model should use * @default "auto" @@ -855,24 +848,11 @@ export class ChatAnthropicMessages< "messages" > & Kwargs { - let tool_choice: + const tool_choice: | MessageCreateParams.ToolChoiceAuto | MessageCreateParams.ToolChoiceAny | MessageCreateParams.ToolChoiceTool - | undefined; - if (options?.tool_choice) { - if (options?.tool_choice === "any") { - tool_choice = { - type: "any", - }; - } else if (options?.tool_choice === "auto") { - tool_choice = { - type: "auto", - }; - } else { - tool_choice = options?.tool_choice; - } - } + | undefined = handleToolChoice(options?.tool_choice); return { model: this.model, diff --git a/libs/langchain-anthropic/src/utils.ts b/libs/langchain-anthropic/src/utils.ts new file mode 100644 index 000000000000..860dd0f594ff --- /dev/null +++ b/libs/langchain-anthropic/src/utils.ts @@ -0,0 +1,51 @@ +import type { + MessageCreateParams, + Tool as AnthropicTool, +} from "@anthropic-ai/sdk/resources/index.mjs"; +import { ToolDefinition } from "@langchain/core/language_models/base"; +import { RunnableToolLike } from "@langchain/core/runnables"; +import { StructuredToolInterface } from "@langchain/core/tools"; + +export type AnthropicToolChoice = + | { + type: "tool"; + name: string; + } + | "any" + | "auto" + | "none" + | string; + +export type AnthropicToolTypes = + | StructuredToolInterface + | AnthropicTool + | Record + | ToolDefinition + | RunnableToolLike; + +export function handleToolChoice( + toolChoice?: AnthropicToolChoice +): + | MessageCreateParams.ToolChoiceAuto + | MessageCreateParams.ToolChoiceAny + | MessageCreateParams.ToolChoiceTool + | undefined { + if (!toolChoice) { + return undefined; + } else if (toolChoice === "any") { + return { + type: "any", + }; + } else if (toolChoice === "auto") { + return { + type: "auto", + }; + } else if (typeof toolChoice === "string") { + return { + type: "tool", + name: toolChoice, + }; + } else { + return toolChoice; + } +} diff --git a/libs/langchain-aws/src/chat_models.ts b/libs/langchain-aws/src/chat_models.ts index 8d5ae1abf00d..96b47993c3a6 100644 --- a/libs/langchain-aws/src/chat_models.ts +++ b/libs/langchain-aws/src/chat_models.ts @@ -2,7 +2,6 @@ import type { BaseMessage } from "@langchain/core/messages"; import { AIMessageChunk } from "@langchain/core/messages"; import type { ToolDefinition, - BaseLanguageModelCallOptions, BaseLanguageModelInput, } from "@langchain/core/language_models/base"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; @@ -10,6 +9,7 @@ import { type BaseChatModelParams, BaseChatModel, LangSmithParams, + BaseChatModelCallOptions, } from "@langchain/core/language_models/chat_models"; import type { ToolConfiguration, @@ -30,11 +30,7 @@ import { import type { DocumentType as __DocumentType } from "@smithy/types"; import { StructuredToolInterface } from "@langchain/core/tools"; import { Runnable, RunnableToolLike } from "@langchain/core/runnables"; -import { - BedrockToolChoice, - ConverseCommandParams, - CredentialType, -} from "./types.js"; +import { ConverseCommandParams, CredentialType } from "./types.js"; import { convertToConverseTools, convertToBedrockToolChoice, @@ -43,6 +39,7 @@ import { handleConverseStreamContentBlockDelta, handleConverseStreamMetadata, handleConverseStreamContentBlockStart, + BedrockConverseToolChoice, } from "./common.js"; /** @@ -127,7 +124,7 @@ export interface ChatBedrockConverseInput } export interface ChatBedrockConverseCallOptions - extends BaseLanguageModelCallOptions, + extends BaseChatModelCallOptions, Pick< ChatBedrockConverseInput, "additionalModelRequestFields" | "streamUsage" @@ -149,7 +146,7 @@ export interface ChatBedrockConverseCallOptions * or whether to generate text instead. * If a tool name is passed, it will force the model to call that specific tool. */ - tool_choice?: "any" | "auto" | string | BedrockToolChoice; + tool_choice?: BedrockConverseToolChoice; } /** diff --git a/libs/langchain-aws/src/common.ts b/libs/langchain-aws/src/common.ts index 14afc63cad6e..e9dd95759b59 100644 --- a/libs/langchain-aws/src/common.ts +++ b/libs/langchain-aws/src/common.ts @@ -258,8 +258,14 @@ export function convertToConverseTools( ); } +export type BedrockConverseToolChoice = + | "any" + | "auto" + | string + | BedrockToolChoice; + export function convertToBedrockToolChoice( - toolChoice: string | BedrockToolChoice, + toolChoice: BedrockConverseToolChoice, tools: BedrockTool[] ): BedrockToolChoice { if (typeof toolChoice === "string") { diff --git a/libs/langchain-groq/src/chat_models.ts b/libs/langchain-groq/src/chat_models.ts index eb8f2dffc76d..ca7c302cfd7f 100644 --- a/libs/langchain-groq/src/chat_models.ts +++ b/libs/langchain-groq/src/chat_models.ts @@ -69,7 +69,7 @@ import { ToolCallChunk } from "@langchain/core/messages/tool"; export interface ChatGroqCallOptions extends BaseChatModelCallOptions { headers?: Record; tools?: OpenAIClient.ChatCompletionTool[]; - tool_choice?: OpenAIClient.ChatCompletionToolChoiceOption; + tool_choice?: OpenAIClient.ChatCompletionToolChoiceOption | "any" | string; response_format?: { type: "json_object" }; } diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index 487b54ffab9b..db86e6e91940 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -64,7 +64,11 @@ import type { LegacyOpenAIInput, } from "./types.js"; import { type OpenAIEndpointConfig, getEndpoint } from "./utils/azure.js"; -import { wrapOpenAIClientError } from "./utils/openai.js"; +import { + OpenAIToolChoice, + formatToOpenAIToolChoice, + wrapOpenAIClientError, +} from "./utils/openai.js"; import { FunctionDef, formatFunctionDefinitions, @@ -274,7 +278,7 @@ export interface ChatOpenAICallOptions extends OpenAICallOptions, BaseFunctionCallOptions { tools?: StructuredToolInterface[] | OpenAIClient.ChatCompletionTool[]; - tool_choice?: OpenAIClient.ChatCompletionToolChoiceOption; + tool_choice?: OpenAIToolChoice; promptIndex?: number; response_format?: { type: "json_object" }; seed?: number; @@ -613,7 +617,7 @@ export class ChatOpenAI< tools: isStructuredToolArray(options?.tools) ? options?.tools.map(convertToOpenAITool) : options?.tools, - tool_choice: options?.tool_choice, + tool_choice: formatToOpenAIToolChoice(options?.tool_choice), response_format: options?.response_format, seed: options?.seed, ...streamOptionsConfig, diff --git a/libs/langchain-openai/src/utils/openai.ts b/libs/langchain-openai/src/utils/openai.ts index 990a7f1dc939..e95297e56b64 100644 --- a/libs/langchain-openai/src/utils/openai.ts +++ b/libs/langchain-openai/src/utils/openai.ts @@ -1,4 +1,8 @@ -import { APIConnectionTimeoutError, APIUserAbortError } from "openai"; +import { + APIConnectionTimeoutError, + APIUserAbortError, + OpenAI as OpenAIClient, +} from "openai"; import { zodToJsonSchema } from "zod-to-json-schema"; import type { StructuredToolInterface } from "@langchain/core/tools"; import { @@ -36,3 +40,31 @@ export function formatToOpenAIAssistantTool(tool: StructuredToolInterface) { }, }; } + +export type OpenAIToolChoice = + | OpenAIClient.ChatCompletionToolChoiceOption + | "any" + | string; + +export function formatToOpenAIToolChoice( + toolChoice?: OpenAIToolChoice +): OpenAIClient.ChatCompletionToolChoiceOption | undefined { + if (!toolChoice) { + return undefined; + } else if (toolChoice === "any" || toolChoice === "required") { + return "required"; + } else if (toolChoice === "auto") { + return "auto"; + } else if (toolChoice === "none") { + return "none"; + } else if (typeof toolChoice === "string") { + return { + type: "function", + function: { + name: toolChoice, + }, + }; + } else { + return toolChoice; + } +}