Skip to content

Commit

Permalink
community[minor]: Update bedrock to accept openai formatted tools (#5852
Browse files Browse the repository at this point in the history
)
  • Loading branch information
bracesproul authored Jun 22, 2024
1 parent 828b1f9 commit 91aacde
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 41 deletions.
2 changes: 1 addition & 1 deletion libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@langchain/core": "~0.2.6",
"@langchain/core": "~0.2.9",
"@langchain/openai": "~0.1.0",
"binary-extensions": "^2.2.0",
"expr-eval": "^2.0.2",
Expand Down
121 changes: 88 additions & 33 deletions libs/langchain-community/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ import {
LangSmithParams,
BaseChatModelCallOptions,
} from "@langchain/core/language_models/chat_models";
import { BaseLanguageModelInput } from "@langchain/core/language_models/base";
import {
BaseLanguageModelInput,
ToolDefinition,
isOpenAITool,
} from "@langchain/core/language_models/base";
import { Runnable } from "@langchain/core/runnables";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import {
Expand All @@ -32,12 +36,15 @@ import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { ToolCall } from "@langchain/core/messages/tool";
import { zodToJsonSchema } from "zod-to-json-schema";

import type { SerializedFields } from "../../load/map_keys.js";
import {
BaseBedrockInput,
BedrockLLMInputOutputAdapter,
type CredentialType,
} from "../../utils/bedrock/index.js";
import type { SerializedFields } from "../../load/map_keys.js";
import { isAnthropicTool } from "../../utils/bedrock/anthropic.js";

type AnthropicTool = Record<string, unknown>;

const PRELUDE_TOTAL_LENGTH_BYTES = 4;

Expand Down Expand Up @@ -99,6 +106,49 @@ export function convertMessagesToPrompt(
throw new Error(`Provider ${provider} does not support chat.`);
}

function formatTools(tools: BedrockChatCallOptions["tools"]): AnthropicTool[] {
if (!tools || !tools.length) {
return [];
}
if (tools.every((tc) => isStructuredTool(tc))) {
return (tools as StructuredToolInterface[]).map((tc) => ({
name: tc.name,
description: tc.description,
input_schema: zodToJsonSchema(tc.schema),
}));
}
if (tools.every((tc) => isOpenAITool(tc))) {
return (tools as ToolDefinition[]).map((tc) => ({
name: tc.function.name,
description: tc.function.description,
input_schema: tc.function.parameters,
}));
}

if (tools.every((tc) => isAnthropicTool(tc))) {
return tools as AnthropicTool[];
}

if (
tools.some((tc) => isStructuredTool(tc)) ||
tools.some((tc) => isOpenAITool(tc)) ||
tools.some((tc) => isAnthropicTool(tc))
) {
throw new Error(
"All tools passed to BedrockChat must be of the same type."
);
}
throw new Error("Invalid tool format received.");
}

export interface BedrockChatCallOptions extends BaseChatModelCallOptions {
tools?: (StructuredToolInterface | AnthropicTool | ToolDefinition)[];
}

export interface BedrockChatFields
extends Partial<BaseBedrockInput>,
BaseChatModelParams {}

/**
* A type of Large Language Model (LLM) that interacts with the Bedrock
* service. It extends the base `LLM` class and implements the
Expand Down Expand Up @@ -195,7 +245,10 @@ export function convertMessagesToPrompt(
* runStreaming().catch(console.error);
* ```
*/
export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
export class BedrockChat
extends BaseChatModel<BedrockChatCallOptions, AIMessageChunk>
implements BaseBedrockInput
{
model = "amazon.titan-tg1-large";

region: string;
Expand Down Expand Up @@ -234,7 +287,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
streamProcessingMode: "SYNCHRONOUS" | "ASYNCHRONOUS";
};

protected _anthropicTools?: Record<string, unknown>[];
protected _anthropicTools?: AnthropicTool[];

get lc_aliases(): Record<string, string> {
return {
Expand Down Expand Up @@ -268,7 +321,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
return "BedrockChat";
}

constructor(fields?: Partial<BaseBedrockInput> & BaseChatModelParams) {
constructor(fields?: BedrockChatFields) {
super(fields ?? {});

this.model = fields?.model ?? this.model;
Expand Down Expand Up @@ -318,11 +371,14 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
}

override invocationParams(options?: this["ParsedCallOptions"]) {
const callOptionTools = formatTools(options?.tools ?? []);
return {
tools: this._anthropicTools,
tools: [...(this._anthropicTools ?? []), ...callOptionTools],
temperature: this.temperature,
max_tokens: this.maxTokens,
stop: options?.stop,
stop: options?.stop ?? this.stopSequences,
modelKwargs: this.modelKwargs,
guardrailConfig: this.guardrailConfig,
};
}

Expand All @@ -340,7 +396,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {

async _generate(
messages: BaseMessage[],
options: Partial<BaseChatModelParams>,
options: Partial<this["ParsedCallOptions"]>,
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (this.streaming) {
Expand Down Expand Up @@ -368,7 +424,7 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {

async _generateNonStreaming(
messages: BaseMessage[],
options: Partial<BaseChatModelParams>,
options: Partial<this["ParsedCallOptions"]>,
_runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
const service = "bedrock-runtime";
Expand Down Expand Up @@ -412,26 +468,34 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
}
) {
const { bedrockMethod, endpointHost, provider } = fields;
const {
max_tokens,
temperature,
stop,
modelKwargs,
guardrailConfig,
tools,
} = this.invocationParams(options);
const inputBody = this.usesMessagesApi
? BedrockLLMInputOutputAdapter.prepareMessagesInput(
provider,
messages,
this.maxTokens,
this.temperature,
options.stop ?? this.stopSequences,
this.modelKwargs,
this.guardrailConfig,
this._anthropicTools
max_tokens,
temperature,
stop,
modelKwargs,
guardrailConfig,
tools
)
: BedrockLLMInputOutputAdapter.prepareInput(
provider,
convertMessagesToPromptAnthropic(messages),
this.maxTokens,
this.temperature,
options.stop ?? this.stopSequences,
this.modelKwargs,
max_tokens,
temperature,
stop,
modelKwargs,
fields.bedrockMethod,
this.guardrailConfig
guardrailConfig
);

const url = new URL(
Expand Down Expand Up @@ -680,29 +744,20 @@ export class BedrockChat extends BaseChatModel implements BaseBedrockInput {
}

override bindTools(
tools: (StructuredToolInterface | Record<string, unknown>)[],
_kwargs?: Partial<BaseChatModelCallOptions>
tools: (StructuredToolInterface | AnthropicTool | ToolDefinition)[],
_kwargs?: Partial<this["ParsedCallOptions"]>
): Runnable<
BaseLanguageModelInput,
BaseMessageChunk,
BaseChatModelCallOptions
this["ParsedCallOptions"]
> {
const provider = this.model.split(".")[0];
if (provider !== "anthropic") {
throw new Error(
"Currently, tool calling through Bedrock is only supported for Anthropic models."
);
}
this._anthropicTools = tools.map((tool) => {
if (isStructuredTool(tool)) {
return {
name: tool.name,
description: tool.description,
input_schema: zodToJsonSchema(tool.schema),
};
}
return tool;
});
this._anthropicTools = formatTools(tools);
return this;
}
}
Expand Down
102 changes: 102 additions & 0 deletions libs/langchain-community/src/chat_models/tests/chatbedrock.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import { test, expect } from "@jest/globals";
import { HumanMessage } from "@langchain/core/messages";
import { AgentExecutor, createToolCallingAgent } from "langchain/agents";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { BedrockChat as BedrockChatWeb } from "../bedrock/web.js";
import { TavilySearchResults } from "../../tools/tavily_search.js";

Expand Down Expand Up @@ -383,3 +385,103 @@ test.skip.each([

expect(res.content.length).toBeGreaterThan(1);
});

test.skip("withStructuredOutput", async () => {
const weatherTool = z
.object({
city: z.string().describe("The city to get the weather for"),
state: z.string().describe("The state to get the weather for").optional(),
})
.describe("Get the weather for a city");
const model = new BedrockChatWeb({
region: process.env.BEDROCK_AWS_REGION,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});
const modelWithTools = model.withStructuredOutput(weatherTool, {
name: "weather",
});
const response = await modelWithTools.invoke(
"Whats the weather like in san francisco?"
);
expect(response.city.toLowerCase()).toBe("san francisco");
});

test.skip(".bind tools", async () => {
const weatherTool = z
.object({
city: z.string().describe("The city to get the weather for"),
state: z.string().describe("The state to get the weather for").optional(),
})
.describe("Get the weather for a city");
const model = new BedrockChatWeb({
region: process.env.BEDROCK_AWS_REGION,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});
const modelWithTools = model.bind({
tools: [
{
name: "weather_tool",
description: weatherTool.description,
input_schema: zodToJsonSchema(weatherTool),
},
],
});
const response = await modelWithTools.invoke(
"Whats the weather like in san francisco?"
);
console.log(response);
if (!response.tool_calls?.[0]) {
throw new Error("No tool calls found in response");
}
const { tool_calls } = response;
expect(tool_calls[0].name.toLowerCase()).toBe("weather_tool");
});

test.skip(".bindTools with openai tool format", async () => {
const weatherTool = z
.object({
city: z.string().describe("The city to get the weather for"),
state: z.string().describe("The state to get the weather for").optional(),
})
.describe("Get the weather for a city");
const model = new BedrockChatWeb({
region: process.env.BEDROCK_AWS_REGION,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});
const modelWithTools = model.bind({
tools: [
{
type: "function",
function: {
name: "weather_tool",
description: weatherTool.description,
parameters: zodToJsonSchema(weatherTool),
},
},
],
});
const response = await modelWithTools.invoke(
"Whats the weather like in san francisco?"
);
console.log(response);
if (!response.tool_calls?.[0]) {
throw new Error("No tool calls found in response");
}
const { tool_calls } = response;
expect(tool_calls[0].name.toLowerCase()).toBe("weather_tool");
});
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class BedrockChatStandardIntegrationTests extends ChatModelIntegrationTests<
super({
Cls: BedrockChat,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: false,
chatModelHasStructuredOutput: true,
constructorArgs: {
region,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ class BedrockChatStandardUnitTests extends ChatModelUnitTests<
constructor() {
super({
Cls: BedrockChat,
chatModelHasToolCalling: false,
chatModelHasStructuredOutput: false,
constructorArgs: {},
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
constructorArgs: {
model: "anthropic.claude-3-sonnet-20240229-v1:0",
},
});
process.env.BEDROCK_AWS_SECRET_ACCESS_KEY = "test";
process.env.BEDROCK_AWS_ACCESS_KEY_ID = "test";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ class ChatTogetherAIStandardUnitTests extends ChatModelUnitTests<
constructor() {
super({
Cls: ChatTogetherAI,
chatModelHasToolCalling: false,
chatModelHasStructuredOutput: false,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
constructorArgs: {},
});
process.env.TOGETHER_AI_API_KEY = "test";
Expand Down
8 changes: 8 additions & 0 deletions libs/langchain-community/src/utils/bedrock/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,11 @@ export function formatMessagesForAnthropic(messages: BaseMessage[]): {
system,
};
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function isAnthropicTool(
tool: unknown
): tool is Record<string, unknown> {
if (typeof tool !== "object" || !tool) return false;
return "input_schema" in tool;
}
Loading

0 comments on commit 91aacde

Please sign in to comment.