Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws[minor]: Implement WSO with tool_choice #6443

Merged
merged 7 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libs/langchain-aws/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"@aws-sdk/client-kendra": "^3.352.0",
"@aws-sdk/credential-provider-node": "^3.600.0",
"@langchain/core": ">=0.2.21 <0.3.0",
"zod": "^3.23.8",
"zod-to-json-schema": "^3.22.5"
},
"devDependencies": {
Expand Down
194 changes: 187 additions & 7 deletions libs/langchain-aws/src/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import type { BaseMessage } from "@langchain/core/messages";
import { AIMessageChunk } from "@langchain/core/messages";
import type { BaseLanguageModelInput } from "@langchain/core/language_models/base";
import type {
BaseLanguageModelInput,
StructuredOutputMethodOptions,
ToolDefinition,
} from "@langchain/core/language_models/base";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
type BaseChatModelParams,
Expand All @@ -24,12 +28,15 @@ import {
DefaultProviderInit,
} from "@aws-sdk/credential-provider-node";
import type { DocumentType as __DocumentType } from "@smithy/types";
import { Runnable } from "@langchain/core/runnables";
import {
ChatBedrockConverseToolType,
ConverseCommandParams,
CredentialType,
} from "./types.js";
Runnable,
RunnableLambda,
RunnablePassthrough,
RunnableSequence,
} from "@langchain/core/runnables";
import { zodToJsonSchema } from "zod-to-json-schema";
import { isZodSchema } from "@langchain/core/utils/types";
import { z } from "zod";
import {
convertToConverseTools,
convertToBedrockToolChoice,
Expand All @@ -40,6 +47,11 @@ import {
handleConverseStreamContentBlockStart,
BedrockConverseToolChoice,
} from "./common.js";
import {
ChatBedrockConverseToolType,
ConverseCommandParams,
CredentialType,
} from "./types.js";

/**
* Inputs for ChatBedrockConverse.
Expand Down Expand Up @@ -120,6 +132,14 @@ export interface ChatBedrockConverseInput
* Configuration information for a guardrail that you want to use in the request.
*/
guardrailConfig?: GuardrailConfiguration;

/**
* Which types of `tool_choice` values the model supports.
*
* Inferred if not specified. Inferred as ['auto', 'any', 'tool'] if a 'claude-3'
* model is used, ['auto', 'any'] if a 'mistral-large' model is used, empty otherwise.
*/
supportsToolChoiceValues?: Array<"auto" | "any" | "tool">;
}

export interface ChatBedrockConverseCallOptions
Expand Down Expand Up @@ -214,6 +234,14 @@ export class ChatBedrockConverse

client: BedrockRuntimeClient;

/**
* Which types of `tool_choice` values the model supports.
*
* Inferred if not specified. Inferred as ['auto', 'any', 'tool'] if a 'claude-3'
* model is used, ['auto', 'any'] if a 'mistral-large' model is used, empty otherwise.
*/
supportsToolChoiceValues?: Array<"auto" | "any" | "tool">;

constructor(fields?: ChatBedrockConverseInput) {
super(fields ?? {});
const {
Expand Down Expand Up @@ -264,6 +292,18 @@ export class ChatBedrockConverse
this.additionalModelRequestFields = rest?.additionalModelRequestFields;
this.streamUsage = rest?.streamUsage ?? this.streamUsage;
this.guardrailConfig = rest?.guardrailConfig;

if (rest?.supportsToolChoiceValues === undefined) {
if (this.model.includes("claude-3")) {
this.supportsToolChoiceValues = ["auto", "any", "tool"];
} else if (this.model.includes("mistral-large")) {
this.supportsToolChoiceValues = ["auto", "any"];
} else {
this.supportsToolChoiceValues = undefined;
}
} else {
this.supportsToolChoiceValues = rest.supportsToolChoiceValues;
}
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
Expand Down Expand Up @@ -303,7 +343,10 @@ export class ChatBedrockConverse
toolConfig = {
tools,
toolChoice: options.tool_choice
? convertToBedrockToolChoice(options.tool_choice, tools)
? convertToBedrockToolChoice(options.tool_choice, tools, {
model: this.model,
supportsToolChoiceValues: this.supportsToolChoiceValues,
})
: undefined,
};
}
Expand Down Expand Up @@ -430,4 +473,141 @@ export class ChatBedrockConverse
}
}
}

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<false>
): Runnable<BaseLanguageModelInput, RunOutput>;

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<true>
): Runnable<BaseLanguageModelInput, { raw: BaseMessage; parsed: RunOutput }>;

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<boolean>
):
| Runnable<BaseLanguageModelInput, RunOutput>
| Runnable<
BaseLanguageModelInput,
{
raw: BaseMessage;
parsed: RunOutput;
}
> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const schema: z.ZodType<RunOutput> | Record<string, any> = outputSchema;
const name = config?.name;
const description = schema.description ?? "A function available to call.";
const method = config?.method;
const includeRaw = config?.includeRaw;
if (method === "jsonMode") {
throw new Error(`ChatBedrockConverse does not support 'jsonMode'.`);
}

