Skip to content

Commit

Permalink
core[minor]: Standardize tool choice (#6111)
Browse files Browse the repository at this point in the history
* core[minor]: Standardize tool choice

* implement in partner pkgs

* use basechatmodelcalloptions
  • Loading branch information
bracesproul authored Jul 17, 2024
1 parent 77904e3 commit a64203b
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 46 deletions.
21 changes: 20 additions & 1 deletion langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ import { concat } from "../utils/stream.js";
import { RunnablePassthrough } from "../runnables/passthrough.js";
import { isZodSchema } from "../utils/types/is_zod_schema.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type ToolChoice = string | Record<string, any> | "auto" | "any";

/**
* Represents a serialized chat model.
*/
Expand Down Expand Up @@ -73,7 +76,23 @@ export type BaseChatModelParams = BaseLanguageModelParams;
/**
* Represents the call options for a base chat model.
*/
export type BaseChatModelCallOptions = BaseLanguageModelCallOptions;
export type BaseChatModelCallOptions = BaseLanguageModelCallOptions & {
/**
* Specifies how the chat model should use tools.
* @default undefined
*
* Possible values:
* - "auto": The model may choose to use any of the provided tools, or none.
* - "any": The model must use one of the provided tools.
* - "none": The model must not use any tools.
* - A string (not "auto", "any", or "none"): The name of a specific tool the model must use.
* - An object: A custom schema specifying tool choice parameters. Specific to the provider.
*
* Note: Not all providers support tool_choice. An error will be thrown
* if used with an unsupporting model.
*/
tool_choice?: ToolChoice;
};

/**
* Creates a transform stream for encoding chat message chunks.
Expand Down
42 changes: 11 additions & 31 deletions libs/langchain-anthropic/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import {
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import {
BaseChatModel,
BaseChatModelCallOptions,
LangSmithParams,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import {
type StructuredOutputMethodOptions,
type BaseLanguageModelCallOptions,
type BaseLanguageModelInput,
type ToolDefinition,
isOpenAITool,
Expand All @@ -53,30 +53,23 @@ import {
extractToolCalls,
} from "./output_parsers.js";
import { AnthropicToolResponse } from "./types.js";
import {
AnthropicToolChoice,
AnthropicToolTypes,
handleToolChoice,
} from "./utils.js";

type AnthropicMessage = Anthropic.MessageParam;
type AnthropicMessageCreateParams = Anthropic.MessageCreateParamsNonStreaming;
type AnthropicStreamingMessageCreateParams =
Anthropic.MessageCreateParamsStreaming;
type AnthropicMessageStreamEvent = Anthropic.MessageStreamEvent;
type AnthropicRequestOptions = Anthropic.RequestOptions;
type AnthropicToolChoice =
| {
type: "tool";
name: string;
}
| "any"
| "auto";

export interface ChatAnthropicCallOptions
extends BaseLanguageModelCallOptions,
extends BaseChatModelCallOptions,
Pick<AnthropicInput, "streamUsage"> {
tools?: (
| StructuredToolInterface
| AnthropicTool
| Record<string, unknown>
| ToolDefinition
| RunnableToolLike
)[];
tools?: AnthropicToolTypes[];
/**
* Whether or not to specify what tool the model should use
* @default "auto"
Expand Down Expand Up @@ -855,24 +848,11 @@ export class ChatAnthropicMessages<
"messages"
> &
Kwargs {
let tool_choice:
const tool_choice:
| MessageCreateParams.ToolChoiceAuto
| MessageCreateParams.ToolChoiceAny
| MessageCreateParams.ToolChoiceTool
| undefined;
if (options?.tool_choice) {
if (options?.tool_choice === "any") {
tool_choice = {
type: "any",
};
} else if (options?.tool_choice === "auto") {
tool_choice = {
type: "auto",
};
} else {
tool_choice = options?.tool_choice;
}
}
| undefined = handleToolChoice(options?.tool_choice);

return {
model: this.model,
Expand Down
51 changes: 51 additions & 0 deletions libs/langchain-anthropic/src/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import type {
MessageCreateParams,
Tool as AnthropicTool,
} from "@anthropic-ai/sdk/resources/index.mjs";
import { ToolDefinition } from "@langchain/core/language_models/base";
import { RunnableToolLike } from "@langchain/core/runnables";
import { StructuredToolInterface } from "@langchain/core/tools";

export type AnthropicToolChoice =
| {
type: "tool";
name: string;
}
| "any"
| "auto"
| "none"
| string;

export type AnthropicToolTypes =
| StructuredToolInterface
| AnthropicTool
| Record<string, unknown>
| ToolDefinition
| RunnableToolLike;

export function handleToolChoice(
toolChoice?: AnthropicToolChoice
):
| MessageCreateParams.ToolChoiceAuto
| MessageCreateParams.ToolChoiceAny
| MessageCreateParams.ToolChoiceTool
| undefined {
if (!toolChoice) {
return undefined;
} else if (toolChoice === "any") {
return {
type: "any",
};
} else if (toolChoice === "auto") {
return {
type: "auto",
};
} else if (typeof toolChoice === "string") {
return {
type: "tool",
name: toolChoice,
};
} else {
return toolChoice;
}
}
13 changes: 5 additions & 8 deletions libs/langchain-aws/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ import type { BaseMessage } from "@langchain/core/messages";
import { AIMessageChunk } from "@langchain/core/messages";
import type {
ToolDefinition,
BaseLanguageModelCallOptions,
BaseLanguageModelInput,
} from "@langchain/core/language_models/base";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
type BaseChatModelParams,
BaseChatModel,
LangSmithParams,
BaseChatModelCallOptions,
} from "@langchain/core/language_models/chat_models";
import type {
ToolConfiguration,
Expand All @@ -30,11 +30,7 @@ import {
import type { DocumentType as __DocumentType } from "@smithy/types";
import { StructuredToolInterface } from "@langchain/core/tools";
import { Runnable, RunnableToolLike } from "@langchain/core/runnables";
import {
BedrockToolChoice,
ConverseCommandParams,
CredentialType,
} from "./types.js";
import { ConverseCommandParams, CredentialType } from "./types.js";
import {
convertToConverseTools,
convertToBedrockToolChoice,
Expand All @@ -43,6 +39,7 @@ import {
handleConverseStreamContentBlockDelta,
handleConverseStreamMetadata,
handleConverseStreamContentBlockStart,
BedrockConverseToolChoice,
} from "./common.js";

/**
Expand Down Expand Up @@ -127,7 +124,7 @@ export interface ChatBedrockConverseInput
}

export interface ChatBedrockConverseCallOptions
extends BaseLanguageModelCallOptions,
extends BaseChatModelCallOptions,
Pick<
ChatBedrockConverseInput,
"additionalModelRequestFields" | "streamUsage"
Expand All @@ -149,7 +146,7 @@ export interface ChatBedrockConverseCallOptions
* or whether to generate text instead.
* If a tool name is passed, it will force the model to call that specific tool.
*/
tool_choice?: "any" | "auto" | string | BedrockToolChoice;
tool_choice?: BedrockConverseToolChoice;
}

