diff --git a/libs/langchain-aws/package.json b/libs/langchain-aws/package.json index 5ab58d4da376..04b85b592ebc 100644 --- a/libs/langchain-aws/package.json +++ b/libs/langchain-aws/package.json @@ -44,6 +44,7 @@ "@aws-sdk/client-kendra": "^3.352.0", "@aws-sdk/credential-provider-node": "^3.600.0", "@langchain/core": ">=0.2.21 <0.3.0", + "zod": "^3.23.8", "zod-to-json-schema": "^3.22.5" }, "devDependencies": { diff --git a/libs/langchain-aws/src/chat_models.ts b/libs/langchain-aws/src/chat_models.ts index 5abeaa518248..7b3a58dc29e9 100644 --- a/libs/langchain-aws/src/chat_models.ts +++ b/libs/langchain-aws/src/chat_models.ts @@ -1,6 +1,10 @@ import type { BaseMessage } from "@langchain/core/messages"; import { AIMessageChunk } from "@langchain/core/messages"; -import type { BaseLanguageModelInput } from "@langchain/core/language_models/base"; +import type { + BaseLanguageModelInput, + StructuredOutputMethodOptions, + ToolDefinition, +} from "@langchain/core/language_models/base"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { type BaseChatModelParams, @@ -24,12 +28,15 @@ import { DefaultProviderInit, } from "@aws-sdk/credential-provider-node"; import type { DocumentType as __DocumentType } from "@smithy/types"; -import { Runnable } from "@langchain/core/runnables"; import { - ChatBedrockConverseToolType, - ConverseCommandParams, - CredentialType, -} from "./types.js"; + Runnable, + RunnableLambda, + RunnablePassthrough, + RunnableSequence, +} from "@langchain/core/runnables"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { isZodSchema } from "@langchain/core/utils/types"; +import { z } from "zod"; import { convertToConverseTools, convertToBedrockToolChoice, @@ -40,6 +47,11 @@ import { handleConverseStreamContentBlockStart, BedrockConverseToolChoice, } from "./common.js"; +import { + ChatBedrockConverseToolType, + ConverseCommandParams, + CredentialType, +} from "./types.js"; /** * Inputs for ChatBedrockConverse. @@ -120,6 +132,14 @@ export interface ChatBedrockConverseInput * Configuration information for a guardrail that you want to use in the request. */ guardrailConfig?: GuardrailConfiguration; + + /** + * Which types of `tool_choice` values the model supports. + * + * Inferred if not specified. Inferred as ['auto', 'any', 'tool'] if a 'claude-3' + * model is used, ['auto', 'any'] if a 'mistral-large' model is used, empty otherwise. + */ + supportsToolChoiceValues?: Array<"auto" | "any" | "tool">; } export interface ChatBedrockConverseCallOptions @@ -214,6 +234,14 @@ export class ChatBedrockConverse client: BedrockRuntimeClient; + /** + * Which types of `tool_choice` values the model supports. + * + * Inferred if not specified. Inferred as ['auto', 'any', 'tool'] if a 'claude-3' + * model is used, ['auto', 'any'] if a 'mistral-large' model is used, empty otherwise. + */ + supportsToolChoiceValues?: Array<"auto" | "any" | "tool">; + constructor(fields?: ChatBedrockConverseInput) { super(fields ?? {}); const { @@ -264,6 +292,18 @@ export class ChatBedrockConverse this.additionalModelRequestFields = rest?.additionalModelRequestFields; this.streamUsage = rest?.streamUsage ?? this.streamUsage; this.guardrailConfig = rest?.guardrailConfig; + + if (rest?.supportsToolChoiceValues === undefined) { + if (this.model.includes("claude-3")) { + this.supportsToolChoiceValues = ["auto", "any", "tool"]; + } else if (this.model.includes("mistral-large")) { + this.supportsToolChoiceValues = ["auto", "any"]; + } else { + this.supportsToolChoiceValues = undefined; + } + } else { + this.supportsToolChoiceValues = rest.supportsToolChoiceValues; + } } getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { @@ -303,7 +343,10 @@ export class ChatBedrockConverse toolConfig = { tools, toolChoice: options.tool_choice - ? convertToBedrockToolChoice(options.tool_choice, tools) + ? convertToBedrockToolChoice(options.tool_choice, tools, { + model: this.model, + supportsToolChoiceValues: this.supportsToolChoiceValues, + }) : undefined, }; } @@ -430,4 +473,141 @@ export class ChatBedrockConverse } } } + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + // eslint-disable-next-line @typescript-eslint/no-explicit-any + >( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions + ): Runnable; + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + >( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions + ): Runnable; + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + >( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions + ): + | Runnable + | Runnable< + BaseLanguageModelInput, + { + raw: BaseMessage; + parsed: RunOutput; + } + > { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const schema: z.ZodType | Record = outputSchema; + const name = config?.name; + const description = schema.description ?? "A function available to call."; + const method = config?.method; + const includeRaw = config?.includeRaw; + if (method === "jsonMode") { + throw new Error(`ChatBedrockConverse does not support 'jsonMode'.`); + } + + let functionName = name ?? "extract"; + let tools: ToolDefinition[]; + if (isZodSchema(schema)) { + tools = [ + { + type: "function", + function: { + name: functionName, + description, + parameters: zodToJsonSchema(schema), + }, + }, + ]; + } else { + if ("name" in schema) { + functionName = schema.name; + } + tools = [ + { + type: "function", + function: { + name: functionName, + description, + parameters: schema, + }, + }, + ]; + } + + const supportsToolChoiceValues = this.supportsToolChoiceValues ?? []; + let toolChoiceObj: { tool_choice: string } | undefined; + if (supportsToolChoiceValues.includes("tool")) { + toolChoiceObj = { + tool_choice: tools[0].function.name, + }; + } else if (supportsToolChoiceValues.includes("any")) { + toolChoiceObj = { + tool_choice: "any", + }; + } + + const llm = this.bindTools(tools, toolChoiceObj); + const outputParser = RunnableLambda.from( + (input: AIMessageChunk): RunOutput => { + if (!input.tool_calls || input.tool_calls.length === 0) { + throw new Error("No tool calls found in the response."); + } + const toolCall = input.tool_calls.find( + (tc) => tc.name === functionName + ); + if (!toolCall) { + throw new Error(`No tool call found with name ${functionName}.`); + } + return toolCall.args as RunOutput; + } + ); + + if (!includeRaw) { + return llm.pipe(outputParser).withConfig({ + runName: "StructuredOutput", + }) as Runnable; + } + + const parserAssign = RunnablePassthrough.assign({ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + parsed: (input: any, config) => outputParser.invoke(input.raw, config), + }); + const parserNone = RunnablePassthrough.assign({ + parsed: () => null, + }); + const parsedWithFallback = parserAssign.withFallbacks({ + fallbacks: [parserNone], + }); + return RunnableSequence.from< + BaseLanguageModelInput, + { raw: BaseMessage; parsed: RunOutput } + >([ + { + raw: llm, + }, + parsedWithFallback, + ]).withConfig({ + runName: "StructuredOutputRunnable", + }); + } } diff --git a/libs/langchain-aws/src/common.ts b/libs/langchain-aws/src/common.ts index 1e201dc8a97c..fbb3356e90e4 100644 --- a/libs/langchain-aws/src/common.ts +++ b/libs/langchain-aws/src/common.ts @@ -271,18 +271,27 @@ export type BedrockConverseToolChoice = export function convertToBedrockToolChoice( toolChoice: BedrockConverseToolChoice, - tools: BedrockTool[] + tools: BedrockTool[], + fields: { + model: string; + supportsToolChoiceValues?: Array<"auto" | "any" | "tool">; + } ): BedrockToolChoice { + const supportsToolChoiceValues = fields.supportsToolChoiceValues ?? []; + + let bedrockToolChoice: BedrockToolChoice; if (typeof toolChoice === "string") { switch (toolChoice) { case "any": - return { + bedrockToolChoice = { any: {}, }; + break; case "auto": - return { + bedrockToolChoice = { auto: {}, }; + break; default: { const foundTool = tools.find( (tool) => tool.toolSpec?.name === toolChoice @@ -292,15 +301,40 @@ export function convertToBedrockToolChoice( `Tool with name ${toolChoice} not found in tools list.` ); } - return { + bedrockToolChoice = { tool: { name: toolChoice, }, }; } } + } else { + bedrockToolChoice = toolChoice; + } + + const toolChoiceType = Object.keys(bedrockToolChoice)[0] as + | "auto" + | "any" + | "tool"; + if (!supportsToolChoiceValues.includes(toolChoiceType)) { + let supportedTxt = ""; + if (supportsToolChoiceValues.length) { + supportedTxt = + `Model ${fields.model} does not currently support 'tool_choice' ` + + `of type ${toolChoiceType}. The following 'tool_choice' types ` + + `are supported: ${supportsToolChoiceValues.join(", ")}.`; + } else { + supportedTxt = `Model ${fields.model} does not currently support 'tool_choice'.`; + } + + throw new Error( + `${supportedTxt} Please see` + + "https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html" + + "for the latest documentation on models that support tool choice." + ); } - return toolChoice; + + return bedrockToolChoice; } export function convertConverseMessageToLangChainMessage( diff --git a/libs/langchain-aws/src/tests/chat_models.test.ts b/libs/langchain-aws/src/tests/chat_models.test.ts index 22ccd342ebc0..e82faf216627 100644 --- a/libs/langchain-aws/src/tests/chat_models.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.test.ts @@ -11,10 +11,13 @@ import type { Message as BedrockMessage, SystemContentBlock as BedrockSystemContentBlock, } from "@aws-sdk/client-bedrock-runtime"; +import { z } from "zod"; +import { describe, expect, test } from "@jest/globals"; import { convertToConverseMessages, handleConverseStreamContentBlockDelta, } from "../common.js"; +import { ChatBedrockConverse } from "../chat_models.js"; describe("convertToConverseMessages", () => { const testCases: { @@ -337,3 +340,204 @@ test("Streaming supports empty string chunks", async () => { if (!finalChunk) return; expect(finalChunk.content).toBe("Hello world!"); }); + +describe("tool_choice works for supported models", () => { + const tool = { + name: "weather", + schema: z.object({ + location: z.string(), + }), + }; + const baseConstructorArgs = { + region: "us-east-1", + credentials: { + secretAccessKey: "process.env.BEDROCK_AWS_SECRET_ACCESS_KEY", + accessKeyId: "process.env.BEDROCK_AWS_ACCESS_KEY_ID", + }, + }; + const supportsToolChoiceValuesClaude3: Array<"auto" | "any" | "tool"> = [ + "auto", + "any", + "tool", + ]; + const supportsToolChoiceValuesMistralLarge: Array<"auto" | "any" | "tool"> = [ + "auto", + "any", + ]; + + it("throws an error if passing tool_choice with unsupported models", async () => { + // Claude 2 should throw + const claude2Model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-v2", + }); + const claude2WithTool = claude2Model.bindTools([tool], { + tool_choice: tool.name, + }); + await expect(claude2WithTool.invoke("foo")).rejects.toThrow(); + + // Cohere should throw + const cohereModel = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "cohere.command-text-v14", + }); + const cohereModelWithTool = cohereModel.bindTools([tool], { + tool_choice: tool.name, + }); + await expect(cohereModelWithTool.invoke("foo")).rejects.toThrow(); + + // Mistral (not mistral large) should throw + const mistralModel = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "mistral.mistral-7b-instruct-v0:2", + }); + const mistralModelWithTool = mistralModel.bindTools([tool], { + tool_choice: tool.name, + }); + await expect(mistralModelWithTool.invoke("foo")).rejects.toThrow(); + }); + + it("does NOT throw and binds tool_choice when calling bindTools with supported models", async () => { + // Claude 3 should NOT throw + const claude3Model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-3-5-sonnet-20240620-v1:0", + supportsToolChoiceValues: supportsToolChoiceValuesClaude3, + }); + const claude3ModelWithTool = claude3Model.bindTools([tool], { + tool_choice: tool.name, + }); + expect(claude3ModelWithTool).toBeDefined(); + const claude3ModelWithToolAsJSON = claude3ModelWithTool.toJSON(); + if (!("kwargs" in claude3ModelWithToolAsJSON)) { + throw new Error("kwargs not found in claude3ModelWithToolAsJSON"); + } + expect(claude3ModelWithToolAsJSON.kwargs.kwargs).toHaveProperty( + "tool_choice" + ); + expect(claude3ModelWithToolAsJSON.kwargs.kwargs.tool_choice).toBe( + tool.name + ); + + // Mistral large should NOT throw + const mistralModel = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "mistral.mistral-large-2407-v1:0", + supportsToolChoiceValues: supportsToolChoiceValuesMistralLarge, + }); + const mistralModelWithTool = mistralModel.bindTools([tool], { + tool_choice: tool.name, + }); + expect(mistralModelWithTool).toBeDefined(); + const mistralModelWithToolAsJSON = mistralModelWithTool.toJSON(); + if (!("kwargs" in mistralModelWithToolAsJSON)) { + throw new Error("kwargs not found in mistralModelWithToolAsJSON"); + } + expect(mistralModelWithToolAsJSON.kwargs.kwargs).toHaveProperty( + "tool_choice" + ); + expect(mistralModelWithToolAsJSON.kwargs.kwargs.tool_choice).toBe( + tool.name + ); + }); + + it("should NOT bind and NOT throw when using WSO with unsupported models", async () => { + // Claude 2 should NOT throw is using WSO + const claude2Model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-v2", + }); + const claude2ModelWSO = claude2Model.withStructuredOutput(tool.schema, { + name: tool.name, + }); + expect(claude2ModelWSO).toBeDefined(); + const claude2ModelWSOAsJSON = claude2ModelWSO.toJSON(); + if (!("kwargs" in claude2ModelWSOAsJSON)) { + throw new Error("kwargs not found in claude2ModelWSOAsJSON"); + } + expect(claude2ModelWSOAsJSON.kwargs.bound.first.kwargs).not.toHaveProperty( + "tool_choice" + ); + + // Cohere should NOT throw is using WSO + const cohereModel = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "cohere.command-text-v14", + }); + const cohereModelWSO = cohereModel.withStructuredOutput(tool.schema, { + name: tool.name, + }); + expect(cohereModelWSO).toBeDefined(); + const cohereModelWSOAsJSON = cohereModelWSO.toJSON(); + if (!("kwargs" in cohereModelWSOAsJSON)) { + throw new Error("kwargs not found in cohereModelWSOAsJSON"); + } + expect(cohereModelWSOAsJSON.kwargs.bound.first.kwargs).not.toHaveProperty( + "tool_choice" + ); + + // Mistral (not mistral large) should NOT throw is using WSO + const mistralModel = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "mistral.mistral-7b-instruct-v0:2", + }); + const mistralModelWSO = mistralModel.withStructuredOutput(tool.schema, { + name: tool.name, + }); + expect(mistralModelWSO).toBeDefined(); + const mistralModelWSOAsJSON = mistralModelWSO.toJSON(); + if (!("kwargs" in mistralModelWSOAsJSON)) { + throw new Error("kwargs not found in mistralModelWSOAsJSON"); + } + expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs).not.toHaveProperty( + "tool_choice" + ); + }); + + it("should bind tool_choice when using WSO with supported models", async () => { + // Claude 3 should NOT throw is using WSO & it should have `tool_choice` bound. + const claude3Model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-3-5-sonnet-20240620-v1:0", + // We are not passing the `supportsToolChoiceValues` arg here as + // it should be inferred from the model name. + }); + const claude3ModelWSO = claude3Model.withStructuredOutput(tool.schema, { + name: tool.name, + }); + expect(claude3ModelWSO).toBeDefined(); + const claude3ModelWSOAsJSON = claude3ModelWSO.toJSON(); + if (!("kwargs" in claude3ModelWSOAsJSON)) { + throw new Error("kwargs not found in claude3ModelWSOAsJSON"); + } + expect(claude3ModelWSOAsJSON.kwargs.bound.first.kwargs).toHaveProperty( + "tool_choice" + ); + expect(claude3ModelWSOAsJSON.kwargs.bound.first.kwargs.tool_choice).toBe( + tool.name + ); + + // Mistral (not mistral large) should NOT throw is using WSO + const mistralModel = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "mistral.mistral-large-2407-v1:0", + // We are not passing the `supportsToolChoiceValues` arg here as + // it should be inferred from the model name. + }); + const mistralModelWSO = mistralModel.withStructuredOutput(tool.schema, { + name: tool.name, + }); + expect(mistralModelWSO).toBeDefined(); + const mistralModelWSOAsJSON = mistralModelWSO.toJSON(); + if (!("kwargs" in mistralModelWSOAsJSON)) { + throw new Error("kwargs not found in mistralModelWSOAsJSON"); + } + expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs).toHaveProperty( + "tool_choice" + ); + // Mistral large only supports "auto" and "any" for tool_choice, not the actual tool name + expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs.tool_choice).toBe( + "any" + ); + }); +});