Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

anthropic[minor]: Add tool_choice arg #5416

Merged
merged 3 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/core_docs/docs/integrations/chat/anthropic.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,42 @@ import AnthropicSingleTool from "@examples/models/chat/integration_anthropic_sin
See the LangSmith trace [here](https://smith.langchain.com/public/90c03ed0-154b-4a50-afbf-83dcbf302647/r)
:::

### Forced tool calling

import AnthropicForcedTool from "@examples/models/chat/integration_anthropic_forced_tool.ts";

In this example we'll provide the model with two tools:

- `calculator`
- `get_weather`

Then, when we call `bindTools`, we'll force the model to use the `get_weather` tool by passing the `tool_choice` arg like this:

```typescript
.bindTools({
tools,
tool_choice: {
type: "tool",
name: "get_weather",
}
});
```

Finally, we'll invoke the model, but instead of asking about the weather, we'll ask it to do some math.
Since we explicitly forced the model to use the `get_weather` tool, it will ignore the input and return the weather information (in this case it returned `<UNKNOWN>`, which is expected.)

<CodeBlock language="typescript">{AnthropicForcedTool}</CodeBlock>

The `bind_tools` argument has three possible values:

- `{ type: "tool", name: "tool_name" }` - Forces the model to use the specified tool.
- `"any"` - Allows the model to choose the tool, but still forcing it to choose at least one.
- `"auto"` - The default value. Allows the model to select any tool, or none.

:::tip
See the LangSmith trace [here](https://smith.langchain.com/public/c5cc8fe7-5e76-4607-8c43-1e0b30e4f5ca/r)
:::

### `withStructuredOutput`

import AnthropicWSA from "@examples/models/chat/integration_anthropic_wsa.ts";
Expand Down
84 changes: 84 additions & 0 deletions examples/src/models/chat/integration_anthropic_forced_tool.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import { ChatAnthropic } from "@langchain/anthropic";
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey team, I've flagged this PR for your review as it explicitly accesses an environment variable via process.env. Please take a look and ensure that this access is handled securely and appropriately.

import { ChatPromptTemplate } from "@langchain/core/prompts";
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";

const calculatorSchema = z.object({
operation: z
.enum(["add", "subtract", "multiply", "divide"])
.describe("The type of operation to execute."),
number1: z.number().describe("The first number to operate on."),
number2: z.number().describe("The second number to operate on."),
});

const weatherSchema = z.object({
city: z.string().describe("The city to get the weather from"),
state: z.string().optional().describe("The state to get the weather from"),
});

const tools = [
{
name: "calculator",
description: "A simple calculator tool",
input_schema: zodToJsonSchema(calculatorSchema),
},
{
name: "get_weather",
description:
"Get the weather of a specific location and return the temperature in Celsius.",
input_schema: zodToJsonSchema(weatherSchema),
},
];

const model = new ChatAnthropic({
apiKey: process.env.ANTHROPIC_API_KEY,
model: "claude-3-haiku-20240307",
}).bind({
tools,
tool_choice: {
type: "tool",
name: "get_weather",
},
});

const prompt = ChatPromptTemplate.fromMessages([
[
"system",
"You are a helpful assistant who always needs to use a calculator.",
],
["human", "{input}"],
]);

// Chain your prompt and model together
const chain = prompt.pipe(model);

const response = await chain.invoke({
input: "What is the sum of 2725 and 273639",
});
console.log(JSON.stringify(response, null, 2));
/*
{
"kwargs": {
"tool_calls": [
{
"name": "get_weather",
"args": {
"city": "<UNKNOWN>",
"state": "<UNKNOWN>"
},
"id": "toolu_01MGRNudJvSDrrCZcPa2WrBX"
}
],
"response_metadata": {
"id": "msg_01RW3R4ctq7q5g4GJuGMmRPR",
"model": "claude-3-haiku-20240307",
"stop_sequence": null,
"usage": {
"input_tokens": 672,
"output_tokens": 52
},
"stop_reason": "tool_use"
}
}
}
*/
4 changes: 2 additions & 2 deletions libs/langchain-anthropic/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@langchain/anthropic",
"version": "0.1.18",
"version": "0.1.19",
"description": "Anthropic integrations for LangChain.js",
"type": "module",
"engines": {
Expand Down Expand Up @@ -39,7 +39,7 @@
"author": "LangChain",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! 👋 I noticed that the "@anthropic-ai/sdk" dependency has been updated to version "^0.21.0" in the package.json file. This change is flagged for maintainers to review. Keep up the great work!

"license": "MIT",
"dependencies": {
"@anthropic-ai/sdk": "^0.20.1",
"@anthropic-ai/sdk": "^0.21.0",
"@langchain/core": "<0.3.0 || >0.1.0",
"fast-xml-parser": "^4.3.5",
"zod": "^3.22.4",
Expand Down
35 changes: 34 additions & 1 deletion libs/langchain-anthropic/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,20 @@ type AnthropicStreamingMessageCreateParams =
Anthropic.MessageCreateParamsStreaming;
type AnthropicMessageStreamEvent = Anthropic.MessageStreamEvent;
type AnthropicRequestOptions = Anthropic.RequestOptions;

type AnthropicToolChoice =
| {
type: "tool";
name: string;
}
| "any"
| "auto";
interface ChatAnthropicCallOptions extends BaseLanguageModelCallOptions {
tools?: (StructuredToolInterface | AnthropicTool)[];
/**
* Whether or not to specify what tool the model should use
* @default "auto"
*/
tool_choice?: AnthropicToolChoice;
}

type AnthropicMessageResponse = Anthropic.ContentBlock | AnthropicToolResponse;
Expand Down Expand Up @@ -546,6 +557,26 @@ export class ChatAnthropicMessages<
"messages"
> &
Kwargs {
let tool_choice:
| {
type: string;
name?: string;
}
| undefined;
if (options?.tool_choice) {
if (options?.tool_choice === "any") {
tool_choice = {
type: "any",
};
} else if (options?.tool_choice === "auto") {
tool_choice = {
type: "auto",
};
} else {
tool_choice = options?.tool_choice;
}
}

return {
model: this.model,
temperature: this.temperature,
Expand All @@ -555,6 +586,7 @@ export class ChatAnthropicMessages<
stream: this.streaming,
max_tokens: this.maxTokens,
tools: this.formatStructuredToolToAnthropic(options?.tools),
tool_choice,
...this.invocationKwargs,
};
}
Expand Down Expand Up @@ -910,6 +942,7 @@ export class ChatAnthropicMessages<
}
const llm = this.bind({
tools,
tool_choice: "any",
} as Partial<CallOptions>);

if (!includeRaw) {
Expand Down
69 changes: 69 additions & 0 deletions libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,72 @@ test("withStructuredOutput JSON Schema only", async () => {
);
expect(typeof result.location).toBe("string");
});

test("Can pass tool_choice", async () => {
const tool1 = {
name: "get_weather",
description:
"Get the weather of a specific location and return the temperature in Celsius.",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The name of city to get the weather for.",
},
},
required: ["location"],
},
};
const tool2 = {
name: "calculator",
description: "Calculate any math expression and return the result.",
input_schema: {
type: "object",
properties: {
expression: {
type: "string",
description: "The math expression to calculate.",
},
},
required: ["expression"],
},
};
const tools = [tool1, tool2];

const modelWithTools = model.bindTools(tools, {
tool_choice: {
type: "tool",
name: "get_weather",
},
});

const result = await modelWithTools.invoke(
"What is the sum of 272818 and 281818?"
);
console.log(
{
tool_calls: JSON.stringify(result.content, null, 2),
},
"Can bind & invoke StructuredTools"
);
expect(Array.isArray(result.content)).toBeTruthy();
if (!Array.isArray(result.content)) {
throw new Error("Content is not an array");
}
let toolCall: AnthropicToolResponse | undefined;
result.content.forEach((item) => {
if (item.type === "tool_use") {
toolCall = item as AnthropicToolResponse;
}
});
if (!toolCall) {
throw new Error("No tool call found");
}
expect(toolCall).toBeTruthy();
const { name, input } = toolCall;
expect(toolCall.input).toEqual(result.tool_calls?.[0].args);
expect(name).toBe("get_weather");
expect(input).toBeTruthy();
expect(input.location).toBeTruthy();
});
10 changes: 5 additions & 5 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ __metadata:
languageName: node
linkType: hard