/**
Expand Down
8 changes: 7 additions & 1 deletion libs/langchain-aws/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,14 @@ export function convertToConverseTools(
);
}

export type BedrockConverseToolChoice =
| "any"
| "auto"
| string
| BedrockToolChoice;

export function convertToBedrockToolChoice(
toolChoice: string | BedrockToolChoice,
toolChoice: BedrockConverseToolChoice,
tools: BedrockTool[]
): BedrockToolChoice {
if (typeof toolChoice === "string") {
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-groq/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ import { ToolCallChunk } from "@langchain/core/messages/tool";
export interface ChatGroqCallOptions extends BaseChatModelCallOptions {
headers?: Record<string, string>;
tools?: OpenAIClient.ChatCompletionTool[];
tool_choice?: OpenAIClient.ChatCompletionToolChoiceOption;
tool_choice?: OpenAIClient.ChatCompletionToolChoiceOption | "any" | string;
response_format?: { type: "json_object" };
}

Expand Down
10 changes: 7 additions & 3 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ import type {
LegacyOpenAIInput,
} from "./types.js";
import { type OpenAIEndpointConfig, getEndpoint } from "./utils/azure.js";
import { wrapOpenAIClientError } from "./utils/openai.js";
import {
OpenAIToolChoice,
formatToOpenAIToolChoice,
wrapOpenAIClientError,
} from "./utils/openai.js";
import {
FunctionDef,
formatFunctionDefinitions,
Expand Down Expand Up @@ -274,7 +278,7 @@ export interface ChatOpenAICallOptions
extends OpenAICallOptions,
BaseFunctionCallOptions {
tools?: StructuredToolInterface[] | OpenAIClient.ChatCompletionTool[];
tool_choice?: OpenAIClient.ChatCompletionToolChoiceOption;
tool_choice?: OpenAIToolChoice;
promptIndex?: number;
response_format?: { type: "json_object" };
seed?: number;
Expand Down Expand Up @@ -613,7 +617,7 @@ export class ChatOpenAI<
tools: isStructuredToolArray(options?.tools)
? options?.tools.map(convertToOpenAITool)
: options?.tools,
tool_choice: options?.tool_choice,
tool_choice: formatToOpenAIToolChoice(options?.tool_choice),
response_format: options?.response_format,
seed: options?.seed,
...streamOptionsConfig,
Expand Down
34 changes: 33 additions & 1 deletion libs/langchain-openai/src/utils/openai.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import { APIConnectionTimeoutError, APIUserAbortError } from "openai";
import {
APIConnectionTimeoutError,
APIUserAbortError,
OpenAI as OpenAIClient,
} from "openai";
import { zodToJsonSchema } from "zod-to-json-schema";
import type { StructuredToolInterface } from "@langchain/core/tools";
import {
Expand Down Expand Up @@ -36,3 +40,31 @@ export function formatToOpenAIAssistantTool(tool: StructuredToolInterface) {
},
};
}

export type OpenAIToolChoice =
| OpenAIClient.ChatCompletionToolChoiceOption
| "any"
| string;

export function formatToOpenAIToolChoice(
toolChoice?: OpenAIToolChoice
): OpenAIClient.ChatCompletionToolChoiceOption | undefined {
if (!toolChoice) {
return undefined;
} else if (toolChoice === "any" || toolChoice === "required") {
return "required";
} else if (toolChoice === "auto") {
return "auto";
} else if (toolChoice === "none") {
return "none";
} else if (typeof toolChoice === "string") {
return {
type: "function",
function: {
name: toolChoice,
},
};
} else {
return toolChoice;
}
}

0 comments on commit a64203b

Please sign in to comment.