let functionName = name ?? "extract";
let tools: ToolDefinition[];
if (isZodSchema(schema)) {
tools = [
{
type: "function",
function: {
name: functionName,
description,
parameters: zodToJsonSchema(schema),
},
},
];
} else {
if ("name" in schema) {
functionName = schema.name;
}
tools = [
{
type: "function",
function: {
name: functionName,
description,
parameters: schema,
},
},
];
}

const supportsToolChoiceValues = this.supportsToolChoiceValues ?? [];
let toolChoiceObj: { tool_choice: string } | undefined;
if (supportsToolChoiceValues.includes("tool")) {
toolChoiceObj = {
tool_choice: tools[0].function.name,
};
} else if (supportsToolChoiceValues.includes("any")) {
toolChoiceObj = {
tool_choice: "any",
};
}

const llm = this.bindTools(tools, toolChoiceObj);
const outputParser = RunnableLambda.from<AIMessageChunk, RunOutput>(
(input: AIMessageChunk): RunOutput => {
if (!input.tool_calls || input.tool_calls.length === 0) {
throw new Error("No tool calls found in the response.");
}
const toolCall = input.tool_calls.find(
(tc) => tc.name === functionName
);
if (!toolCall) {
throw new Error(`No tool call found with name ${functionName}.`);
}
return toolCall.args as RunOutput;
}
);

if (!includeRaw) {
return llm.pipe(outputParser).withConfig({
runName: "StructuredOutput",
}) as Runnable<BaseLanguageModelInput, RunOutput>;
}

const parserAssign = RunnablePassthrough.assign({
// eslint-disable-next-line @typescript-eslint/no-explicit-any
parsed: (input: any, config) => outputParser.invoke(input.raw, config),
});
const parserNone = RunnablePassthrough.assign({
parsed: () => null,
});
const parsedWithFallback = parserAssign.withFallbacks({
fallbacks: [parserNone],
});
return RunnableSequence.from<
BaseLanguageModelInput,
{ raw: BaseMessage; parsed: RunOutput }
>([
{
raw: llm,
},
parsedWithFallback,
]).withConfig({
runName: "StructuredOutputRunnable",
});
}
}
44 changes: 39 additions & 5 deletions libs/langchain-aws/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -271,18 +271,27 @@ export type BedrockConverseToolChoice =

export function convertToBedrockToolChoice(
toolChoice: BedrockConverseToolChoice,
tools: BedrockTool[]
tools: BedrockTool[],
fields: {
model: string;
supportsToolChoiceValues?: Array<"auto" | "any" | "tool">;
}
): BedrockToolChoice {
const supportsToolChoiceValues = fields.supportsToolChoiceValues ?? [];

let bedrockToolChoice: BedrockToolChoice;
if (typeof toolChoice === "string") {
switch (toolChoice) {
case "any":
return {
bedrockToolChoice = {
any: {},
};
break;
case "auto":
return {
bedrockToolChoice = {
auto: {},
};
break;
default: {
const foundTool = tools.find(
(tool) => tool.toolSpec?.name === toolChoice
Expand All @@ -292,15 +301,40 @@ export function convertToBedrockToolChoice(
`Tool with name ${toolChoice} not found in tools list.`
);
}
return {
bedrockToolChoice = {
tool: {
name: toolChoice,
},
};
}
}
} else {
bedrockToolChoice = toolChoice;
}

const toolChoiceType = Object.keys(bedrockToolChoice)[0] as
| "auto"
| "any"
| "tool";
if (!supportsToolChoiceValues.includes(toolChoiceType)) {
let supportedTxt = "";
if (supportsToolChoiceValues.length) {
supportedTxt =
`Model ${fields.model} does not currently support 'tool_choice' ` +
`of type ${toolChoiceType}. The following 'tool_choice' types ` +
`are supported: ${supportsToolChoiceValues.join(", ")}.`;
} else {
supportedTxt = `Model ${fields.model} does not currently support 'tool_choice'.`;
}

throw new Error(
`${supportedTxt} Please see` +
"https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html" +
"for the latest documentation on models that support tool choice."
);
}
return toolChoice;

return bedrockToolChoice;
}

export function convertConverseMessageToLangChainMessage(
Expand Down
Loading
Loading