Skip to content

Commit

Permalink
anthropic[patch]: Fix passing streamed tool calls back to anthropic (#…
Browse files Browse the repository at this point in the history
…6199)

* anthropic[patch]: Fix passing streamed tool calls back to anthropic

* rm anthropic test, implement standard tests
  • Loading branch information
bracesproul authored Jul 24, 2024
1 parent e8a9458 commit 585e65c
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 9 deletions.
100 changes: 92 additions & 8 deletions libs/langchain-anthropic/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,18 @@ function _makeMessageChunkFromAnthropicEvent(
streamUsage: boolean;
coerceContentToString: boolean;
usageData: { input_tokens: number; output_tokens: number };
toolUse?: {
id: string;
name: string;
};
}
): {
chunk: AIMessageChunk;
usageData: { input_tokens: number; output_tokens: number };
toolUse?: {
id: string;
name: string;
};
} | null {
let usageDataCopy = { ...fields.usageData };

Expand Down Expand Up @@ -233,6 +241,10 @@ function _makeMessageChunkFromAnthropicEvent(
additional_kwargs: {},
}),
usageData: usageDataCopy,
toolUse: {
id: data.content_block.id,
name: data.content_block.name,
},
};
} else if (
data.type === "content_block_delta" &&
Expand Down Expand Up @@ -274,6 +286,25 @@ function _makeMessageChunkFromAnthropicEvent(
}),
usageData: usageDataCopy,
};
} else if (data.type === "content_block_stop" && fields.toolUse) {
// Only yield the ID & name when the tool_use block is complete.
// This is so the names & IDs do not get concatenated.
return {
chunk: new AIMessageChunk({
content: fields.coerceContentToString
? ""
: [
{
id: fields.toolUse.id,
name: fields.toolUse.name,
index: data.index,
type: "input_json_delta",
},
],
additional_kwargs: {},
}),
usageData: usageDataCopy,
};
}

return null;
Expand Down Expand Up @@ -424,6 +455,9 @@ export function _convertLangChainToolCallToAnthropic(
}