"@anthropic-ai/sdk@npm:^0.20.1":
version: 0.20.1
resolution: "@anthropic-ai/sdk@npm:0.20.1"
"@anthropic-ai/sdk@npm:^0.21.0":
version: 0.21.0
resolution: "@anthropic-ai/sdk@npm:0.21.0"
dependencies:
"@types/node": ^18.11.18
"@types/node-fetch": ^2.6.4
Expand All @@ -223,7 +223,7 @@ __metadata:
formdata-node: ^4.3.2
node-fetch: ^2.6.7
web-streams-polyfill: ^3.2.1
checksum: a880088ffeb993ea835f3ec250d53bf6ba23e97c3dfc54c915843aa8cb4778849fb7b85de0a359155c36595a5a5cc1db64139d407d2e36a2423284ebfe763cce
checksum: fbed720938487495f1d28822fa6eb3871cf7e7be325c299b69efa78e72e1e0b66d9f564003ae5d7a1e96c7555cc69c817be4b901d1847ae002f782546a4c987d
languageName: node
linkType: hard

Expand Down Expand Up @@ -8865,7 +8865,7 @@ __metadata:
version: 0.0.0-use.local
resolution: "@langchain/anthropic@workspace:libs/langchain-anthropic"
dependencies:
"@anthropic-ai/sdk": ^0.20.1
"@anthropic-ai/sdk": ^0.21.0
"@jest/globals": ^29.5.0
"@langchain/community": "workspace:*"
"@langchain/core": <0.3.0 || >0.1.0
Expand Down
Loading