Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

google-common[minor]: Fix streaming tool calls #6204

Merged
merged 3 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions libs/langchain-google-common/src/utils/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { v4 as uuidv4 } from "uuid";
import {
AIMessage,
AIMessageChunk,
AIMessageFields,
AIMessageChunkFields,
BaseMessage,
BaseMessageChunk,
BaseMessageFields,
Expand Down Expand Up @@ -566,7 +566,7 @@ export function chunkToString(chunk: BaseMessageChunk): string {
}

export function partToMessageChunk(part: GeminiPart): BaseMessageChunk {
const fields = partsToBaseMessageFields([part]);
const fields = partsToBaseMessageChunkFields([part]);
if (typeof fields.content === "string") {
return new AIMessageChunk(fields);
} else if (fields.content.every((item) => item.type === "text")) {
Expand Down Expand Up @@ -636,12 +636,15 @@ export function responseToBaseMessageFields(
response: GoogleLLMResponse
): BaseMessageFields {
const parts = responseToParts(response);
return partsToBaseMessageFields(parts);
return partsToBaseMessageChunkFields(parts);
}

export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
const fields: AIMessageFields = {
export function partsToBaseMessageChunkFields(
parts: GeminiPart[]
): AIMessageChunkFields {
const fields: AIMessageChunkFields = {
content: partsToMessageContent(parts),
tool_call_chunks: [],
tool_calls: [],
invalid_tool_calls: [],
};
Expand All @@ -650,6 +653,13 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
if (rawTools.length > 0) {
const tools = toolsRawToTools(rawTools);
for (const tool of tools) {
fields.tool_call_chunks?.push({
name: tool.function.name,
args: tool.function.arguments,
id: tool.id,
type: "tool_call_chunk",
});

try {
fields.tool_calls?.push({
name: tool.function.name,
Expand All @@ -661,7 +671,7 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
} catch (e: any) {
fields.invalid_tool_calls?.push({
name: tool.function.name,
args: JSON.parse(tool.function.arguments),
args: tool.function.arguments,
id: tool.id,
error: e.message,
type: "invalid_tool_call",
Expand Down
3 changes: 2 additions & 1 deletion libs/langchain-google-vertexai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@
"release-it": "^15.10.1",
"rollup": "^4.5.2",
"ts-jest": "^29.1.0",
"typescript": "<5.2.0"
"typescript": "<5.2.0",
"zod": "^3.22.4"
},
"publishConfig": {
"access": "public"
Expand Down
177 changes: 86 additions & 91 deletions libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,111 +11,70 @@ import {
SystemMessage,
ToolMessage,
} from "@langchain/core/messages";
import { ChatVertexAI } from "../chat_models.js";
import { tool } from "@langchain/core/tools";
import { concat } from "@langchain/core/utils/stream";
import { z } from "zod";
import { GeminiTool } from "../types.js";
import { ChatVertexAI } from "../chat_models.js";

describe("GAuth Chat", () => {
test("invoke", async () => {
const model = new ChatVertexAI();
try {
const res = await model.invoke("What is 1 + 1?");
expect(res).toBeDefined();
expect(res._getType()).toEqual("ai");

const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/);
const res = await model.invoke("What is 1 + 1?");
expect(res).toBeDefined();
expect(res._getType()).toEqual("ai");

/*
expect(aiMessage.content.length).toBeGreaterThan(0);
expect(aiMessage.content[0]).toBeDefined();
const content = aiMessage.content[0] as MessageContentComplex;
expect(content).toHaveProperty("type");
expect(content.type).toEqual("text");
const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

const textContent = content as MessageContentText;
expect(textContent.text).toBeDefined();
expect(textContent.text).toEqual("2");
*/
} catch (e) {
console.error(e);
throw e;
}
expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/);
});

test("generate", async () => {
const model = new ChatVertexAI();
try {
const messages: BaseMessage[] = [
new SystemMessage(
"You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails."
),
new HumanMessage("Flip it"),
new AIMessage("T"),
new HumanMessage("Flip the coin again"),
];
const res = await model.predictMessages(messages);
expect(res).toBeDefined();
expect(res._getType()).toEqual("ai");

const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(["H", "T"]).toContainEqual(text);

/*
expect(aiMessage.content.length).toBeGreaterThan(0);
expect(aiMessage.content[0]).toBeDefined();
const messages: BaseMessage[] = [
new SystemMessage(
"You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails."
),
new HumanMessage("Flip it"),
new AIMessage("T"),
new HumanMessage("Flip the coin again"),
];
const res = await model.predictMessages(messages);
expect(res).toBeDefined();
expect(res._getType()).toEqual("ai");

const content = aiMessage.content[0] as MessageContentComplex;
expect(content).toHaveProperty("type");
expect(content.type).toEqual("text");
const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

const textContent = content as MessageContentText;
expect(textContent.text).toBeDefined();
expect(["H", "T"]).toContainEqual(textContent.text);
*/
} catch (e) {
console.error(e);
throw e;
}
expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(["H", "T"]).toContainEqual(text);
});

test("stream", async () => {
const model = new ChatVertexAI();
try {
const input: BaseLanguageModelInput = new ChatPromptValue([
new SystemMessage(
"You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails."
),
new HumanMessage("Flip it"),
new AIMessage("T"),
new HumanMessage("Flip the coin again"),
]);
const res = await model.stream(input);
const resArray: BaseMessageChunk[] = [];
for await (const chunk of res) {
resArray.push(chunk);
}
expect(resArray).toBeDefined();
expect(resArray.length).toBeGreaterThanOrEqual(1);

const lastChunk = resArray[resArray.length - 1];
expect(lastChunk).toBeDefined();
expect(lastChunk._getType()).toEqual("ai");
const aiChunk = lastChunk as AIMessageChunk;
console.log(aiChunk);

console.log(JSON.stringify(resArray, null, 2));
} catch (e) {
console.error(e);
throw e;
const input: BaseLanguageModelInput = new ChatPromptValue([
new SystemMessage(
"You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails."
),
new HumanMessage("Flip it"),
new AIMessage("T"),
new HumanMessage("Flip the coin again"),
]);
const res = await model.stream(input);
const resArray: BaseMessageChunk[] = [];
for await (const chunk of res) {
resArray.push(chunk);
}
expect(resArray).toBeDefined();
expect(resArray.length).toBeGreaterThanOrEqual(1);

const lastChunk = resArray[resArray.length - 1];
expect(lastChunk).toBeDefined();
expect(lastChunk._getType()).toEqual("ai");
});

test("function", async () => {
Expand Down Expand Up @@ -209,7 +168,7 @@ describe("GAuth Chat", () => {
for await (const chunk of res) {
resArray.push(chunk);
}
console.log(JSON.stringify(resArray, null, 2));
// console.log(JSON.stringify(resArray, null, 2));
});

test("withStructuredOutput", async () => {
Expand Down Expand Up @@ -249,7 +208,7 @@ test("Stream token count usage_metadata", async () => {
res = res.concat(chunk);
}
}
console.log(res);
// console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
Expand All @@ -276,7 +235,7 @@ test("streamUsage excludes token usage", async () => {
res = res.concat(chunk);
}
}
console.log(res);
// console.log(res);
expect(res?.usage_metadata).not.toBeDefined();
});

Expand All @@ -286,7 +245,7 @@ test("Invoke token count usage_metadata", async () => {
maxOutputTokens: 10,
});
const res = await model.invoke("Why is the sky blue? Be concise.");
console.log(res);
// console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
Expand Down Expand Up @@ -322,3 +281,39 @@ test("Streaming true constructor param will stream", async () => {

expect(totalTokenCount).toBeGreaterThan(1);
});

test("ChatGoogleGenerativeAI can stream tools", async () => {
const model = new ChatVertexAI({});

const weatherTool = tool(
(_) => "The weather in San Francisco today is 18 degrees and sunny.",
{
name: "current_weather_tool",
description: "Get the current weather for a given location.",
schema: z.object({
location: z.string().describe("The location to get the weather for."),
}),
}
);

const modelWithTools = model.bindTools([weatherTool]);
const stream = await modelWithTools.stream(
"Whats the weather like today in San Francisco?"
);
let finalChunk: AIMessageChunk | undefined;
for await (const chunk of stream) {
finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk);
}

expect(finalChunk).toBeDefined();
if (!finalChunk) return;

const toolCalls = finalChunk.tool_calls;
expect(toolCalls).toBeDefined();
if (!toolCalls) {
throw new Error("tool_calls not in response");
}
expect(toolCalls.length).toBe(1);
expect(toolCalls[0].name).toBe("current_weather_tool");
expect(toolCalls[0].args).toHaveProperty("location");
});
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
Cls: ChatVertexAI,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
invokeResponseType: AIMessageChunk,
constructorArgs: {
model: "gemini-1.5-pro",
},
Expand All @@ -32,6 +33,14 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
"Not implemented."
);
}

async testInvokeMoreComplexTools() {
this.skipTestMessage(
"testInvokeMoreComplexTools",
"ChatVertexAI",
"Google VertexAI does not support tool schemas where the object properties are not defined."
);
}
}

const testClass = new ChatVertexAIStandardIntegrationTests();
Expand Down
1 change: 1 addition & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -11695,6 +11695,7 @@ __metadata:
rollup: ^4.5.2
ts-jest: ^29.1.0
typescript: <5.2.0
zod: ^3.22.4
languageName: unknown
linkType: soft

Expand Down
Loading