From a3b2c9a999604bc8b02e6d7d6a9fb5353d462691 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 7 Aug 2024 11:33:08 -0700 Subject: [PATCH 1/7] aws[minor]: Implement WSO with tool_choice --- libs/langchain-aws/src/chat_models.ts | 141 +++++++++++++++++++++++++- 1 file changed, 139 insertions(+), 2 deletions(-) diff --git a/libs/langchain-aws/src/chat_models.ts b/libs/langchain-aws/src/chat_models.ts index 5abeaa518248..372e3083fd26 100644 --- a/libs/langchain-aws/src/chat_models.ts +++ b/libs/langchain-aws/src/chat_models.ts @@ -1,6 +1,6 @@ import type { BaseMessage } from "@langchain/core/messages"; import { AIMessageChunk } from "@langchain/core/messages"; -import type { BaseLanguageModelInput } from "@langchain/core/language_models/base"; +import type { BaseLanguageModelInput, StructuredOutputMethodOptions, ToolDefinition } from "@langchain/core/language_models/base"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { type BaseChatModelParams, @@ -24,7 +24,7 @@ import { DefaultProviderInit, } from "@aws-sdk/credential-provider-node"; import type { DocumentType as __DocumentType } from "@smithy/types"; -import { Runnable } from "@langchain/core/runnables"; +import { Runnable, RunnableLambda, RunnablePassthrough, RunnableSequence } from "@langchain/core/runnables"; import { ChatBedrockConverseToolType, ConverseCommandParams, @@ -40,6 +40,9 @@ import { handleConverseStreamContentBlockStart, BedrockConverseToolChoice, } from "./common.js"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { isZodSchema } from "@langchain/core/utils/types"; +import { z } from "zod"; /** * Inputs for ChatBedrockConverse. @@ -430,4 +433,138 @@ export class ChatBedrockConverse } } } + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + // eslint-disable-next-line @typescript-eslint/no-explicit-any + >( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions + ): Runnable; + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + >( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions + ): Runnable; + + withStructuredOutput< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends Record = Record + >( + outputSchema: + | z.ZodType + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record, + config?: StructuredOutputMethodOptions + ): + | Runnable + | Runnable< + BaseLanguageModelInput, + { + raw: BaseMessage; + parsed: RunOutput; + } + > { + if (typeof this.bindTools !== "function") { + throw new Error( + `Chat model must implement ".bindTools()" to use withStructuredOutput.` + ); + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const schema: z.ZodType | Record = outputSchema; + const name = config?.name; + const description = schema.description ?? "A function available to call."; + const method = config?.method; + const includeRaw = config?.includeRaw; + if (method === "jsonMode") { + throw new Error( + `Base withStructuredOutput implementation only supports "functionCalling" as a method.` + ); + } + + let functionName = name ?? "extract"; + let tools: ToolDefinition[]; + if (isZodSchema(schema)) { + tools = [ + { + type: "function", + function: { + name: functionName, + description, + parameters: zodToJsonSchema(schema), + }, + }, + ]; + } else { + if ("name" in schema) { + functionName = schema.name; + } + tools = [ + { + type: "function", + function: { + name: functionName, + description, + parameters: schema, + }, + }, + ]; + } + + const llm = this.bindTools(tools, { + tool_choice: tools[0].function.name, + }); + const outputParser = RunnableLambda.from( + (input: AIMessageChunk): RunOutput => { + if (!input.tool_calls || input.tool_calls.length === 0) { + throw new Error("No tool calls found in the response."); + } + const toolCall = input.tool_calls.find( + (tc) => tc.name === functionName + ); + if (!toolCall) { + throw new Error(`No tool call found with name ${functionName}.`); + } + return toolCall.args as RunOutput; + } + ); + + if (!includeRaw) { + return llm.pipe(outputParser).withConfig({ + runName: "StructuredOutput", + }) as Runnable; + } + + const parserAssign = RunnablePassthrough.assign({ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + parsed: (input: any, config) => outputParser.invoke(input.raw, config), + }); + const parserNone = RunnablePassthrough.assign({ + parsed: () => null, + }); + const parsedWithFallback = parserAssign.withFallbacks({ + fallbacks: [parserNone], + }); + return RunnableSequence.from< + BaseLanguageModelInput, + { raw: BaseMessage; parsed: RunOutput } + >([ + { + raw: llm, + }, + parsedWithFallback, + ]).withConfig({ + runName: "StructuredOutputRunnable", + }); + } } From 34c35d29ad32d5c7f26585a5e1f8469d9abce981 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 7 Aug 2024 11:44:58 -0700 Subject: [PATCH 2/7] cr --- libs/langchain-aws/src/tests/chat_models.test.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/langchain-aws/src/tests/chat_models.test.ts b/libs/langchain-aws/src/tests/chat_models.test.ts index 22ccd342ebc0..57f0e42891ec 100644 --- a/libs/langchain-aws/src/tests/chat_models.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.test.ts @@ -15,6 +15,7 @@ import { convertToConverseMessages, handleConverseStreamContentBlockDelta, } from "../common.js"; +import { describe, test, expect, it } from "@jest/globals"; describe("convertToConverseMessages", () => { const testCases: { From f3a34f8d885b0b023e15f87a3516d2cb76de3ad7 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 7 Aug 2024 11:45:06 -0700 Subject: [PATCH 3/7] chore: lint files --- libs/langchain-aws/src/chat_models.ts | 27 ++++++++++++------- .../src/tests/chat_models.test.ts | 2 +- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/libs/langchain-aws/src/chat_models.ts b/libs/langchain-aws/src/chat_models.ts index 372e3083fd26..b295f4049146 100644 --- a/libs/langchain-aws/src/chat_models.ts +++ b/libs/langchain-aws/src/chat_models.ts @@ -1,6 +1,10 @@ import type { BaseMessage } from "@langchain/core/messages"; import { AIMessageChunk } from "@langchain/core/messages"; -import type { BaseLanguageModelInput, StructuredOutputMethodOptions, ToolDefinition } from "@langchain/core/language_models/base"; +import type { + BaseLanguageModelInput, + StructuredOutputMethodOptions, + ToolDefinition, +} from "@langchain/core/language_models/base"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { type BaseChatModelParams, @@ -24,12 +28,15 @@ import { DefaultProviderInit, } from "@aws-sdk/credential-provider-node"; import type { DocumentType as __DocumentType } from "@smithy/types"; -import { Runnable, RunnableLambda, RunnablePassthrough, RunnableSequence } from "@langchain/core/runnables"; import { - ChatBedrockConverseToolType, - ConverseCommandParams, - CredentialType, -} from "./types.js"; + Runnable, + RunnableLambda, + RunnablePassthrough, + RunnableSequence, +} from "@langchain/core/runnables"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { isZodSchema } from "@langchain/core/utils/types"; +import { z } from "zod"; import { convertToConverseTools, convertToBedrockToolChoice, @@ -40,9 +47,11 @@ import { handleConverseStreamContentBlockStart, BedrockConverseToolChoice, } from "./common.js"; -import { zodToJsonSchema } from "zod-to-json-schema"; -import { isZodSchema } from "@langchain/core/utils/types"; -import { z } from "zod"; +import { + ChatBedrockConverseToolType, + ConverseCommandParams, + CredentialType, +} from "./types.js"; /** * Inputs for ChatBedrockConverse. diff --git a/libs/langchain-aws/src/tests/chat_models.test.ts b/libs/langchain-aws/src/tests/chat_models.test.ts index 57f0e42891ec..de9b181f57ee 100644 --- a/libs/langchain-aws/src/tests/chat_models.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.test.ts @@ -11,11 +11,11 @@ import type { Message as BedrockMessage, SystemContentBlock as BedrockSystemContentBlock, } from "@aws-sdk/client-bedrock-runtime"; +import { describe, test, expect, it } from "@jest/globals"; import { convertToConverseMessages, handleConverseStreamContentBlockDelta, } from "../common.js"; -import { describe, test, expect, it } from "@jest/globals"; describe("convertToConverseMessages", () => { const testCases: { From 0cfff8463c7c7ecc395c822c94b1fd0e39684b1f Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 7 Aug 2024 12:08:34 -0700 Subject: [PATCH 4/7] cr --- libs/langchain-aws/src/tests/chat_models.test.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/langchain-aws/src/tests/chat_models.test.ts b/libs/langchain-aws/src/tests/chat_models.test.ts index de9b181f57ee..22ccd342ebc0 100644 --- a/libs/langchain-aws/src/tests/chat_models.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.test.ts @@ -11,7 +11,6 @@ import type { Message as BedrockMessage, SystemContentBlock as BedrockSystemContentBlock, } from "@aws-sdk/client-bedrock-runtime"; -import { describe, test, expect, it } from "@jest/globals"; import { convertToConverseMessages, handleConverseStreamContentBlockDelta, From e96cf9e6a435ddfe5356cfbacbb5d3dc50d80311 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 7 Aug 2024 12:45:19 -0700 Subject: [PATCH 5/7] add more tests --- libs/langchain-aws/package.json | 1 + libs/langchain-aws/src/chat_models.ts | 21 ++- .../src/tests/chat_models.test.ts | 163 ++++++++++++++++++ 3 files changed, 177 insertions(+), 8 deletions(-) diff --git a/libs/langchain-aws/package.json b/libs/langchain-aws/package.json index 5ab58d4da376..04b85b592ebc 100644 --- a/libs/langchain-aws/package.json +++ b/libs/langchain-aws/package.json @@ -44,6 +44,7 @@ "@aws-sdk/client-kendra": "^3.352.0", "@aws-sdk/credential-provider-node": "^3.600.0", "@langchain/core": ">=0.2.21 <0.3.0", + "zod": "^3.23.8", "zod-to-json-schema": "^3.22.5" }, "devDependencies": { diff --git a/libs/langchain-aws/src/chat_models.ts b/libs/langchain-aws/src/chat_models.ts index b295f4049146..3b0acda8ed3e 100644 --- a/libs/langchain-aws/src/chat_models.ts +++ b/libs/langchain-aws/src/chat_models.ts @@ -53,6 +53,10 @@ import { CredentialType, } from "./types.js"; +// Models which support the `toolChoice` param. +// See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html +const ALLOWED_TOOL_CHOICE_MODELS_PREFIX = ["anthropic.claude-3", "mistral.mistral-large"] + /** * Inputs for ChatBedrockConverse. */ @@ -298,6 +302,11 @@ export class ChatBedrockConverse AIMessageChunk, this["ParsedCallOptions"] > { + if (kwargs?.tool_choice) { + if (!ALLOWED_TOOL_CHOICE_MODELS_PREFIX.find((prefix) => this.model.startsWith(prefix))) { + throw new Error("Only Anthropic Claude 3 and Mistral Large models support the tool_choice parameter."); + } + } return this.bind({ tools: convertToConverseTools(tools), ...kwargs }); } @@ -484,11 +493,6 @@ export class ChatBedrockConverse parsed: RunOutput; } > { - if (typeof this.bindTools !== "function") { - throw new Error( - `Chat model must implement ".bindTools()" to use withStructuredOutput.` - ); - } // eslint-disable-next-line @typescript-eslint/no-explicit-any const schema: z.ZodType | Record = outputSchema; const name = config?.name; @@ -497,7 +501,7 @@ export class ChatBedrockConverse const includeRaw = config?.includeRaw; if (method === "jsonMode") { throw new Error( - `Base withStructuredOutput implementation only supports "functionCalling" as a method.` + `ChatBedrockConverse does not support 'jsonMode'.` ); } @@ -530,9 +534,10 @@ export class ChatBedrockConverse ]; } - const llm = this.bindTools(tools, { + const toolChoiceObj = ALLOWED_TOOL_CHOICE_MODELS_PREFIX.find((prefix) => this.model.startsWith(prefix)) ? { tool_choice: tools[0].function.name, - }); + } : undefined + const llm = this.bindTools(tools, toolChoiceObj); const outputParser = RunnableLambda.from( (input: AIMessageChunk): RunOutput => { if (!input.tool_calls || input.tool_calls.length === 0) { diff --git a/libs/langchain-aws/src/tests/chat_models.test.ts b/libs/langchain-aws/src/tests/chat_models.test.ts index 22ccd342ebc0..1e7e01c1529c 100644 --- a/libs/langchain-aws/src/tests/chat_models.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.test.ts @@ -15,6 +15,9 @@ import { convertToConverseMessages, handleConverseStreamContentBlockDelta, } from "../common.js"; +import { ChatBedrockConverse } from "../chat_models.js"; +import { z } from "zod"; +import { describe, expect, test } from "@jest/globals"; describe("convertToConverseMessages", () => { const testCases: { @@ -337,3 +340,163 @@ test("Streaming supports empty string chunks", async () => { if (!finalChunk) return; expect(finalChunk.content).toBe("Hello world!"); }); + +describe("tool_choice works for supported models", () => { + const tool = { + name: "weather", + schema: z.object({ + location: z.string(), + }) + } + const baseConstructorArgs = { + region: "us-east-1", + credentials: { + secretAccessKey: "process.env.BEDROCK_AWS_SECRET_ACCESS_KEY", + accessKeyId: "process.env.BEDROCK_AWS_ACCESS_KEY_ID", + }, + } + + it("throws an error if passing tool_choice with unsupported models", async () => { + // Claude 2 should throw + const claude2Model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-v2" + }) + expect(() => claude2Model.bindTools([tool], { + tool_choice: tool.name, + })).toThrow(); + + // Cohere should throw + const cohereModel = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "cohere.command-text-v14" + }) + expect(() => cohereModel.bindTools([tool], { + tool_choice: tool.name, + })).toThrow(); + + // Mistral (not mistral large) should throw + const mistralModel = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "mistral.mistral-7b-instruct-v0:2" + }) + expect(() => mistralModel.bindTools([tool], { + tool_choice: tool.name, + })).toThrow(); + }) + + it("does NOT throw and binds tool_choice when calling bindTools with supported models", async () => { + // Claude 3 should NOT throw + const claude3Model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-3-5-sonnet-20240620-v1:0" + }) + const claude3ModelWithTool = claude3Model.bindTools([tool], { + tool_choice: tool.name, + }) + expect(claude3ModelWithTool).toBeDefined(); + const claude3ModelWithToolAsJSON = claude3ModelWithTool.toJSON(); + if (!("kwargs" in claude3ModelWithToolAsJSON)) { + throw new Error("kwargs not found in claude3ModelWithToolAsJSON"); + } + expect(claude3ModelWithToolAsJSON.kwargs.kwargs).toHaveProperty("tool_choice"); + expect(claude3ModelWithToolAsJSON.kwargs.kwargs.tool_choice).toBe(tool.name); + + // Mistral large should NOT throw + const mistralModel = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "mistral.mistral-large-2407-v1:0" + }) + const mistralModelWithTool = mistralModel.bindTools([tool], { + tool_choice: tool.name, + }); + expect(mistralModelWithTool).toBeDefined(); + const mistralModelWithToolAsJSON = mistralModelWithTool.toJSON(); + if (!("kwargs" in mistralModelWithToolAsJSON)) { + throw new Error("kwargs not found in mistralModelWithToolAsJSON"); + } + expect(mistralModelWithToolAsJSON.kwargs.kwargs).toHaveProperty("tool_choice"); + expect(mistralModelWithToolAsJSON.kwargs.kwargs.tool_choice).toBe(tool.name); + }) + + it("should NOT bind and NOT throw when using WSO with unsupported models", async () => { + // Claude 2 should NOT throw is using WSO + const claude2Model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-v2" + }) + const claude2ModelWSO = claude2Model.withStructuredOutput(tool.schema, { + name: tool.name, + }) + expect(claude2ModelWSO).toBeDefined(); + const claude2ModelWSOAsJSON = claude2ModelWSO.toJSON(); + if (!("kwargs" in claude2ModelWSOAsJSON)) { + throw new Error("kwargs not found in claude2ModelWSOAsJSON"); + } + expect(claude2ModelWSOAsJSON.kwargs.bound.first.kwargs).not.toHaveProperty("tool_choice"); + + // Cohere should NOT throw is using WSO + const cohereModel = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "cohere.command-text-v14" + }) + const cohereModelWSO = cohereModel.withStructuredOutput(tool.schema, { + name: tool.name, + }) + expect(cohereModelWSO).toBeDefined(); + const cohereModelWSOAsJSON = cohereModelWSO.toJSON(); + if (!("kwargs" in cohereModelWSOAsJSON)) { + throw new Error("kwargs not found in cohereModelWSOAsJSON"); + } + expect(cohereModelWSOAsJSON.kwargs.bound.first.kwargs).not.toHaveProperty("tool_choice"); + + // Mistral (not mistral large) should NOT throw is using WSO + const mistralModel = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "mistral.mistral-7b-instruct-v0:2" + }) + const mistralModelWSO = mistralModel.withStructuredOutput(tool.schema, { + name: tool.name, + }) + expect(mistralModelWSO).toBeDefined(); + const mistralModelWSOAsJSON = mistralModelWSO.toJSON(); + if (!("kwargs" in mistralModelWSOAsJSON)) { + throw new Error("kwargs not found in mistralModelWSOAsJSON"); + } + expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs).not.toHaveProperty("tool_choice"); + }) + + it("should bind tool_choice when using WSO with supported models", async () => { + // Claude 3 should NOT throw is using WSO & it should have `tool_choice` bound. + const claude3Model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-3-5-sonnet-20240620-v1:0" + }) + const claude3ModelWSO = claude3Model.withStructuredOutput(tool.schema, { + name: tool.name, + }) + expect(claude3ModelWSO).toBeDefined(); + const claude3ModelWSOAsJSON = claude3ModelWSO.toJSON(); + if (!("kwargs" in claude3ModelWSOAsJSON)) { + throw new Error("kwargs not found in claude3ModelWSOAsJSON"); + } + expect(claude3ModelWSOAsJSON.kwargs.bound.first.kwargs).toHaveProperty("tool_choice"); + expect(claude3ModelWSOAsJSON.kwargs.bound.first.kwargs.tool_choice).toBe(tool.name); + + // Mistral (not mistral large) should NOT throw is using WSO + const mistralModel = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "mistral.mistral-large-2407-v1:0" + }) + const mistralModelWSO = mistralModel.withStructuredOutput(tool.schema, { + name: tool.name, + }) + expect(mistralModelWSO).toBeDefined(); + const mistralModelWSOAsJSON = mistralModelWSO.toJSON(); + if (!("kwargs" in mistralModelWSOAsJSON)) { + throw new Error("kwargs not found in mistralModelWSOAsJSON"); + } + expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs).toHaveProperty("tool_choice"); + expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs.tool_choice).toBe(tool.name); + }) +}); \ No newline at end of file From 8501cc7d3c79293bf416b9a43ac0c59168afaedc Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 7 Aug 2024 12:45:29 -0700 Subject: [PATCH 6/7] chore: lint files --- libs/langchain-aws/src/chat_models.ts | 29 +++- .../src/tests/chat_models.test.ts | 158 +++++++++++------- 2 files changed, 113 insertions(+), 74 deletions(-) diff --git a/libs/langchain-aws/src/chat_models.ts b/libs/langchain-aws/src/chat_models.ts index 3b0acda8ed3e..84a901d7314c 100644 --- a/libs/langchain-aws/src/chat_models.ts +++ b/libs/langchain-aws/src/chat_models.ts @@ -55,7 +55,10 @@ import { // Models which support the `toolChoice` param. // See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html -const ALLOWED_TOOL_CHOICE_MODELS_PREFIX = ["anthropic.claude-3", "mistral.mistral-large"] +const ALLOWED_TOOL_CHOICE_MODELS_PREFIX = [ + "anthropic.claude-3", + "mistral.mistral-large", +]; /** * Inputs for ChatBedrockConverse. @@ -303,8 +306,14 @@ export class ChatBedrockConverse this["ParsedCallOptions"] > { if (kwargs?.tool_choice) { - if (!ALLOWED_TOOL_CHOICE_MODELS_PREFIX.find((prefix) => this.model.startsWith(prefix))) { - throw new Error("Only Anthropic Claude 3 and Mistral Large models support the tool_choice parameter."); + if ( + !ALLOWED_TOOL_CHOICE_MODELS_PREFIX.find((prefix) => + this.model.startsWith(prefix) + ) + ) { + throw new Error( + "Only Anthropic Claude 3 and Mistral Large models support the tool_choice parameter." + ); } } return this.bind({ tools: convertToConverseTools(tools), ...kwargs }); @@ -500,9 +509,7 @@ export class ChatBedrockConverse const method = config?.method; const includeRaw = config?.includeRaw; if (method === "jsonMode") { - throw new Error( - `ChatBedrockConverse does not support 'jsonMode'.` - ); + throw new Error(`ChatBedrockConverse does not support 'jsonMode'.`); } let functionName = name ?? "extract"; @@ -534,9 +541,13 @@ export class ChatBedrockConverse ]; } - const toolChoiceObj = ALLOWED_TOOL_CHOICE_MODELS_PREFIX.find((prefix) => this.model.startsWith(prefix)) ? { - tool_choice: tools[0].function.name, - } : undefined + const toolChoiceObj = ALLOWED_TOOL_CHOICE_MODELS_PREFIX.find((prefix) => + this.model.startsWith(prefix) + ) + ? { + tool_choice: tools[0].function.name, + } + : undefined; const llm = this.bindTools(tools, toolChoiceObj); const outputParser = RunnableLambda.from( (input: AIMessageChunk): RunOutput => { diff --git a/libs/langchain-aws/src/tests/chat_models.test.ts b/libs/langchain-aws/src/tests/chat_models.test.ts index 1e7e01c1529c..c95213761a39 100644 --- a/libs/langchain-aws/src/tests/chat_models.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.test.ts @@ -11,13 +11,13 @@ import type { Message as BedrockMessage, SystemContentBlock as BedrockSystemContentBlock, } from "@aws-sdk/client-bedrock-runtime"; +import { z } from "zod"; +import { describe, expect, test } from "@jest/globals"; import { convertToConverseMessages, handleConverseStreamContentBlockDelta, } from "../common.js"; import { ChatBedrockConverse } from "../chat_models.js"; -import { z } from "zod"; -import { describe, expect, test } from "@jest/globals"; describe("convertToConverseMessages", () => { const testCases: { @@ -346,67 +346,77 @@ describe("tool_choice works for supported models", () => { name: "weather", schema: z.object({ location: z.string(), - }) - } + }), + }; const baseConstructorArgs = { region: "us-east-1", credentials: { secretAccessKey: "process.env.BEDROCK_AWS_SECRET_ACCESS_KEY", accessKeyId: "process.env.BEDROCK_AWS_ACCESS_KEY_ID", }, - } + }; it("throws an error if passing tool_choice with unsupported models", async () => { // Claude 2 should throw const claude2Model = new ChatBedrockConverse({ ...baseConstructorArgs, - model: "anthropic.claude-v2" - }) - expect(() => claude2Model.bindTools([tool], { - tool_choice: tool.name, - })).toThrow(); - + model: "anthropic.claude-v2", + }); + expect(() => + claude2Model.bindTools([tool], { + tool_choice: tool.name, + }) + ).toThrow(); + // Cohere should throw const cohereModel = new ChatBedrockConverse({ ...baseConstructorArgs, - model: "cohere.command-text-v14" - }) - expect(() => cohereModel.bindTools([tool], { - tool_choice: tool.name, - })).toThrow(); - + model: "cohere.command-text-v14", + }); + expect(() => + cohereModel.bindTools([tool], { + tool_choice: tool.name, + }) + ).toThrow(); + // Mistral (not mistral large) should throw const mistralModel = new ChatBedrockConverse({ ...baseConstructorArgs, - model: "mistral.mistral-7b-instruct-v0:2" - }) - expect(() => mistralModel.bindTools([tool], { - tool_choice: tool.name, - })).toThrow(); - }) - + model: "mistral.mistral-7b-instruct-v0:2", + }); + expect(() => + mistralModel.bindTools([tool], { + tool_choice: tool.name, + }) + ).toThrow(); + }); + it("does NOT throw and binds tool_choice when calling bindTools with supported models", async () => { // Claude 3 should NOT throw const claude3Model = new ChatBedrockConverse({ ...baseConstructorArgs, - model: "anthropic.claude-3-5-sonnet-20240620-v1:0" - }) + model: "anthropic.claude-3-5-sonnet-20240620-v1:0", + }); const claude3ModelWithTool = claude3Model.bindTools([tool], { tool_choice: tool.name, - }) + }); expect(claude3ModelWithTool).toBeDefined(); const claude3ModelWithToolAsJSON = claude3ModelWithTool.toJSON(); if (!("kwargs" in claude3ModelWithToolAsJSON)) { throw new Error("kwargs not found in claude3ModelWithToolAsJSON"); } - expect(claude3ModelWithToolAsJSON.kwargs.kwargs).toHaveProperty("tool_choice"); - expect(claude3ModelWithToolAsJSON.kwargs.kwargs.tool_choice).toBe(tool.name); - + expect(claude3ModelWithToolAsJSON.kwargs.kwargs).toHaveProperty( + "tool_choice" + ); + expect(claude3ModelWithToolAsJSON.kwargs.kwargs.tool_choice).toBe( + tool.name + ); + // Mistral large should NOT throw const mistralModel = new ChatBedrockConverse({ ...baseConstructorArgs, - model: "mistral.mistral-large-2407-v1:0" - }) + model: "mistral.mistral-large-2407-v1:0", + }); const mistralModelWithTool = mistralModel.bindTools([tool], { tool_choice: tool.name, }); @@ -415,88 +425,106 @@ describe("tool_choice works for supported models", () => { if (!("kwargs" in mistralModelWithToolAsJSON)) { throw new Error("kwargs not found in mistralModelWithToolAsJSON"); } - expect(mistralModelWithToolAsJSON.kwargs.kwargs).toHaveProperty("tool_choice"); - expect(mistralModelWithToolAsJSON.kwargs.kwargs.tool_choice).toBe(tool.name); - }) - + expect(mistralModelWithToolAsJSON.kwargs.kwargs).toHaveProperty( + "tool_choice" + ); + expect(mistralModelWithToolAsJSON.kwargs.kwargs.tool_choice).toBe( + tool.name + ); + }); + it("should NOT bind and NOT throw when using WSO with unsupported models", async () => { // Claude 2 should NOT throw is using WSO const claude2Model = new ChatBedrockConverse({ ...baseConstructorArgs, - model: "anthropic.claude-v2" - }) + model: "anthropic.claude-v2", + }); const claude2ModelWSO = claude2Model.withStructuredOutput(tool.schema, { name: tool.name, - }) + }); expect(claude2ModelWSO).toBeDefined(); const claude2ModelWSOAsJSON = claude2ModelWSO.toJSON(); if (!("kwargs" in claude2ModelWSOAsJSON)) { throw new Error("kwargs not found in claude2ModelWSOAsJSON"); } - expect(claude2ModelWSOAsJSON.kwargs.bound.first.kwargs).not.toHaveProperty("tool_choice"); - + expect(claude2ModelWSOAsJSON.kwargs.bound.first.kwargs).not.toHaveProperty( + "tool_choice" + ); + // Cohere should NOT throw is using WSO const cohereModel = new ChatBedrockConverse({ ...baseConstructorArgs, - model: "cohere.command-text-v14" - }) + model: "cohere.command-text-v14", + }); const cohereModelWSO = cohereModel.withStructuredOutput(tool.schema, { name: tool.name, - }) + }); expect(cohereModelWSO).toBeDefined(); const cohereModelWSOAsJSON = cohereModelWSO.toJSON(); if (!("kwargs" in cohereModelWSOAsJSON)) { throw new Error("kwargs not found in cohereModelWSOAsJSON"); } - expect(cohereModelWSOAsJSON.kwargs.bound.first.kwargs).not.toHaveProperty("tool_choice"); - + expect(cohereModelWSOAsJSON.kwargs.bound.first.kwargs).not.toHaveProperty( + "tool_choice" + ); + // Mistral (not mistral large) should NOT throw is using WSO const mistralModel = new ChatBedrockConverse({ ...baseConstructorArgs, - model: "mistral.mistral-7b-instruct-v0:2" - }) + model: "mistral.mistral-7b-instruct-v0:2", + }); const mistralModelWSO = mistralModel.withStructuredOutput(tool.schema, { name: tool.name, - }) + }); expect(mistralModelWSO).toBeDefined(); const mistralModelWSOAsJSON = mistralModelWSO.toJSON(); if (!("kwargs" in mistralModelWSOAsJSON)) { throw new Error("kwargs not found in mistralModelWSOAsJSON"); } - expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs).not.toHaveProperty("tool_choice"); - }) - + expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs).not.toHaveProperty( + "tool_choice" + ); + }); + it("should bind tool_choice when using WSO with supported models", async () => { // Claude 3 should NOT throw is using WSO & it should have `tool_choice` bound. const claude3Model = new ChatBedrockConverse({ ...baseConstructorArgs, - model: "anthropic.claude-3-5-sonnet-20240620-v1:0" - }) + model: "anthropic.claude-3-5-sonnet-20240620-v1:0", + }); const claude3ModelWSO = claude3Model.withStructuredOutput(tool.schema, { name: tool.name, - }) + }); expect(claude3ModelWSO).toBeDefined(); const claude3ModelWSOAsJSON = claude3ModelWSO.toJSON(); if (!("kwargs" in claude3ModelWSOAsJSON)) { throw new Error("kwargs not found in claude3ModelWSOAsJSON"); } - expect(claude3ModelWSOAsJSON.kwargs.bound.first.kwargs).toHaveProperty("tool_choice"); - expect(claude3ModelWSOAsJSON.kwargs.bound.first.kwargs.tool_choice).toBe(tool.name); - + expect(claude3ModelWSOAsJSON.kwargs.bound.first.kwargs).toHaveProperty( + "tool_choice" + ); + expect(claude3ModelWSOAsJSON.kwargs.bound.first.kwargs.tool_choice).toBe( + tool.name + ); + // Mistral (not mistral large) should NOT throw is using WSO const mistralModel = new ChatBedrockConverse({ ...baseConstructorArgs, - model: "mistral.mistral-large-2407-v1:0" - }) + model: "mistral.mistral-large-2407-v1:0", + }); const mistralModelWSO = mistralModel.withStructuredOutput(tool.schema, { name: tool.name, - }) + }); expect(mistralModelWSO).toBeDefined(); const mistralModelWSOAsJSON = mistralModelWSO.toJSON(); if (!("kwargs" in mistralModelWSOAsJSON)) { throw new Error("kwargs not found in mistralModelWSOAsJSON"); } - expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs).toHaveProperty("tool_choice"); - expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs.tool_choice).toBe(tool.name); - }) -}); \ No newline at end of file + expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs).toHaveProperty( + "tool_choice" + ); + expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs.tool_choice).toBe( + tool.name + ); + }); +}); From 7a47bfb7e7eebda156901346dce91ad0ca858940 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 7 Aug 2024 13:37:54 -0700 Subject: [PATCH 7/7] allow users to pass tool choice supported values --- libs/langchain-aws/src/chat_models.ts | 70 ++++++++++++------- libs/langchain-aws/src/common.ts | 44 ++++++++++-- .../src/tests/chat_models.test.ts | 45 +++++++----- 3 files changed, 112 insertions(+), 47 deletions(-) diff --git a/libs/langchain-aws/src/chat_models.ts b/libs/langchain-aws/src/chat_models.ts index 84a901d7314c..7b3a58dc29e9 100644 --- a/libs/langchain-aws/src/chat_models.ts +++ b/libs/langchain-aws/src/chat_models.ts @@ -53,13 +53,6 @@ import { CredentialType, } from "./types.js"; -// Models which support the `toolChoice` param. -// See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html -const ALLOWED_TOOL_CHOICE_MODELS_PREFIX = [ - "anthropic.claude-3", - "mistral.mistral-large", -]; - /** * Inputs for ChatBedrockConverse. */ @@ -139,6 +132,14 @@ export interface ChatBedrockConverseInput * Configuration information for a guardrail that you want to use in the request. */ guardrailConfig?: GuardrailConfiguration; + + /** + * Which types of `tool_choice` values the model supports. + * + * Inferred if not specified. Inferred as ['auto', 'any', 'tool'] if a 'claude-3' + * model is used, ['auto', 'any'] if a 'mistral-large' model is used, empty otherwise. + */ + supportsToolChoiceValues?: Array<"auto" | "any" | "tool">; } export interface ChatBedrockConverseCallOptions @@ -233,6 +234,14 @@ export class ChatBedrockConverse client: BedrockRuntimeClient; + /** + * Which types of `tool_choice` values the model supports. + * + * Inferred if not specified. Inferred as ['auto', 'any', 'tool'] if a 'claude-3' + * model is used, ['auto', 'any'] if a 'mistral-large' model is used, empty otherwise. + */ + supportsToolChoiceValues?: Array<"auto" | "any" | "tool">; + constructor(fields?: ChatBedrockConverseInput) { super(fields ?? {}); const { @@ -283,6 +292,18 @@ export class ChatBedrockConverse this.additionalModelRequestFields = rest?.additionalModelRequestFields; this.streamUsage = rest?.streamUsage ?? this.streamUsage; this.guardrailConfig = rest?.guardrailConfig; + + if (rest?.supportsToolChoiceValues === undefined) { + if (this.model.includes("claude-3")) { + this.supportsToolChoiceValues = ["auto", "any", "tool"]; + } else if (this.model.includes("mistral-large")) { + this.supportsToolChoiceValues = ["auto", "any"]; + } else { + this.supportsToolChoiceValues = undefined; + } + } else { + this.supportsToolChoiceValues = rest.supportsToolChoiceValues; + } } getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { @@ -305,17 +326,6 @@ export class ChatBedrockConverse AIMessageChunk, this["ParsedCallOptions"] > { - if (kwargs?.tool_choice) { - if ( - !ALLOWED_TOOL_CHOICE_MODELS_PREFIX.find((prefix) => - this.model.startsWith(prefix) - ) - ) { - throw new Error( - "Only Anthropic Claude 3 and Mistral Large models support the tool_choice parameter." - ); - } - } return this.bind({ tools: convertToConverseTools(tools), ...kwargs }); } @@ -333,7 +343,10 @@ export class ChatBedrockConverse toolConfig = { tools, toolChoice: options.tool_choice - ? convertToBedrockToolChoice(options.tool_choice, tools) + ? convertToBedrockToolChoice(options.tool_choice, tools, { + model: this.model, + supportsToolChoiceValues: this.supportsToolChoiceValues, + }) : undefined, }; } @@ -541,13 +554,18 @@ export class ChatBedrockConverse ]; } - const toolChoiceObj = ALLOWED_TOOL_CHOICE_MODELS_PREFIX.find((prefix) => - this.model.startsWith(prefix) - ) - ? { - tool_choice: tools[0].function.name, - } - : undefined; + const supportsToolChoiceValues = this.supportsToolChoiceValues ?? []; + let toolChoiceObj: { tool_choice: string } | undefined; + if (supportsToolChoiceValues.includes("tool")) { + toolChoiceObj = { + tool_choice: tools[0].function.name, + }; + } else if (supportsToolChoiceValues.includes("any")) { + toolChoiceObj = { + tool_choice: "any", + }; + } + const llm = this.bindTools(tools, toolChoiceObj); const outputParser = RunnableLambda.from( (input: AIMessageChunk): RunOutput => { diff --git a/libs/langchain-aws/src/common.ts b/libs/langchain-aws/src/common.ts index 1e201dc8a97c..fbb3356e90e4 100644 --- a/libs/langchain-aws/src/common.ts +++ b/libs/langchain-aws/src/common.ts @@ -271,18 +271,27 @@ export type BedrockConverseToolChoice = export function convertToBedrockToolChoice( toolChoice: BedrockConverseToolChoice, - tools: BedrockTool[] + tools: BedrockTool[], + fields: { + model: string; + supportsToolChoiceValues?: Array<"auto" | "any" | "tool">; + } ): BedrockToolChoice { + const supportsToolChoiceValues = fields.supportsToolChoiceValues ?? []; + + let bedrockToolChoice: BedrockToolChoice; if (typeof toolChoice === "string") { switch (toolChoice) { case "any": - return { + bedrockToolChoice = { any: {}, }; + break; case "auto": - return { + bedrockToolChoice = { auto: {}, }; + break; default: { const foundTool = tools.find( (tool) => tool.toolSpec?.name === toolChoice @@ -292,15 +301,40 @@ export function convertToBedrockToolChoice( `Tool with name ${toolChoice} not found in tools list.` ); } - return { + bedrockToolChoice = { tool: { name: toolChoice, }, }; } } + } else { + bedrockToolChoice = toolChoice; + } + + const toolChoiceType = Object.keys(bedrockToolChoice)[0] as + | "auto" + | "any" + | "tool"; + if (!supportsToolChoiceValues.includes(toolChoiceType)) { + let supportedTxt = ""; + if (supportsToolChoiceValues.length) { + supportedTxt = + `Model ${fields.model} does not currently support 'tool_choice' ` + + `of type ${toolChoiceType}. The following 'tool_choice' types ` + + `are supported: ${supportsToolChoiceValues.join(", ")}.`; + } else { + supportedTxt = `Model ${fields.model} does not currently support 'tool_choice'.`; + } + + throw new Error( + `${supportedTxt} Please see` + + "https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html" + + "for the latest documentation on models that support tool choice." + ); } - return toolChoice; + + return bedrockToolChoice; } export function convertConverseMessageToLangChainMessage( diff --git a/libs/langchain-aws/src/tests/chat_models.test.ts b/libs/langchain-aws/src/tests/chat_models.test.ts index c95213761a39..e82faf216627 100644 --- a/libs/langchain-aws/src/tests/chat_models.test.ts +++ b/libs/langchain-aws/src/tests/chat_models.test.ts @@ -355,6 +355,15 @@ describe("tool_choice works for supported models", () => { accessKeyId: "process.env.BEDROCK_AWS_ACCESS_KEY_ID", }, }; + const supportsToolChoiceValuesClaude3: Array<"auto" | "any" | "tool"> = [ + "auto", + "any", + "tool", + ]; + const supportsToolChoiceValuesMistralLarge: Array<"auto" | "any" | "tool"> = [ + "auto", + "any", + ]; it("throws an error if passing tool_choice with unsupported models", async () => { // Claude 2 should throw @@ -362,33 +371,30 @@ describe("tool_choice works for supported models", () => { ...baseConstructorArgs, model: "anthropic.claude-v2", }); - expect(() => - claude2Model.bindTools([tool], { - tool_choice: tool.name, - }) - ).toThrow(); + const claude2WithTool = claude2Model.bindTools([tool], { + tool_choice: tool.name, + }); + await expect(claude2WithTool.invoke("foo")).rejects.toThrow(); // Cohere should throw const cohereModel = new ChatBedrockConverse({ ...baseConstructorArgs, model: "cohere.command-text-v14", }); - expect(() => - cohereModel.bindTools([tool], { - tool_choice: tool.name, - }) - ).toThrow(); + const cohereModelWithTool = cohereModel.bindTools([tool], { + tool_choice: tool.name, + }); + await expect(cohereModelWithTool.invoke("foo")).rejects.toThrow(); // Mistral (not mistral large) should throw const mistralModel = new ChatBedrockConverse({ ...baseConstructorArgs, model: "mistral.mistral-7b-instruct-v0:2", }); - expect(() => - mistralModel.bindTools([tool], { - tool_choice: tool.name, - }) - ).toThrow(); + const mistralModelWithTool = mistralModel.bindTools([tool], { + tool_choice: tool.name, + }); + await expect(mistralModelWithTool.invoke("foo")).rejects.toThrow(); }); it("does NOT throw and binds tool_choice when calling bindTools with supported models", async () => { @@ -396,6 +402,7 @@ describe("tool_choice works for supported models", () => { const claude3Model = new ChatBedrockConverse({ ...baseConstructorArgs, model: "anthropic.claude-3-5-sonnet-20240620-v1:0", + supportsToolChoiceValues: supportsToolChoiceValuesClaude3, }); const claude3ModelWithTool = claude3Model.bindTools([tool], { tool_choice: tool.name, @@ -416,6 +423,7 @@ describe("tool_choice works for supported models", () => { const mistralModel = new ChatBedrockConverse({ ...baseConstructorArgs, model: "mistral.mistral-large-2407-v1:0", + supportsToolChoiceValues: supportsToolChoiceValuesMistralLarge, }); const mistralModelWithTool = mistralModel.bindTools([tool], { tool_choice: tool.name, @@ -491,6 +499,8 @@ describe("tool_choice works for supported models", () => { const claude3Model = new ChatBedrockConverse({ ...baseConstructorArgs, model: "anthropic.claude-3-5-sonnet-20240620-v1:0", + // We are not passing the `supportsToolChoiceValues` arg here as + // it should be inferred from the model name. }); const claude3ModelWSO = claude3Model.withStructuredOutput(tool.schema, { name: tool.name, @@ -511,6 +521,8 @@ describe("tool_choice works for supported models", () => { const mistralModel = new ChatBedrockConverse({ ...baseConstructorArgs, model: "mistral.mistral-large-2407-v1:0", + // We are not passing the `supportsToolChoiceValues` arg here as + // it should be inferred from the model name. }); const mistralModelWSO = mistralModel.withStructuredOutput(tool.schema, { name: tool.name, @@ -523,8 +535,9 @@ describe("tool_choice works for supported models", () => { expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs).toHaveProperty( "tool_choice" ); + // Mistral large only supports "auto" and "any" for tool_choice, not the actual tool name expect(mistralModelWSOAsJSON.kwargs.bound.first.kwargs.tool_choice).toBe( - tool.name + "any" ); }); });