Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cohere[minor]: Fix token counts, add usage_metadata #5732

Merged
merged 27 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d9df403
cohere[minor]: Fix token counts, add usage_metadata
bracesproul Jun 11, 2024
0290115
chore: lint files
bracesproul Jun 11, 2024
fc30a10
remove skipped token usage tests from cohere standard int test
bracesproul Jun 11, 2024
b38a802
bump min core version to usage_metadata update
bracesproul Jun 11, 2024
448511e
add streamUsage
bracesproul Jun 11, 2024
e28c29d
added cohere to latest/lowest dep tests
bracesproul Jun 11, 2024
ddf98b1
conditionally run latest/lowest
bracesproul Jun 11, 2024
f0cc874
Merge branch 'main' into brace/cohere-token-count
bracesproul Jun 11, 2024
8a4eaa6
nit
bracesproul Jun 11, 2024
0127915
cr
bracesproul Jun 11, 2024
745aa15
cr
bracesproul Jun 11, 2024
87756a9
cr
bracesproul Jun 11, 2024
f01e778
revert
bracesproul Jun 11, 2024
427b0a2
test
bracesproul Jun 11, 2024
8676b88
try just pr
bracesproul Jun 11, 2024
7ccce30
cr
bracesproul Jun 11, 2024
fe01ba5
cr
bracesproul Jun 11, 2024
b586528
only log files
bracesproul Jun 11, 2024
8f22a96
more tests
bracesproul Jun 11, 2024
e9d4817
toJson
bracesproul Jun 11, 2024
bb17fe5
use git to access changed files
bracesproul Jun 11, 2024
c80eb11
Merge branch 'main' into brace/cohere-token-count
bracesproul Jun 11, 2024
5331f45
fix if statements
bracesproul Jun 12, 2024
2a8632e
fix test
bracesproul Jun 12, 2024
d57893d
Merge branch 'main' into brace/cohere-token-count
bracesproul Jun 12, 2024
3060a41
unfocus jest tests and add eslint rule
bracesproul Jun 12, 2024
018aa93
add eslint-plugin-jest
bracesproul Jun 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/langchain-cohere/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"license": "MIT",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! 👋 I noticed that the package.json file has an update for the "cohere-ai" dependency, which seems to be a peer dependency change. I've flagged this for your review. Keep up the great work! 🚀

