Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cohere[minor]: Fix token counts, add usage_metadata #5732

Merged
merged 27 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d9df403
cohere[minor]: Fix token counts, add usage_metadata
bracesproul Jun 11, 2024
0290115
chore: lint files
bracesproul Jun 11, 2024
fc30a10
remove skipped token usage tests from cohere standard int test
bracesproul Jun 11, 2024
b38a802
bump min core version to usage_metadata update
bracesproul Jun 11, 2024
448511e
add streamUsage
bracesproul Jun 11, 2024
e28c29d
added cohere to latest/lowest dep tests
bracesproul Jun 11, 2024
ddf98b1
conditionally run latest/lowest
bracesproul Jun 11, 2024
f0cc874
Merge branch 'main' into brace/cohere-token-count
bracesproul Jun 11, 2024
8a4eaa6
nit
bracesproul Jun 11, 2024
0127915
cr
bracesproul Jun 11, 2024
745aa15
cr
bracesproul Jun 11, 2024
87756a9
cr
bracesproul Jun 11, 2024
f01e778
revert
bracesproul Jun 11, 2024
427b0a2
test
bracesproul Jun 11, 2024
8676b88
try just pr
bracesproul Jun 11, 2024
7ccce30
cr
bracesproul Jun 11, 2024
fe01ba5
cr
bracesproul Jun 11, 2024
b586528
only log files
bracesproul Jun 11, 2024
8f22a96
more tests
bracesproul Jun 11, 2024
e9d4817
toJson
bracesproul Jun 11, 2024
bb17fe5
use git to access changed files
bracesproul Jun 11, 2024
c80eb11
Merge branch 'main' into brace/cohere-token-count
bracesproul Jun 11, 2024
5331f45
fix if statements
bracesproul Jun 12, 2024
2a8632e
fix test
bracesproul Jun 12, 2024
d57893d
Merge branch 'main' into brace/cohere-token-count
bracesproul Jun 12, 2024
3060a41
unfocus jest tests and add eslint rule
bracesproul Jun 12, 2024
018aa93
add eslint-plugin-jest
bracesproul Jun 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/compatibility.yml
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,40 @@ jobs:
run: yarn build --filter=@langchain/standard-tests
- name: Test `@langchain/google-vertexai` with lowest deps
run: docker compose -f dependency_range_tests/docker-compose.yml run google-vertexai-lowest-deps

# Cohere
cohere-latest-deps:
runs-on: ubuntu-latest
needs: get-changed-files
if: contains(needs.get-changed-files.outputs.changed_files, 'langchain-core/') || contains(needs.get-changed-files.outputs.changed_files, 'libs/langchain-cohere/')
steps:
- uses: actions/checkout@v4
- name: Use Node.js ${{ env.NODE_VERSION }}
uses: actions/setup-node@v3
with:
node-version: ${{ env.NODE_VERSION }}
cache: "yarn"
- name: Install dependencies
run: yarn install --immutable
- name: Build `@langchain/standard-tests`
run: yarn build --filter=@langchain/standard-tests
- name: Test `@langchain/cohere` with latest deps
run: docker compose -f dependency_range_tests/docker-compose.yml run cohere-latest-deps

cohere-lowest-deps:
runs-on: ubuntu-latest
needs: get-changed-files
if: contains(needs.get-changed-files.outputs.changed_files, 'libs/langchain-cohere/')
steps:
- uses: actions/checkout@v4
- name: Use Node.js ${{ env.NODE_VERSION }}
uses: actions/setup-node@v3
with:
node-version: ${{ env.NODE_VERSION }}
cache: "yarn"
- name: Install dependencies
run: yarn install --immutable
- name: Build `@langchain/standard-tests`
run: yarn build --filter=@langchain/standard-tests
- name: Test `@langchain/cohere` with lowest deps
run: docker compose -f dependency_range_tests/docker-compose.yml run cohere-lowest-deps
32 changes: 31 additions & 1 deletion dependency_range_tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,34 @@ services:
- ../libs/langchain-standard-tests:/libs/langchain-standard-tests
- ../libs/langchain-google-vertexai:/libs/langchain-google-vertexai
- ./scripts:/scripts
command: bash /scripts/with_standard_tests/google-vertexai/test-with-lowest-deps.sh
command: bash /scripts/with_standard_tests/google-vertexai/test-with-lowest-deps.sh