function _formatContent(content: MessageContent) {
const toolTypes = ["tool_use", "tool_result", "input_json_delta"];
const textTypes = ["text", "text_delta"];

if (typeof content === "string") {
return content;
} else {
Expand All @@ -439,19 +473,40 @@ function _formatContent(content: MessageContent) {
type: "image" as const, // Explicitly setting the type as "image"
source,
};
} else if (contentPart.type === "text") {
} else if (
textTypes.find((t) => t === contentPart.type) &&
"text" in contentPart
) {
// Assuming contentPart is of type MessageContentText here
return {
type: "text" as const, // Explicitly setting the type as "text"
text: contentPart.text,
};
} else if (
contentPart.type === "tool_use" ||
contentPart.type === "tool_result"
) {
} else if (toolTypes.find((t) => t === contentPart.type)) {
const contentPartCopy = { ...contentPart };
if ("index" in contentPartCopy) {
// Anthropic does not support passing the index field here, so we remove it.
delete contentPartCopy.index;
}

if (contentPartCopy.type === "input_json_delta") {
// `input_json_delta` type only represents yielding partial tool inputs
// and is not a valid type for Anthropic messages.
contentPartCopy.type = "tool_use";
}

if ("input" in contentPartCopy) {
// Anthropic tool use inputs should be valid objects, when applicable.
try {
contentPartCopy.input = JSON.parse(contentPartCopy.input);
} catch {
// no-op
}
}

// TODO: Fix when SDK types are fixed
return {
...contentPart,
...contentPartCopy,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any;
} else {
Expand Down Expand Up @@ -519,7 +574,9 @@ function _formatMessagesForAnthropic(messages: BaseMessage[]): {
const hasMismatchedToolCalls = !message.tool_calls.every((toolCall) =>
content.find(
(contentPart) =>
contentPart.type === "tool_use" && contentPart.id === toolCall.id
(contentPart.type === "tool_use" ||
contentPart.type === "input_json_delta") &&
contentPart.id === toolCall.id
)
);
if (hasMismatchedToolCalls) {
Expand Down Expand Up @@ -581,12 +638,16 @@ function extractToolCallChunk(
) {
if (typeof inputJsonDeltaChunks.input === "string") {
newToolCallChunk = {
id: inputJsonDeltaChunks.id,
name: inputJsonDeltaChunks.name,
args: inputJsonDeltaChunks.input,
index: inputJsonDeltaChunks.index,
type: "tool_call_chunk",
};
} else {
newToolCallChunk = {
id: inputJsonDeltaChunks.id,
name: inputJsonDeltaChunks.name,
args: JSON.stringify(inputJsonDeltaChunks.input, null, 2),
index: inputJsonDeltaChunks.index,
type: "tool_call_chunk",
Expand Down Expand Up @@ -919,6 +980,14 @@ export class ChatAnthropicMessages<
let usageData = { input_tokens: 0, output_tokens: 0 };

let concatenatedChunks: AIMessageChunk | undefined;
// Anthropic only yields the tool name and id once, so we need to save those
// so we can yield them with the rest of the tool_use content.
let toolUse:
| {
id: string;
name: string;
}
| undefined;

for await (const data of stream) {
if (options.signal?.aborted) {
Expand All @@ -930,12 +999,27 @@ export class ChatAnthropicMessages<
streamUsage: !!(this.streamUsage || options.streamUsage),
coerceContentToString,
usageData,
toolUse: toolUse
? {
id: toolUse.id,
name: toolUse.name,
}
: undefined,
});
if (!result) continue;

const { chunk, usageData: updatedUsageData } = result;
const {
chunk,
usageData: updatedUsageData,
toolUse: updatedToolUse,
} = result;

usageData = updatedUsageData;

if (updatedToolUse) {
toolUse = updatedToolUse;
}

const newToolCallChunk = extractToolCallChunk(chunk);
// Maintain concatenatedChunks for accessing the complete `tool_use` content block.
concatenatedChunks = concatenatedChunks
Expand Down
170 changes: 169 additions & 1 deletion libs/langchain-standard-tests/src/integration_tests/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import {
getBufferString,
} from "@langchain/core/messages";
import { z } from "zod";
import { StructuredTool } from "@langchain/core/tools";
import { StructuredTool, tool } from "@langchain/core/tools";
import { zodToJsonSchema } from "zod-to-json-schema";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { RunnableLambda } from "@langchain/core/runnables";
import { concat } from "@langchain/core/utils/stream";
import {
BaseChatModelsTests,
BaseChatModelsTestsFields,
Expand Down Expand Up @@ -522,6 +523,159 @@ export abstract class ChatModelIntegrationTests<
expect(cacheValue2).toEqual(cacheValue);
}

/**
* This test verifies models can invoke a tool, and use the AIMessage
* with the tool call in a followup request. This is useful when building
* agents, or other pipelines that invoke tools.
*/
async testModelCanUseToolUseAIMessage() {
if (!this.chatModelHasToolCalling) {
console.log("Test requires tool calling. Skipping...");
return;
}

const model = new this.Cls(this.constructorArgs);
if (!model.bindTools) {
throw new Error(
"bindTools undefined. Cannot test OpenAI formatted tool calls."
);
}

const weatherSchema = z.object({
location: z.string().describe("The location to get the weather for."),
});

// Define the tool
const weatherTool = tool(
(_) => "The weather in San Francisco is 70 degrees and sunny.",
{
name: "get_current_weather",
schema: weatherSchema,
description: "Get the current weather for a location.",
}
);

const modelWithTools = model.bindTools([weatherTool]);

// List of messages to initially invoke the model with, and to hold
// followup messages to invoke the model with.
const messages = [
new HumanMessage(
"What's the weather like in San Francisco right now? Use the 'get_current_weather' tool to find the answer."
),
];

const result: AIMessage = await modelWithTools.invoke(messages);

expect(result.tool_calls?.[0]).toBeDefined();
if (!result.tool_calls?.[0]) {
throw new Error("result.tool_calls is undefined");
}
const { tool_calls } = result;
expect(tool_calls[0].name).toBe("get_current_weather");

// Push the result of the tool call into the messages array so we can
// confirm in the followup request the model can use the tool call.
messages.push(result);

// Create a dummy ToolMessage representing the output of the tool call.
const toolMessage = new ToolMessage({
tool_call_id: tool_calls[0].id ?? "",
name: tool_calls[0].name,
content: await weatherTool.invoke(
tool_calls[0].args as z.infer<typeof weatherSchema>
),
});
messages.push(toolMessage);

const finalResult = await modelWithTools.invoke(messages);

expect(finalResult.content).not.toBe("");
}

/**
* Same as the above test, but streaming both model invocations.
*/
async testModelCanUseToolUseAIMessageWithStreaming() {
if (!this.chatModelHasToolCalling) {
console.log("Test requires tool calling. Skipping...");
return;
}

const model = new this.Cls(this.constructorArgs);
if (!model.bindTools) {
throw new Error(
"bindTools undefined. Cannot test OpenAI formatted tool calls."
);
}

const weatherSchema = z.object({
location: z.string().describe("The location to get the weather for."),
});

// Define the tool
const weatherTool = tool(
(_) => "The weather in San Francisco is 70 degrees and sunny.",
{
name: "get_current_weather",
schema: weatherSchema,
description: "Get the current weather for a location.",
}
);

const modelWithTools = model.bindTools([weatherTool]);

// List of messages to initially invoke the model with, and to hold
// followup messages to invoke the model with.
const messages = [
new HumanMessage(
"What's the weather like in San Francisco right now? Use the 'get_current_weather' tool to find the answer."
),
];

const stream = await modelWithTools.stream(messages);
let result: AIMessageChunk | undefined;
for await (const chunk of stream) {
result = !result ? chunk : concat(result, chunk);
}

expect(result).toBeDefined();
if (!result) return;

expect(result.tool_calls?.[0]).toBeDefined();
if (!result.tool_calls?.[0]) {
throw new Error("result.tool_calls is undefined");
}

const { tool_calls } = result;
expect(tool_calls[0].name).toBe("get_current_weather");

// Push the result of the tool call into the messages array so we can
// confirm in the followup request the model can use the tool call.
messages.push(result);

// Create a dummy ToolMessage representing the output of the tool call.
const toolMessage = new ToolMessage({
tool_call_id: tool_calls[0].id ?? "",
name: tool_calls[0].name,
content: await weatherTool.invoke(
tool_calls[0].args as z.infer<typeof weatherSchema>
),
});
messages.push(toolMessage);

const finalStream = await modelWithTools.stream(messages);
let finalResult: AIMessageChunk | undefined;
for await (const chunk of finalStream) {
finalResult = !finalResult ? chunk : concat(finalResult, chunk);
}

expect(finalResult).toBeDefined();
if (!finalResult) return;

expect(finalResult.content).not.toBe("");
}

/**
* Run all unit tests for the chat model.
* Each test is wrapped in a try/catch block to prevent the entire test suite from failing.
Expand Down Expand Up @@ -629,6 +783,20 @@ export abstract class ChatModelIntegrationTests<
console.error("testCacheComplexMessageTypes failed", e);
}

try {
await this.testModelCanUseToolUseAIMessage();
} catch (e: any) {
allTestsPassed = false;
console.error("testModelCanUseToolUseAIMessage failed", e);
}

try {
await this.testModelCanUseToolUseAIMessageWithStreaming();
} catch (e: any) {
allTestsPassed = false;
console.error("testModelCanUseToolUseAIMessageWithStreaming failed", e);
}

return allTestsPassed;
}
}

0 comments on commit 585e65c

Please sign in to comment.