Skip to content

Commit

Permalink
Fix count tokens to include model params (#209)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsubox76 authored Jul 15, 2024
1 parent 6dc4218 commit e87d5b0
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 60 deletions.
5 changes: 5 additions & 0 deletions .changeset/rare-birds-bow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@google/generative-ai": patch
---

Fix countTokens to include any params set on the model instance.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"@web/test-runner": "^0.18.0",
"chai": "^4.3.10",
"chai-as-promised": "^7.1.1",
"chai-deep-equal-ignore-undefined": "^1.1.1",
"eslint": "^8.52.0",
"eslint-plugin-import": "^2.29.0",
"eslint-plugin-unused-imports": "^3.0.0",
Expand Down
5 changes: 4 additions & 1 deletion packages/main/src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ describe("GenerativeModel", () => {
"apiKey",
{
model: "my-model",
systemInstruction: "you are a cat",
},
{
apiVersion: "v2000",
Expand All @@ -257,7 +258,9 @@ describe("GenerativeModel", () => {
request.Task.COUNT_TOKENS,
match.any,
false,
match.any,
match((value: string) => {
return value.includes("hello") && value.includes("you are a cat");
}),
match((value) => {
return value.apiVersion === "v2000";
}),
Expand Down
10 changes: 9 additions & 1 deletion packages/main/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,15 @@ export class GenerativeModel {
async countTokens(
request: CountTokensRequest | string | Array<string | Part>,
): Promise<CountTokensResponse> {
const formattedParams = formatCountTokensInput(request, this.model);
const formattedParams = formatCountTokensInput(request, {
model: this.model,
generationConfig: this.generationConfig,
safetySettings: this.safetySettings,
tools: this.tools,
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
cachedContent: this.cachedContent,
});
return countTokens(
this.apiKey,
this.model,
Expand Down
105 changes: 104 additions & 1 deletion packages/main/src/requests/request-helpers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@

import { expect, use } from "chai";
import * as sinonChai from "sinon-chai";
import chaiDeepEqualIgnoreUndefined from "chai-deep-equal-ignore-undefined";
import { Content } from "../../types";
import { formatGenerateContentInput } from "./request-helpers";
import {
formatCountTokensInput,
formatGenerateContentInput,
} from "./request-helpers";

use(sinonChai);
use(chaiDeepEqualIgnoreUndefined);

describe("request formatting methods", () => {
describe("formatGenerateContentInput", () => {
Expand Down Expand Up @@ -172,4 +177,102 @@ describe("request formatting methods", () => {
});
});
});
describe("formatCountTokensInput", () => {
it("formats a text string into a count request", () => {
const result = formatCountTokensInput("some text content", {
model: "gemini-1.5-flash",
});
expect(result.generateContentRequest).to.deepEqualIgnoreUndefined({
model: "gemini-1.5-flash",
contents: [
{
role: "user",
parts: [{ text: "some text content" }],
},
],
});
});
it("formats a text string into a count request, along with model params", () => {
const result = formatCountTokensInput("some text content", {
model: "gemini-1.5-flash",
systemInstruction: "hello",
tools: [{ codeExecution: {} }],
cachedContent: { name: "mycache", contents: [] },
});
expect(result.generateContentRequest).to.deepEqualIgnoreUndefined({
model: "gemini-1.5-flash",
contents: [
{
role: "user",
parts: [{ text: "some text content" }],
},
],
systemInstruction: "hello",
tools: [{ codeExecution: {} }],
cachedContent: "mycache",
});
});
it("formats a 'contents' style count request, along with model params", () => {
const result = formatCountTokensInput(
{
contents: [
{
role: "user",
parts: [{ text: "some text content" }],
},
],
},
{
model: "gemini-1.5-flash",
systemInstruction: "hello",
tools: [{ codeExecution: {} }],
cachedContent: { name: "mycache", contents: [] },
},
);
expect(result.generateContentRequest).to.deepEqualIgnoreUndefined({
model: "gemini-1.5-flash",
contents: [
{
role: "user",
parts: [{ text: "some text content" }],
},
],
systemInstruction: "hello",
tools: [{ codeExecution: {} }],
cachedContent: "mycache",
});
});
it("formats a 'generateContentRequest' style count request, along with model params", () => {
const result = formatCountTokensInput(
{
generateContentRequest: {
contents: [
{
role: "user",
parts: [{ text: "some text content" }],
},
],
},
},
{
model: "gemini-1.5-flash",
systemInstruction: "hello",
tools: [{ codeExecution: {} }],
cachedContent: { name: "mycache", contents: [] },
},
);
expect(result.generateContentRequest).to.deepEqualIgnoreUndefined({
model: "gemini-1.5-flash",
contents: [
{
role: "user",
parts: [{ text: "some text content" }],
},
],
systemInstruction: "hello",
tools: [{ codeExecution: {} }],
cachedContent: "mycache",
});
});
});
});
29 changes: 22 additions & 7 deletions packages/main/src/requests/request-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import {
CountTokensRequest,
EmbedContentRequest,
GenerateContentRequest,
ModelParams,
Part,
_CountTokensRequestInternal,
_GenerateContentRequestInternal,
} from "../../types";
import {
GoogleGenerativeAIError,
Expand Down Expand Up @@ -111,9 +113,18 @@ function assignRoleToPartsAndValidateSendMessageRequest(

export function formatCountTokensInput(
params: CountTokensRequest | string | Array<string | Part>,
model: string,
modelParams?: ModelParams,
): _CountTokensRequestInternal {
let formattedRequest: _CountTokensRequestInternal = {};
let formattedGenerateContentRequest: _GenerateContentRequestInternal = {
model: modelParams?.model,
generationConfig: modelParams?.generationConfig,
safetySettings: modelParams?.safetySettings,
tools: modelParams?.tools,
toolConfig: modelParams?.toolConfig,
systemInstruction: modelParams?.systemInstruction,
cachedContent: modelParams?.cachedContent?.name,
contents: [],
};
const containsGenerateContentRequest =
(params as CountTokensRequest).generateContentRequest != null;
if ((params as CountTokensRequest).contents) {
Expand All @@ -122,16 +133,20 @@ export function formatCountTokensInput(
"CountTokensRequest must have one of contents or generateContentRequest, not both.",
);
}
formattedRequest = { ...(params as CountTokensRequest) };
formattedGenerateContentRequest.contents = (
params as CountTokensRequest
).contents;
} else if (containsGenerateContentRequest) {
formattedRequest = { ...(params as CountTokensRequest) };
formattedRequest.generateContentRequest.model = model;
formattedGenerateContentRequest = {
...formattedGenerateContentRequest,
...(params as CountTokensRequest).generateContentRequest,
};
} else {
// Array or string
const content = formatNewContent(params as string | Array<string | Part>);
formattedRequest.contents = [content];
formattedGenerateContentRequest.contents = [content];
}
return formattedRequest;
return { generateContentRequest: formattedGenerateContentRequest };
}

export function formatGenerateContentInput(
Expand Down
74 changes: 24 additions & 50 deletions samples/count_tokens.js
Original file line number Diff line number Diff line change
Expand Up @@ -287,36 +287,30 @@ async function tokensCachedContent() {
});

const genAI = new GoogleGenerativeAI(process.env.API_KEY);
const model = genAI.getGenerativeModel({
model: "models/gemini-1.5-flash",
});
const model = genAI.getGenerativeModelFromCachedContent(cacheResult);

const prompt = "Please give a short summary of this file.";

// Call `countTokens` to get the input token count
// of the combined text and file (`totalTokens`).
const result = await model.countTokens({
generateContentRequest: {
contents: [
{
role: "user",
parts: [{ text: "Please give a short summary of this file." }],
},
],
cachedContent: cacheResult.name,
},
});
const result = await model.countTokens(prompt);

console.log(result.totalTokens); // 10

const generateResult = await model.generateContent(
"Please give a short summary of this file.",
);
const generateResult = await model.generateContent(prompt);

// On the response for `generateContent`, use `usageMetadata`
// to get separate input and output token counts
// (`promptTokenCount` and `candidatesTokenCount`, respectively),
// as well as the combined token count (`totalTokenCount`).
// as well as the cached content token count and the combined total
// token count.
console.log(generateResult.response.usageMetadata);
// { promptTokenCount: 10, candidatesTokenCount: 31, totalTokenCount: 41 }
// {
// promptTokenCount: 323396,
// candidatesTokenCount: 113,
// totalTokenCount: 323509,
// cachedContentTokenCount: 323386
// }

await cacheManager.delete(cacheResult.name);
// [END tokens_cached_content]
Expand All @@ -329,22 +323,12 @@ async function tokensSystemInstruction() {
const genAI = new GoogleGenerativeAI(process.env.API_KEY);
const model = genAI.getGenerativeModel({
model: "models/gemini-1.5-flash",
systemInstruction: "You are a cat. Your name is Neko.",
});

const result = await model.countTokens({
generateContentRequest: {
contents: [
{
role: "user",
parts: [{ text: "The quick brown fox jumps over the lazy dog." }],
},
],
systemInstruction: {
role: "system",
parts: [{ text: "You are a cat. Your name is Neko." }],
},
},
});
const result = await model.countTokens(
"The quick brown fox jumps over the lazy dog.",
);

console.log(result);
// {
Expand All @@ -360,9 +344,6 @@ async function tokensTools() {
// Make sure to include these imports:
// import { GoogleGenerativeAI } from "@google/generative-ai";
const genAI = new GoogleGenerativeAI(process.env.API_KEY);
const model = genAI.getGenerativeModel({
model: "models/gemini-1.5-flash",
});

const functionDeclarations = [
{ name: "add" },
Expand All @@ -371,22 +352,15 @@ async function tokensTools() {
{ name: "divide" },
];

const result = await model.countTokens({
generateContentRequest: {
contents: [
{
role: "user",
parts: [
{
text: "I have 57 cats, each owns 44 mittens, how many mittens is that in total?",
},
],
},
],
tools: [{ functionDeclarations }],
},
const model = genAI.getGenerativeModel({
model: "models/gemini-1.5-flash",
tools: [{ functionDeclarations }],
});

const result = await model.countTokens(
"I have 57 cats, each owns 44 mittens, how many mittens is that in total?",
);

console.log(result);
// {
// totalTokens: 99,
Expand Down
5 changes: 5 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -2389,6 +2389,11 @@ chai-as-promised@^7.1.1:
dependencies:
check-error "^1.0.2"

chai-deep-equal-ignore-undefined@^1.1.1:
version "1.1.1"
resolved "https://registry.yarnpkg.com/chai-deep-equal-ignore-undefined/-/chai-deep-equal-ignore-undefined-1.1.1.tgz#c9e3736fed06c83572f03c592c025cf2703fd1a1"
integrity sha512-BE4nUR2Jbqmmv8A0EuAydFRB/lXgXWAfa9TvO3YzHeGHAU7ZRwPZyu074oDl/CZtNXM7jXINpQxKBOe7N0P4bg==

chai@^4.3.10:
version "4.3.10"
resolved "https://registry.yarnpkg.com/chai/-/chai-4.3.10.tgz#d784cec635e3b7e2ffb66446a63b4e33bd390384"
Expand Down

0 comments on commit e87d5b0

Please sign in to comment.