# Cohere
cohere-latest-deps:
image: node:18
environment:
PUPPETEER_SKIP_DOWNLOAD: "true"
PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD: "true"
COHERE_API_KEY: ${COHERE_API_KEY}
working_dir: /app
volumes:
- ../turbo.json:/turbo.json
- ../package.json:/package.json
- ../libs/langchain-standard-tests:/libs/langchain-standard-tests
- ../libs/langchain-cohere:/libs/langchain-cohere
- ./scripts:/scripts
command: bash /scripts/with_standard_tests/cohere/test-with-latest-deps.sh
cohere-lowest-deps:
image: node:18
environment:
PUPPETEER_SKIP_DOWNLOAD: "true"
PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD: "true"
COHERE_API_KEY: ${COHERE_API_KEY}
working_dir: /app
volumes:
- ../turbo.json:/turbo.json
- ../package.json:/package.json
- ../libs/langchain-standard-tests:/libs/langchain-standard-tests
- ../libs/langchain-cohere:/libs/langchain-cohere
- ./scripts:/scripts
command: bash /scripts/with_standard_tests/cohere/test-with-lowest-deps.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"name": "dependency-range-tests",
"version": "0.0.0",
"private": true,
"description": "Tests dependency ranges for LangChain.",
"dependencies": {
"semver": "^7.5.4"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
const fs = require("fs");
const semver = require("semver");

const communityPackageJsonPath = "/app/monorepo/libs/langchain-cohere/package.json";

const currentPackageJson = JSON.parse(fs.readFileSync(communityPackageJsonPath));

if (currentPackageJson.dependencies["@langchain/core"] && !currentPackageJson.dependencies["@langchain/core"].includes("rc")) {
const minVersion = semver.minVersion(
currentPackageJson.dependencies["@langchain/core"]
).version;
currentPackageJson.overrides = {
...currentPackageJson.overrides,
"@langchain/core": minVersion,
};
currentPackageJson.dependencies = {
...currentPackageJson.dependencies,
"@langchain/core": minVersion,
};
}

fs.writeFileSync(communityPackageJsonPath, JSON.stringify(currentPackageJson, null, 2));
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY.
# yarn lockfile v1


lru-cache@^6.0.0:
version "6.0.0"
resolved "https://registry.yarnpkg.com/lru-cache/-/lru-cache-6.0.0.tgz#6d6fe6570ebd96aaf90fcad1dafa3b2566db3a94"
integrity sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==
dependencies:
yallist "^4.0.0"

semver@^7.5.4:
version "7.5.4"
resolved "https://registry.yarnpkg.com/semver/-/semver-7.5.4.tgz#483986ec4ed38e1c6c48c34894a9182dbff68a6e"
integrity sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==
dependencies:
lru-cache "^6.0.0"

yallist@^4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72"
integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env bash

set -euxo pipefail

export CI=true

# New monorepo directory paths
monorepo_dir="/app/monorepo"
monorepo_openai_dir="/app/monorepo/libs/langchain-cohere"

# Run the shared script to copy all necessary folders/files
bash /scripts/with_standard_tests/shared.sh cohere

# Navigate back to monorepo root and install dependencies
cd "$monorepo_dir"
yarn

# Navigate into `@langchain/cohere` to build and run tests
# We need to run inside the cohere directory so turbo repo does
# not try to build the package/its workspace dependencies.
cd "$monorepo_openai_dir"
yarn test
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env bash

set -euxo pipefail

export CI=true

monorepo_dir="/app/monorepo"
monorepo_cohere_dir="/app/monorepo/libs/langchain-cohere"
updater_script_dir="/app/updater_script"
updater_script_dir="/app/updater_script"
original_updater_script_dir="/scripts/with_standard_tests/cohere/node"

# Run the shared script to copy all necessary folders/files
bash /scripts/with_standard_tests/shared.sh cohere

# Copy the updater script to the monorepo
mkdir -p "$updater_script_dir"
cp "$original_updater_script_dir"/* "$updater_script_dir/"

# Install deps (e.g semver) for the updater script
cd "$updater_script_dir"
yarn
# Run the updater script
node "update_resolutions_lowest.js"


# Navigate back to monorepo root and install dependencies
cd "$monorepo_dir"
yarn

# Navigate into `@langchain/cohere` to build and run tests
# We need to run inside the cohere directory so turbo repo does
# not try to build the package/its workspace dependencies.
cd "$monorepo_cohere_dir"
yarn test
4 changes: 2 additions & 2 deletions libs/langchain-cohere/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
"author": "LangChain",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! 👋 I noticed that this PR updates the dependencies for LangChain JS, specifically the peer dependencies "@langchain/core" and "cohere-ai". This is flagged for maintainers to review the changes in peer dependencies. Keep up the great work! 🚀

"license": "MIT",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! 👋 I noticed that the package.json file has an update for the "cohere-ai" dependency, which seems to be a peer dependency change. I've flagged this for your review. Keep up the great work! 🚀

"dependencies": {
"@langchain/core": ">0.1.58 <0.3.0",
"cohere-ai": "^7.9.3"
"@langchain/core": ">=0.2.5 <0.3.0",
"cohere-ai": "^7.10.5"
},
"devDependencies": {
"@jest/globals": "^29.5.0",
Expand Down
98 changes: 76 additions & 22 deletions libs/langchain-cohere/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ export interface ChatCohereInput extends BaseChatModelParams {
model?: string;
/**
* What sampling temperature to use, between 0.0 and 2.0.
* Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
* Higher values like 0.8 will make the output more random,
* while lower values like 0.2 will make it more focused
* and deterministic.
* @default {0.3}
*/
temperature?: number;
Expand All @@ -47,6 +49,14 @@ export interface ChatCohereInput extends BaseChatModelParams {
* @default {false}
*/
streaming?: boolean;
/**
* Whether or not to include token usage when streaming.
* This will include an extra chunk at the end of the stream
* with `eventType: "stream-end"` and the token usage in
* `usage_metadata`.
* @default {true}
*/
streamUsage?: boolean;
}

interface TokenUsage {
Expand All @@ -58,11 +68,12 @@ interface TokenUsage {
export interface CohereChatCallOptions
extends BaseLanguageModelCallOptions,
Partial<Omit<Cohere.ChatRequest, "message">>,
Partial<Omit<Cohere.ChatStreamRequest, "message">> {}
Partial<Omit<Cohere.ChatStreamRequest, "message">>,
Pick<ChatCohereInput, "streamUsage"> {}

function convertMessagesToCohereMessages(
messages: Array<BaseMessage>
): Array<Cohere.ChatMessage> {
): Array<Cohere.Message> {
const getRole = (role: MessageType) => {
switch (role) {
case "system":
Expand Down Expand Up @@ -113,7 +124,7 @@ function convertMessagesToCohereMessages(
export class ChatCohere<
CallOptions extends CohereChatCallOptions = CohereChatCallOptions
>
extends BaseChatModel<CallOptions>
extends BaseChatModel<CallOptions, AIMessageChunk>
implements ChatCohereInput
{
static lc_name() {
Expand All @@ -130,6 +141,8 @@ export class ChatCohere<

streaming = false;

streamUsage: boolean = true;

constructor(fields?: ChatCohereInput) {
super(fields ?? {});

Expand All @@ -144,6 +157,7 @@ export class ChatCohere<
this.model = fields?.model ?? this.model;
this.temperature = fields?.temperature ?? this.temperature;
this.streaming = fields?.streaming ?? this.streaming;
this.streamUsage = fields?.streamUsage ?? this.streamUsage;
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
Expand Down Expand Up @@ -193,8 +207,14 @@ export class ChatCohere<
const cohereMessages = convertMessagesToCohereMessages(messages);
// The last message in the array is the most recent, all other messages
// are apart of the chat history.
const { message } = cohereMessages[cohereMessages.length - 1];
const chatHistory: Cohere.ChatMessage[] = [];
const lastMessage = cohereMessages[cohereMessages.length - 1];
if (lastMessage.role === "TOOL") {
throw new Error(
"Cohere does not support tool messages as the most recent message in chat history."
);
}
const { message } = lastMessage;
const chatHistory: Cohere.Message[] = [];
if (cohereMessages.length > 1) {
chatHistory.push(...cohereMessages.slice(0, -1));
}
Expand Down Expand Up @@ -241,25 +261,22 @@ export class ChatCohere<
}
);

if ("token_count" in response) {
const {
response_tokens: completionTokens,
prompt_tokens: promptTokens,
total_tokens: totalTokens,
} = response.token_count as Record<string, number>;
if (response.meta?.tokens) {
const { inputTokens, outputTokens } = response.meta.tokens;

if (completionTokens) {
if (outputTokens) {
tokenUsage.completionTokens =
(tokenUsage.completionTokens ?? 0) + completionTokens;
(tokenUsage.completionTokens ?? 0) + outputTokens;
}

if (promptTokens) {
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens;
if (inputTokens) {
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + inputTokens;
}

if (totalTokens) {
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens;
}
tokenUsage.totalTokens =
(tokenUsage.totalTokens ?? 0) +
(tokenUsage.promptTokens ?? 0) +
(tokenUsage.completionTokens ?? 0);
}

const generationInfo: Record<string, unknown> = { ...response };
Expand All @@ -271,6 +288,11 @@ export class ChatCohere<
message: new AIMessage({
content: response.text,
additional_kwargs: generationInfo,
usage_metadata: {
input_tokens: tokenUsage.promptTokens ?? 0,
output_tokens: tokenUsage.completionTokens ?? 0,
total_tokens: tokenUsage.totalTokens ?? 0,
},
}),
generationInfo,
},
Expand All @@ -290,8 +312,14 @@ export class ChatCohere<
const cohereMessages = convertMessagesToCohereMessages(messages);
// The last message in the array is the most recent, all other messages
// are apart of the chat history.
const { message } = cohereMessages[cohereMessages.length - 1];
const chatHistory: Cohere.ChatMessage[] = [];
const lastMessage = cohereMessages[cohereMessages.length - 1];
if (lastMessage.role === "TOOL") {
throw new Error(
"Cohere does not support tool messages as the most recent message in chat history."
);
}
const { message } = lastMessage;
const chatHistory: Cohere.Message[] = [];
if (cohereMessages.length > 1) {
chatHistory.push(...cohereMessages.slice(0, -1));
}
Expand All @@ -317,7 +345,9 @@ export class ChatCohere<
if (chunk.eventType === "text-generation") {
yield new ChatGenerationChunk({
text: chunk.text,
message: new AIMessageChunk({ content: chunk.text }),
message: new AIMessageChunk({
content: chunk.text,
}),
});
await runManager?.handleLLMNewToken(chunk.text);
} else if (chunk.eventType !== "stream-end") {
Expand All @@ -335,6 +365,30 @@ export class ChatCohere<
...chunk,
},
});
} else if (
chunk.eventType === "stream-end" &&
(this.streamUsage || options.streamUsage)
) {
// stream-end events contain the final token count
const input_tokens = chunk.response.meta?.tokens?.inputTokens ?? 0;
const output_tokens = chunk.response.meta?.tokens?.outputTokens ?? 0;
yield new ChatGenerationChunk({
text: "",
message: new AIMessageChunk({
content: "",
additional_kwargs: {
eventType: "stream-end",
},
usage_metadata: {
input_tokens,
output_tokens,
total_tokens: input_tokens + output_tokens,
},
}),
generationInfo: {
eventType: "stream-end",
},
});
}
}
}
Expand Down
Loading
Loading