"dependencies": {
"@langchain/core": ">0.1.58 <0.3.0",
"cohere-ai": "^7.9.3"
"cohere-ai": "^7.10.5"
},
"devDependencies": {
"@jest/globals": "^29.5.0",
Expand Down
78 changes: 59 additions & 19 deletions libs/langchain-cohere/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export interface CohereChatCallOptions

function convertMessagesToCohereMessages(
messages: Array<BaseMessage>
): Array<Cohere.ChatMessage> {
): Array<Cohere.Message> {
const getRole = (role: MessageType) => {
switch (role) {
case "system":
Expand Down Expand Up @@ -113,7 +113,7 @@ function convertMessagesToCohereMessages(
export class ChatCohere<
CallOptions extends CohereChatCallOptions = CohereChatCallOptions
>
extends BaseChatModel<CallOptions>
extends BaseChatModel<CallOptions, AIMessageChunk>
implements ChatCohereInput
{
static lc_name() {
Expand Down Expand Up @@ -193,8 +193,14 @@ export class ChatCohere<
const cohereMessages = convertMessagesToCohereMessages(messages);
// The last message in the array is the most recent, all other messages
// are apart of the chat history.
const { message } = cohereMessages[cohereMessages.length - 1];
const chatHistory: Cohere.ChatMessage[] = [];
const lastMessage = cohereMessages[cohereMessages.length - 1];
if (lastMessage.role === "TOOL") {
throw new Error(
"Cohere does not support tool messages as the most recent message in chat history."
);
}
const { message } = lastMessage;
const chatHistory: Cohere.Message[] = [];
if (cohereMessages.length > 1) {
chatHistory.push(...cohereMessages.slice(0, -1));
}
Expand Down Expand Up @@ -241,25 +247,22 @@ export class ChatCohere<
}
);

if ("token_count" in response) {
const {
response_tokens: completionTokens,
prompt_tokens: promptTokens,
total_tokens: totalTokens,
} = response.token_count as Record<string, number>;
if (response.meta?.tokens) {
const { inputTokens, outputTokens } = response.meta.tokens;

if (completionTokens) {
if (outputTokens) {
tokenUsage.completionTokens =
(tokenUsage.completionTokens ?? 0) + completionTokens;
(tokenUsage.completionTokens ?? 0) + outputTokens;
}

if (promptTokens) {
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens;
if (inputTokens) {
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + inputTokens;
}

if (totalTokens) {
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens;
}
tokenUsage.totalTokens =
(tokenUsage.totalTokens ?? 0) +
(tokenUsage.promptTokens ?? 0) +
(tokenUsage.completionTokens ?? 0);
}

const generationInfo: Record<string, unknown> = { ...response };
Expand All @@ -271,6 +274,11 @@ export class ChatCohere<
message: new AIMessage({
content: response.text,
additional_kwargs: generationInfo,
usage_metadata: {
input_tokens: tokenUsage.promptTokens ?? 0,
output_tokens: tokenUsage.completionTokens ?? 0,
total_tokens: tokenUsage.totalTokens ?? 0,
},
}),
generationInfo,
},
Expand All @@ -290,8 +298,14 @@ export class ChatCohere<
const cohereMessages = convertMessagesToCohereMessages(messages);
// The last message in the array is the most recent, all other messages
// are apart of the chat history.
const { message } = cohereMessages[cohereMessages.length - 1];
const chatHistory: Cohere.ChatMessage[] = [];
const lastMessage = cohereMessages[cohereMessages.length - 1];
if (lastMessage.role === "TOOL") {
throw new Error(
"Cohere does not support tool messages as the most recent message in chat history."
);
}
const { message } = lastMessage;
const chatHistory: Cohere.Message[] = [];
if (cohereMessages.length > 1) {
chatHistory.push(...cohereMessages.slice(0, -1));
}
Expand Down Expand Up @@ -335,6 +349,32 @@ export class ChatCohere<
...chunk,
},
});
} else if (
chunk.eventType === "stream-end" &&
chunk.response.meta?.tokens &&
(chunk.response.meta.tokens.inputTokens ||
chunk.response.meta.tokens.outputTokens)
) {
// stream-end events contain the final token count
const input_tokens = chunk.response.meta.tokens.inputTokens ?? 0;
const output_tokens = chunk.response.meta.tokens.outputTokens ?? 0;
yield new ChatGenerationChunk({
text: "",
message: new AIMessageChunk({
content: "",
additional_kwargs: {
eventType: "stream-end",
},
usage_metadata: {
input_tokens,
output_tokens,
total_tokens: input_tokens + output_tokens,
},
}),
generationInfo: {
eventType: "stream-end",
},
});
}
}
}
Expand Down
47 changes: 46 additions & 1 deletion libs/langchain-cohere/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable no-promise-executor-return */
import { test, expect } from "@jest/globals";
import { HumanMessage } from "@langchain/core/messages";
import { AIMessageChunk, HumanMessage } from "@langchain/core/messages";
import { ChatCohere } from "../chat_models.js";

test("ChatCohere can invoke", async () => {
Expand Down Expand Up @@ -58,3 +58,48 @@ test("should abort the request", async () => {
return ret;
}).rejects.toThrow("AbortError");
});

test("Stream token count usage_metadata", async () => {
const model = new ChatCohere({
model: "command-light",
temperature: 0,
});
let res: AIMessageChunk | null = null;
for await (const chunk of await model.stream(
"Why is the sky blue? Be concise."
)) {
if (!res) {
res = chunk;
} else {
res = res.concat(chunk);
}
}
console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
}
expect(res.usage_metadata.input_tokens).toBe(71);
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});

test("Invoke token count usage_metadata", async () => {
const model = new ChatCohere({
model: "command-light",
temperature: 0,
});
const res = await model.invoke("Why is the sky blue? Be concise.");
console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
}
expect(res.usage_metadata.input_tokens).toBe(71);
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
});
Loading
Loading