Skip to content

Commit 7af0f8d

Browse files
committed
Merge remote-tracking branch 'public/vaihi-exp' into erikeldridge-vertex-stream-rebased
2 parents 34c658e + e069751 commit 7af0f8d

File tree

5 files changed

+138
-6
lines changed

5 files changed

+138
-6
lines changed

packages/vertexai/src/methods/chrome-adapter.test.ts

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,60 @@ describe('ChromeAdapter', () => {
326326
});
327327
});
328328
});
329+
describe('countTokens', () => {
330+
it('counts tokens from a singular input', async () => {
331+
const inputText = 'first';
332+
const expectedCount = 10;
333+
const onDeviceParams = {
334+
systemPrompt: 'be yourself'
335+
} as LanguageModelCreateOptions;
336+
337+
// setting up stubs
338+
const languageModelProvider = {
339+
create: () => Promise.resolve({})
340+
} as LanguageModel;
341+
const languageModel = {
342+
measureInputUsage: _i => Promise.resolve(123)
343+
} as LanguageModel;
344+
const createStub = stub(languageModelProvider, 'create').resolves(
345+
languageModel
346+
);
347+
348+
// overrides impl with stub method
349+
const measureInputUsageStub = stub(
350+
languageModel,
351+
'measureInputUsage'
352+
).resolves(expectedCount);
353+
354+
const adapter = new ChromeAdapter(
355+
languageModelProvider,
356+
'prefer_on_device',
357+
onDeviceParams
358+
);
359+
360+
const countTokenRequest = {
361+
contents: [{ role: 'user', parts: [{ text: inputText }] }]
362+
} as GenerateContentRequest;
363+
const response = await adapter.countTokens(countTokenRequest);
364+
// Asserts initialization params are proxied.
365+
expect(createStub).to.have.been.calledOnceWith(onDeviceParams);
366+
// Asserts Vertex input type is mapped to Chrome type.
367+
expect(measureInputUsageStub).to.have.been.calledOnceWith([
368+
{
369+
role: 'user',
370+
content: [
371+
{
372+
type: 'text',
373+
content: inputText
374+
}
375+
]
376+
}
377+
]);
378+
expect(await response.json()).to.deep.equal({
379+
totalTokens: expectedCount
380+
});
381+
});
382+
});
329383
describe('generateContentStreamOnDevice', () => {
330384
it('generates content stream', async () => {
331385
const languageModelProvider = {

packages/vertexai/src/methods/chrome-adapter.ts

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import {
1919
Content,
20+
CountTokensRequest,
2021
GenerateContentRequest,
2122
InferenceMode,
2223
Part,
@@ -103,6 +104,15 @@ export class ChromeAdapter {
103104
const text = await session.prompt(messages);
104105
return ChromeAdapter.toResponse(text);
105106
}
107+
108+
/**
109+
* Generates content stream on device.
110+
*
111+
* <p>This is comparable to {@link GenerativeModel.generateContentStream} for generating content in
112+
* Cloud.</p>
113+
* @param request a standard Vertex {@link GenerateContentRequest}
114+
* @returns {@link Response}, so we can reuse common response formatting.
115+
*/
106116
async generateContentStream(
107117
request: GenerateContentRequest
108118
): Promise<Response> {
@@ -114,6 +124,25 @@ export class ChromeAdapter {
114124
const stream = await session.promptStreaming(messages);
115125
return ChromeAdapter.toStreamResponse(stream);
116126
}
127+
128+
async countTokens(request: CountTokensRequest): Promise<Response> {
129+
// TODO: Check if the request contains an image, and if so, throw.
130+
const session = await this.createSession(
131+
// TODO: normalize on-device params during construction.
132+
this.onDeviceParams || {}
133+
);
134+
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
135+
const tokenCount = await session.measureInputUsage(messages);
136+
return {
137+
json: async () => ({
138+
totalTokens: tokenCount
139+
})
140+
} as Response;
141+
}
142+
143+
/**
144+
* Asserts inference for the given request can be performed by an on-device model.
145+
*/
117146
private static isOnDeviceRequest(request: GenerateContentRequest): boolean {
118147
// Returns false if the prompt is empty.
119148
if (request.contents.length === 0) {

packages/vertexai/src/methods/count-tokens.test.ts

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import { countTokens } from './count-tokens';
2525
import { CountTokensRequest } from '../types';
2626
import { ApiSettings } from '../types/internal';
2727
import { Task } from '../requests/request';
28+
import { ChromeAdapter } from './chrome-adapter';
2829

2930
use(sinonChai);
3031
use(chaiAsPromised);
@@ -55,7 +56,8 @@ describe('countTokens()', () => {
5556
const result = await countTokens(
5657
fakeApiSettings,
5758
'model',
58-
fakeRequestParams
59+
fakeRequestParams,
60+
new ChromeAdapter()
5961
);
6062
expect(result.totalTokens).to.equal(6);
6163
expect(result.totalBillableCharacters).to.equal(16);
@@ -81,7 +83,8 @@ describe('countTokens()', () => {
8183
const result = await countTokens(
8284
fakeApiSettings,
8385
'model',
84-
fakeRequestParams
86+
fakeRequestParams,
87+
new ChromeAdapter()
8588
);
8689
expect(result.totalTokens).to.equal(1837);
8790
expect(result.totalBillableCharacters).to.equal(117);
@@ -109,7 +112,8 @@ describe('countTokens()', () => {
109112
const result = await countTokens(
110113
fakeApiSettings,
111114
'model',
112-
fakeRequestParams
115+
fakeRequestParams,
116+
new ChromeAdapter()
113117
);
114118
expect(result.totalTokens).to.equal(258);
115119
expect(result).to.not.have.property('totalBillableCharacters');
@@ -135,8 +139,33 @@ describe('countTokens()', () => {
135139
json: mockResponse.json
136140
} as Response);
137141
await expect(
138-
countTokens(fakeApiSettings, 'model', fakeRequestParams)
142+
countTokens(
143+
fakeApiSettings,
144+
'model',
145+
fakeRequestParams,
146+
new ChromeAdapter()
147+
)
139148
).to.be.rejectedWith(/404.*not found/);
140149
expect(mockFetch).to.be.called;
141150
});
151+
it('on-device', async () => {
152+
const chromeAdapter = new ChromeAdapter();
153+
const isAvailableStub = stub(chromeAdapter, 'isAvailable').resolves(true);
154+
const mockResponse = getMockResponse(
155+
'vertexAI',
156+
'unary-success-total-tokens.json'
157+
);
158+
const countTokensStub = stub(chromeAdapter, 'countTokens').resolves(
159+
mockResponse as Response
160+
);
161+
const result = await countTokens(
162+
fakeApiSettings,
163+
'model',
164+
fakeRequestParams,
165+
chromeAdapter
166+
);
167+
expect(result.totalTokens).eq(6);
168+
expect(isAvailableStub).to.be.called;
169+
expect(countTokensStub).to.be.calledWith(fakeRequestParams);
170+
});
142171
});

packages/vertexai/src/methods/count-tokens.ts

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ import {
2222
} from '../types';
2323
import { Task, makeRequest } from '../requests/request';
2424
import { ApiSettings } from '../types/internal';
25+
import { ChromeAdapter } from './chrome-adapter';
2526

26-
export async function countTokens(
27+
export async function countTokensOnCloud(
2728
apiSettings: ApiSettings,
2829
model: string,
2930
params: CountTokensRequest,
@@ -39,3 +40,17 @@ export async function countTokens(
3940
);
4041
return response.json();
4142
}
43+
44+
export async function countTokens(
45+
apiSettings: ApiSettings,
46+
model: string,
47+
params: CountTokensRequest,
48+
chromeAdapter: ChromeAdapter,
49+
requestOptions?: RequestOptions
50+
): Promise<CountTokensResponse> {
51+
if (await chromeAdapter.isAvailable(params)) {
52+
return (await chromeAdapter.countTokens(params)).json();
53+
}
54+
55+
return countTokensOnCloud(apiSettings, model, params, requestOptions);
56+
}

packages/vertexai/src/models/generative-model.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ export class GenerativeModel extends VertexAIModel {
154154
request: CountTokensRequest | string | Array<string | Part>
155155
): Promise<CountTokensResponse> {
156156
const formattedParams = formatGenerateContentInput(request);
157-
return countTokens(this._apiSettings, this.model, formattedParams);
157+
return countTokens(
158+
this._apiSettings,
159+
this.model,
160+
formattedParams,
161+
this.chromeAdapter
162+
);
158163
}
159164
}

0 commit comments

Comments
 (0)