Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

google-common[minor]: Add stream_usage #5763

Merged
merged 4 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/langchain-google-common/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"author": "LangChain",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I noticed the dependency change for "@langchain/core" in the package.json file. This might impact the peer/dev/hard dependencies, so I'm flagging this for the maintainers to review. Great work on the PR!

"license": "MIT",
"dependencies": {
"@langchain/core": ">0.1.56 <0.3.0",
"@langchain/core": ">=0.2.5 <0.3.0",
"uuid": "^9.0.0",
"zod-to-json-schema": "^3.22.4"
},
Expand Down
24 changes: 20 additions & 4 deletions libs/langchain-google-common/src/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { type BaseMessage } from "@langchain/core/messages";
import { UsageMetadata, type BaseMessage } from "@langchain/core/messages";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";

import {
Expand Down Expand Up @@ -150,7 +150,8 @@ export interface ChatGoogleBaseInput<AuthOptions>
extends BaseChatModelParams,
GoogleConnectionParams<AuthOptions>,
GoogleAIModelParams,
GoogleAISafetyParams {}
GoogleAISafetyParams,
Pick<GoogleAIBaseLanguageModelCallOptions, "streamUsage"> {}

function convertToGeminiTools(
structuredTools: (StructuredToolInterface | Record<string, unknown>)[]
Expand Down Expand Up @@ -216,6 +217,8 @@ export abstract class ChatGoogleBase<AuthOptions>

safetyHandler: GoogleAISafetyHandler;

streamUsage = true;

protected connection: ChatConnection<AuthOptions>;

protected streamedConnection: ChatConnection<AuthOptions>;
Expand All @@ -226,7 +229,7 @@ export abstract class ChatGoogleBase<AuthOptions>
copyAndValidateModelParamsInto(fields, this);
this.safetyHandler =
fields?.safetyHandler ?? new DefaultGeminiSafetyHandler();

this.streamUsage = fields?.streamUsage ?? this.streamUsage;
const client = this.buildClient(fields);
this.buildConnection(fields ?? {}, client);
}
Expand Down Expand Up @@ -342,12 +345,24 @@ export abstract class ChatGoogleBase<AuthOptions>

// Get the streaming parser of the response
const stream = response.data as JsonStream;

let usageMetadata: UsageMetadata | undefined;
// Loop until the end of the stream
// During the loop, yield each time we get a chunk from the streaming parser
// that is either available or added to the queue
while (!stream.streamDone) {
const output = await stream.nextChunk();
if (
output &&
output.usageMetadata &&
this.streamUsage !== false &&
options.streamUsage !== false
) {
usageMetadata = {
input_tokens: output.usageMetadata.promptTokenCount,
output_tokens: output.usageMetadata.candidatesTokenCount,
total_tokens: output.usageMetadata.totalTokenCount,
};
}
const chunk =
output !== null
? safeResponseToChatGeneration({ data: output }, this.safetyHandler)
Expand All @@ -356,6 +371,7 @@ export abstract class ChatGoogleBase<AuthOptions>
generationInfo: { finishReason: "stop" },
message: new AIMessageChunk({
content: "",
usage_metadata: usageMetadata,
}),
});
yield chunk;
Expand Down
9 changes: 8 additions & 1 deletion libs/langchain-google-common/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,14 @@ export interface GoogleAIBaseLLMInput<AuthOptions>
export interface GoogleAIBaseLanguageModelCallOptions
extends BaseLanguageModelCallOptions,
GoogleAIModelRequestParams,
GoogleAISafetyParams {}
GoogleAISafetyParams {
/**
* Whether or not to include usage data, like token counts
* in the streamed response chunks.
* @default true
*/
streamUsage?: boolean;
}

/**
* Input to LLM class.
Expand Down
11 changes: 11 additions & 0 deletions libs/langchain-google-common/src/utils/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
MessageContentText,
SystemMessage,
ToolMessage,
UsageMetadata,
isAIMessage,
} from "@langchain/core/messages";
import {
Expand Down Expand Up @@ -604,12 +605,22 @@ export function responseToChatGenerations(
id: toolCall.id,
index: i,
}));
let usageMetadata: UsageMetadata | undefined;
if ("usageMetadata" in response.data) {
usageMetadata = {
input_tokens: response.data.usageMetadata.promptTokenCount as number,
output_tokens: response.data.usageMetadata
.candidatesTokenCount as number,
total_tokens: response.data.usageMetadata.totalTokenCount as number,
};
}
ret = [
new ChatGenerationChunk({
message: new AIMessageChunk({
content: combinedContent,
additional_kwargs: ret[ret.length - 1]?.message.additional_kwargs,
tool_call_chunks: toolCallChunks,
usage_metadata: usageMetadata,
}),
text: combinedText,
generationInfo: ret[ret.length - 1].generationInfo,
Expand Down
62 changes: 62 additions & 0 deletions libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,65 @@ describe("GAuth Chat", () => {
expect(result).toHaveProperty("location");
});
});

test("Stream token count usage_metadata", async () => {
const model = new ChatVertexAI({
temperature: 0,
});
let res: AIMessageChunk | null = null;
for await (const chunk of await model.stream(
"Why is the sky blue? Be concise."
)) {
if (!res) {
res = chunk;
} else {
res = res.concat(chunk);
}
}
console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
}
expect(res.usage_metadata.input_tokens).toBe(9);
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});

test("streamUsage excludes token usage", async () => {
const model = new ChatVertexAI({
temperature: 0,
streamUsage: false,
});
let res: AIMessageChunk | null = null;
for await (const chunk of await model.stream(
"Why is the sky blue? Be concise."
)) {
if (!res) {
res = chunk;
} else {
res = res.concat(chunk);
}
}
console.log(res);
expect(res?.usage_metadata).not.toBeDefined();
});

test("Invoke token count usage_metadata", async () => {
const model = new ChatVertexAI({
temperature: 0,
});
const res = await model.invoke("Why is the sky blue? Be concise.");
console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
}
expect(res.usage_metadata.input_tokens).toBe(9);
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,6 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
});
}

async testUsageMetadataStreaming() {
this.skipTestMessage(
"testUsageMetadataStreaming",
"ChatVertexAI",
"Streaming tokens is not currently supported."
);
}

async testUsageMetadata() {
this.skipTestMessage(
"testUsageMetadata",
"ChatVertexAI",
"Usage metadata tokens is not currently supported."
);
}

async testToolMessageHistoriesListContent() {
this.skipTestMessage(
"testToolMessageHistoriesListContent",
Expand Down
2 changes: 1 addition & 1 deletion yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10194,7 +10194,7 @@ __metadata:
resolution: "@langchain/google-common@workspace:libs/langchain-google-common"
dependencies:
"@jest/globals": ^29.5.0
"@langchain/core": ">0.1.56 <0.3.0"
"@langchain/core": ">=0.2.5 <0.3.0"
"@langchain/scripts": ~0.0.14
"@swc/core": ^1.3.90
"@swc/jest": ^0.2.29
Expand Down
Loading