Skip to content

Commit

Permalink
anthropic[minor]: Add tool_choice arg (#5416)
Browse files Browse the repository at this point in the history
* anthropic[minor]: Add tool_choice arg

* chore: lint files

* release 0.1.19
  • Loading branch information
bracesproul authored May 16, 2024
1 parent a9409a5 commit 08ff323
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 8 deletions.
36 changes: 36 additions & 0 deletions docs/core_docs/docs/integrations/chat/anthropic.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,42 @@ import AnthropicSingleTool from "@examples/models/chat/integration_anthropic_sin
See the LangSmith trace [here](https://smith.langchain.com/public/90c03ed0-154b-4a50-afbf-83dcbf302647/r)
:::

### Forced tool calling

import AnthropicForcedTool from "@examples/models/chat/integration_anthropic_forced_tool.ts";

In this example we'll provide the model with two tools:

- `calculator`
- `get_weather`

Then, when we call `bindTools`, we'll force the model to use the `get_weather` tool by passing the `tool_choice` arg like this:

```typescript
.bindTools({
tools,
tool_choice: {
type: "tool",
name: "get_weather",
}
});
```

Finally, we'll invoke the model, but instead of asking about the weather, we'll ask it to do some math.
Since we explicitly forced the model to use the `get_weather` tool, it will ignore the input and return the weather information (in this case it returned `<UNKNOWN>`, which is expected.)

<CodeBlock language="typescript">{AnthropicForcedTool}</CodeBlock>

The `bind_tools` argument has three possible values:

- `{ type: "tool", name: "tool_name" }` - Forces the model to use the specified tool.
- `"any"` - Allows the model to choose the tool, but still forcing it to choose at least one.
- `"auto"` - The default value. Allows the model to select any tool, or none.

:::tip
See the LangSmith trace [here](https://smith.langchain.com/public/c5cc8fe7-5e76-4607-8c43-1e0b30e4f5ca/r)
:::

### `withStructuredOutput`

import AnthropicWSA from "@examples/models/chat/integration_anthropic_wsa.ts";
Expand Down
84 changes: 84 additions & 0 deletions examples/src/models/chat/integration_anthropic_forced_tool.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import { ChatAnthropic } from "@langchain/anthropic";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";

const calculatorSchema = z.object({
operation: z
.enum(["add", "subtract", "multiply", "divide"])
.describe("The type of operation to execute."),
number1: z.number().describe("The first number to operate on."),
number2: z.number().describe("The second number to operate on."),
});

const weatherSchema = z.object({
city: z.string().describe("The city to get the weather from"),
state: z.string().optional().describe("The state to get the weather from"),
});

const tools = [
{
name: "calculator",
description: "A simple calculator tool",
input_schema: zodToJsonSchema(calculatorSchema),
},
{
name: "get_weather",
description:
"Get the weather of a specific location and return the temperature in Celsius.",
input_schema: zodToJsonSchema(weatherSchema),
},
];

const model = new ChatAnthropic({
apiKey: process.env.ANTHROPIC_API_KEY,
model: "claude-3-haiku-20240307",
}).bind({
tools,
tool_choice: {
type: "tool",
name: "get_weather",
},
});

const prompt = ChatPromptTemplate.fromMessages([
[
"system",
"You are a helpful assistant who always needs to use a calculator.",
],
["human", "{input}"],
]);

// Chain your prompt and model together
const chain = prompt.pipe(model);

const response = await chain.invoke({
input: "What is the sum of 2725 and 273639",
});
console.log(JSON.stringify(response, null, 2));
/*
{
"kwargs": {
"tool_calls": [
{
"name": "get_weather",
"args": {
"city": "<UNKNOWN>",
"state": "<UNKNOWN>"
},
"id": "toolu_01MGRNudJvSDrrCZcPa2WrBX"
}
],
"response_metadata": {
"id": "msg_01RW3R4ctq7q5g4GJuGMmRPR",
"model": "claude-3-haiku-20240307",
"stop_sequence": null,
"usage": {
"input_tokens": 672,
"output_tokens": 52
},
"stop_reason": "tool_use"
}
}
}
*/
4 changes: 2 additions & 2 deletions libs/langchain-anthropic/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@langchain/anthropic",
"version": "0.1.18",
"version": "0.1.19",
"description": "Anthropic integrations for LangChain.js",
"type": "module",
"engines": {
Expand Down Expand Up @@ -39,7 +39,7 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@anthropic-ai/sdk": "^0.20.1",
"@anthropic-ai/sdk": "^0.21.0",
"@langchain/core": "<0.3.0 || >0.1.0",
"fast-xml-parser": "^4.3.5",
"zod": "^3.22.4",
Expand Down
35 changes: 34 additions & 1 deletion libs/langchain-anthropic/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,20 @@ type AnthropicStreamingMessageCreateParams =
Anthropic.MessageCreateParamsStreaming;
type AnthropicMessageStreamEvent = Anthropic.MessageStreamEvent;
type AnthropicRequestOptions = Anthropic.RequestOptions;

type AnthropicToolChoice =
| {
type: "tool";
name: string;
}
| "any"
| "auto";
interface ChatAnthropicCallOptions extends BaseLanguageModelCallOptions {
tools?: (StructuredToolInterface | AnthropicTool)[];
/**
* Whether or not to specify what tool the model should use
* @default "auto"
*/
tool_choice?: AnthropicToolChoice;
}

type AnthropicMessageResponse = Anthropic.ContentBlock | AnthropicToolResponse;
Expand Down Expand Up @@ -546,6 +557,26 @@ export class ChatAnthropicMessages<
"messages"
> &
Kwargs {
let tool_choice:
| {
type: string;
name?: string;
}
| undefined;
if (options?.tool_choice) {
if (options?.tool_choice === "any") {
tool_choice = {
type: "any",
};
} else if (options?.tool_choice === "auto") {
tool_choice = {
type: "auto",
};
} else {
tool_choice = options?.tool_choice;
}
}

return {
model: this.model,
temperature: this.temperature,
Expand All @@ -555,6 +586,7 @@ export class ChatAnthropicMessages<
stream: this.streaming,
max_tokens: this.maxTokens,
tools: this.formatStructuredToolToAnthropic(options?.tools),
tool_choice,
...this.invocationKwargs,
};
}
Expand Down Expand Up @@ -910,6 +942,7 @@ export class ChatAnthropicMessages<
}
const llm = this.bind({
tools,
tool_choice: "any",
} as Partial<CallOptions>);

if (!includeRaw) {
Expand Down
69 changes: 69 additions & 0 deletions libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,72 @@ test("withStructuredOutput JSON Schema only", async () => {
);
expect(typeof result.location).toBe("string");
});

test("Can pass tool_choice", async () => {
const tool1 = {
name: "get_weather",
description:
"Get the weather of a specific location and return the temperature in Celsius.",
input_schema: {
type: "object",
properties: {
location: {
type: "string",
description: "The name of city to get the weather for.",
},
},
required: ["location"],
},
};
const tool2 = {
name: "calculator",
description: "Calculate any math expression and return the result.",
input_schema: {
type: "object",
properties: {
expression: {
type: "string",
description: "The math expression to calculate.",
},
},
required: ["expression"],
},
};
const tools = [tool1, tool2];

const modelWithTools = model.bindTools(tools, {
tool_choice: {
type: "tool",
name: "get_weather",
},
});

const result = await modelWithTools.invoke(
"What is the sum of 272818 and 281818?"
);
console.log(
{
tool_calls: JSON.stringify(result.content, null, 2),
},
"Can bind & invoke StructuredTools"
);
expect(Array.isArray(result.content)).toBeTruthy();
if (!Array.isArray(result.content)) {
throw new Error("Content is not an array");
}
let toolCall: AnthropicToolResponse | undefined;
result.content.forEach((item) => {
if (item.type === "tool_use") {
toolCall = item as AnthropicToolResponse;
}
});
if (!toolCall) {
throw new Error("No tool call found");
}
expect(toolCall).toBeTruthy();
const { name, input } = toolCall;
expect(toolCall.input).toEqual(result.tool_calls?.[0].args);
expect(name).toBe("get_weather");
expect(input).toBeTruthy();
expect(input.location).toBeTruthy();
});
10 changes: 5 additions & 5 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ __metadata:
languageName: node
linkType: hard

"@anthropic-ai/sdk@npm:^0.20.1":
version: 0.20.1
resolution: "@anthropic-ai/sdk@npm:0.20.1"
"@anthropic-ai/sdk@npm:^0.21.0":
version: 0.21.0
resolution: "@anthropic-ai/sdk@npm:0.21.0"
dependencies:
"@types/node": ^18.11.18
"@types/node-fetch": ^2.6.4
Expand All @@ -223,7 +223,7 @@ __metadata:
formdata-node: ^4.3.2
node-fetch: ^2.6.7
web-streams-polyfill: ^3.2.1
checksum: a880088ffeb993ea835f3ec250d53bf6ba23e97c3dfc54c915843aa8cb4778849fb7b85de0a359155c36595a5a5cc1db64139d407d2e36a2423284ebfe763cce
checksum: fbed720938487495f1d28822fa6eb3871cf7e7be325c299b69efa78e72e1e0b66d9f564003ae5d7a1e96c7555cc69c817be4b901d1847ae002f782546a4c987d
languageName: node
linkType: hard

Expand Down Expand Up @@ -8865,7 +8865,7 @@ __metadata:
version: 0.0.0-use.local
resolution: "@langchain/anthropic@workspace:libs/langchain-anthropic"
dependencies:
"@anthropic-ai/sdk": ^0.20.1
"@anthropic-ai/sdk": ^0.21.0
"@jest/globals": ^29.5.0
"@langchain/community": "workspace:*"
"@langchain/core": <0.3.0 || >0.1.0
Expand Down

0 comments on commit 08ff323

Please sign in to comment.