From f6ed3175d41f1f464b9897ef114da633777a62af Mon Sep 17 00:00:00 2001 From: Eric Barroca Date: Wed, 30 Oct 2024 21:35:03 +0900 Subject: [PATCH 1/2] Update build-and-test.yml Signed-off-by: Eric Barroca --- .github/workflows/build-and-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index f2a1ac2..4adf863 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -53,7 +53,7 @@ jobs: audience: sts.amazonaws.com role-to-assume: arn:aws:iam::716085231028:role/ComposablePromptExecutor role-session-name: github-actions - aws-region: us-west-2 + aws-region: us-east-1 - run: npx vitest env: From dd7622ded7b1c442d6103c40ccaebc5a997989f7 Mon Sep 17 00:00:00 2001 From: Leon Ruggiero Date: Tue, 5 Nov 2024 13:34:07 +0000 Subject: [PATCH 2/2] Adds finish_reason and token_usage support, streaming included. (#24) * Adds finish_reason and token_usage support. Change return type from string to CompletionChunk object to facilitate this. Adds tests for finish_reason. * Fix build error. Remove accidental inclusion of change for another branch * Fixes AWS support, groq & openai token usage. Bedrock:mistral bypasses token_usage check as it does not support it. Adds chunk count to completion object. * Rename getTextAnsStopReason to getExtractedCompletionChunk, and remove superfluous console.log * Yield chunk.result rather than whole chunk. Only yield non-empty content. Narrow ResultT to JSONObject | string. Change all instances of yield to type string * Fix build, revert ResultT to any. --- core/src/CompletionStream.ts | 40 ++- core/src/Driver.ts | 5 +- core/src/async.ts | 12 +- core/src/types.ts | 18 ++ drivers/src/bedrock/index.ts | 269 ++++++++++++------ drivers/src/groq/index.ts | 15 +- drivers/src/huggingface_ie.ts | 24 +- drivers/src/mistral/index.ts | 16 +- drivers/src/openai/index.ts | 33 ++- drivers/src/replicate.ts | 17 +- .../TestValidationErrorCompletionStream.ts | 4 +- drivers/src/togetherai/index.ts | 17 +- drivers/src/vertexai/index.ts | 4 +- drivers/src/vertexai/models.ts | 4 +- drivers/src/vertexai/models/gemini.ts | 32 ++- drivers/src/watsonx/index.ts | 25 +- drivers/test/all-models.test.ts | 4 +- drivers/test/assertions.ts | 13 +- 18 files changed, 389 insertions(+), 163 deletions(-) diff --git a/core/src/CompletionStream.ts b/core/src/CompletionStream.ts index 510fdde..a48e16a 100644 --- a/core/src/CompletionStream.ts +++ b/core/src/CompletionStream.ts @@ -1,5 +1,5 @@ import { AbstractDriver } from "./Driver.js"; -import { CompletionStream, DriverOptions, ExecutionOptions, ExecutionResponse } from "./types.js"; +import { CompletionStream, DriverOptions, ExecutionOptions, ExecutionResponse, ExecutionTokenUsage } from "./types.js"; export class DefaultCompletionStream implements CompletionStream { @@ -27,27 +27,47 @@ export class DefaultCompletionStream implements CompletionStream< const start = Date.now(); const stream = await this.driver.requestCompletionStream(this.prompt, this.options); + let finish_reason: string | undefined = undefined; + let promptTokens: number = 0; + let resultTokens: number | undefined = undefined; for await (const chunk of stream) { if (chunk) { - chunks.push(chunk); - yield chunk; + if (typeof chunk === 'string') { + chunks.push(chunk); + yield chunk; + }else{ + if (chunk.finish_reason) { //Do not replace non-null values with null values + finish_reason = chunk.finish_reason; //Used to skip empty finish_reason chunks coming after "stop" or "length" + } + if (chunk.token_usage) { + //Tokens returned include prior parts of stream, + //so overwrite rather than accumulate + //Math.max used as some models report final token count at beginning of stream + promptTokens = Math.max(promptTokens,chunk.token_usage.prompt ?? 0); + resultTokens = Math.max(resultTokens ?? 0,chunk.token_usage.result ?? 0); + } + if (chunk.result) { + chunks.push(chunk.result); + yield chunk.result; + } + } } } const content = chunks.join(''); - const promptTokens = typeof this.prompt === 'string' ? this.prompt.length : JSON.stringify(this.prompt).length; - const resultTokens = content.length; //TODO use chunks.length ? + // Return undefined for the ExecutionTokenUsage object if there is nothing to fill it with. + // Allows for checking for truthyness on token_usage, rather than it's internals. For testing and downstream usage. + let tokens: ExecutionTokenUsage | undefined = resultTokens ? + { prompt: promptTokens, result: resultTokens, total: resultTokens + promptTokens, } : undefined this.completion = { result: content, prompt: this.prompt, execution_time: Date.now() - start, - token_usage: { - prompt: promptTokens, - result: resultTokens, - total: resultTokens + promptTokens, - } + token_usage: tokens, + finish_reason: finish_reason, + chunks: chunks.length, } this.driver.validateResult(this.completion, this.options); diff --git a/core/src/Driver.ts b/core/src/Driver.ts index f072951..7d13044 100644 --- a/core/src/Driver.ts +++ b/core/src/Driver.ts @@ -22,7 +22,8 @@ import { PromptSegment, TrainingJob, TrainingOptions, - TrainingPromptOptions + TrainingPromptOptions, + CompletionChunk } from "./types.js"; import { validateResult } from "./validation.js"; @@ -223,7 +224,7 @@ export abstract class AbstractDriver; - abstract requestCompletionStream(prompt: PromptT, options: ExecutionOptions): Promise>; + abstract requestCompletionStream(prompt: PromptT, options: ExecutionOptions): Promise>; //list models available for this environement abstract listModels(params?: ModelSearchPayload): Promise; diff --git a/core/src/async.ts b/core/src/async.ts index 3eb53d1..63c8f74 100644 --- a/core/src/async.ts +++ b/core/src/async.ts @@ -1,4 +1,5 @@ import type { ServerSentEvent } from "api-fetch-client"; +import { CompletionChunk } from "./types.js"; export async function* asyncMap(asyncIterable: AsyncIterable, callback: (value: T, index: number) => R) { let i = 0; @@ -15,22 +16,23 @@ export function oneAsyncIterator(value: T): AsyncIterable { } /** - * Given a ReadableStream of server seent events, tran + * Given a ReadableStream of server sent events, tran */ -export function transformSSEStream(stream: ReadableStream, transform: (data: string) => string): ReadableStream & AsyncIterable { +export function transformSSEStream(stream: ReadableStream, transform: (data: string) => CompletionChunk): ReadableStream & AsyncIterable { // on node and bun the readablestream is an async iterable - return stream.pipeThrough(new TransformStream({ + return stream.pipeThrough(new TransformStream({ transform(event: ServerSentEvent, controller) { if (event.type === 'event' && event.data && event.data !== '[DONE]') { try { - controller.enqueue(transform(event.data) ?? ''); + const result = transform(event.data) ?? '' + controller.enqueue(result); } catch (err) { // double check for the last event whicb is not a JSON - at this time togetherai and mistralai returrns the string [DONE] // do nothing - happens if data is not a JSON - the last event data is the [DONE] string } } } - })) as ReadableStream & AsyncIterable; + })) as ReadableStream & AsyncIterable; } export class EventStream implements AsyncIterable{ diff --git a/core/src/types.ts b/core/src/types.ts index 9e1d40e..e88f761 100644 --- a/core/src/types.ts +++ b/core/src/types.ts @@ -40,6 +40,18 @@ export interface ResultValidationError { data?: string; } +//ResultT should be either JSONObject or string +//Internal structure used in driver implementation. +export interface CompletionChunkObject { + result: ResultT; + token_usage?: ExecutionTokenUsage; + finish_reason?: "stop" | "length" | string; +} + +//Internal structure used in driver implementation. +export type CompletionChunk = CompletionChunkObject | string; + +//ResultT should be either JSONObject or string export interface Completion { // the driver impl must return the result and optionally the token_usage. the execution time is computed by the extended abstract driver result: ResultT; @@ -69,6 +81,10 @@ export interface ExecutionResponse extends Completion { * The time it took to execute the request in seconds */ execution_time?: number; + /** + * The number of chunks for streamed executions + */ + chunks?: number; } @@ -118,6 +134,8 @@ export interface ExecutionOptions extends PromptOptions { top_p?: number; /** + * Currently not supported, will be ignored. + * Should be an integer. * Only supported for OpenAI. Look at OpenAI documentation for more detailsx */ top_logprobs?: number; diff --git a/drivers/src/bedrock/index.ts b/drivers/src/bedrock/index.ts index 87aaf61..bead67d 100644 --- a/drivers/src/bedrock/index.ts +++ b/drivers/src/bedrock/index.ts @@ -1,7 +1,7 @@ import { Bedrock, CreateModelCustomizationJobCommand, FoundationModelSummary, GetModelCustomizationJobCommand, GetModelCustomizationJobCommandOutput, ModelCustomizationJobStatus, StopModelCustomizationJobCommand } from "@aws-sdk/client-bedrock"; import { BedrockRuntime, InvokeModelCommandOutput, ResponseStream } from "@aws-sdk/client-bedrock-runtime"; import { S3Client } from "@aws-sdk/client-s3"; -import { AIModel, AbstractDriver, Completion, DataSource, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptOptions, PromptSegment, TrainingJob, TrainingJobStatus, TrainingOptions } from "@llumiverse/core"; +import { AIModel, AbstractDriver, Completion, DataSource, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptOptions, PromptSegment, CompletionChunkObject, TrainingJob, TrainingJobStatus, TrainingOptions, ExecutionTokenUsage } from "@llumiverse/core"; import { transformAsyncIterator } from "@llumiverse/core/async"; import { ClaudeMessagesPrompt, formatClaudePrompt } from "@llumiverse/core/formatters"; import { AwsCredentialIdentity, Provider } from "@smithy/types"; @@ -89,60 +89,194 @@ export class BedrockDriver extends AbstractDriver { - if (result.generation) { - // LLAMA2 - return [result.generation, result.stop_reason]; // comes in coirrect format (stop, length) - } else if (result.generations) { - // Cohere - return [result.generations[0].text, cohereFinishReason(result.generations[0].finish_reason)]; - } else if (result.chat_history) { - //Cohere Command R - return [result.text, cohereFinishReason(result.finish_reason)]; - } else if (result.completions) { - //A21 - return [result.completions[0].data?.text, a21FinishReason(result.completions[0].finishReason?.reason)]; - } else if (result.content) { - // Claude + //Update this when supporting new models + static getExtractedCompletionChunk(result: any, prompt?: BedrockPrompt): CompletionChunkObject { + //AWS universal token_usage + let token_usage = this.getAmazonInvocationMetrics(result); + if (result.generation || result.generation == '') { + // LLAMA3 + if (!token_usage) { + token_usage = { + prompt: result.prompt_token_count, + result: result.generation_token_count, + total: result.generation_token_count + result.prompt_token_count, + }; + } + return { + result: result.generation, + finish_reason: result.stop_reason, //already in "stop" or "length" format + token_usage: token_usage, + }; + } else if (result.generations) { + // Cohere Command (Non-R) + if (!token_usage) { + token_usage = { + prompt: result?.meta?.billed_units.input_tokens, + result: result?.meta?.billed_units.output_tokens, + total: result?.meta?.billed_units.input_tokens + result?.meta?.billed_units.output_tokens, + } + } + return { + result: result.generations[0].text, + finish_reason: cohereFinishReason(result.generations[0].finish_reason), + //Token usage not given in AWS docs, but is in cohere docs. + token_usage: token_usage + }; + } else if (result.chat_history) { + // Cohere Command R + if (!token_usage) { + token_usage = { + prompt: result?.meta?.billed_units.input_tokens, + result: result?.meta?.billed_units.output_tokens, + total: result?.meta?.billed_units.input_tokens + result?.meta?.billed_units.output_tokens, + } + } + return { + result: result.text, + finish_reason: cohereFinishReason(result.finish_reason), + token_usage: token_usage, + }; + } else if (result.event_type) { + // Cohere Command R streaming + return { + result: result.text, + finish_reason: cohereFinishReason(result.finish_reason), + token_usage: token_usage, + }; + } else if (result.completions) { + // A21 Jurassic + if (!token_usage) { + token_usage = { + prompt: result.prompt.tokens.length, + result: result.completions[0].data.tokens.length, + total: result.prompt.tokens.length + result.completions[0].data.tokens.length, + } + } + return { + result: result.completions[0].data?.text, + finish_reason: a21FinishReason(result.completions[0].finishReason?.reason), + token_usage: token_usage, + }; + } else if (result.content) { + // Claude + if (!token_usage) { + token_usage = { + prompt: result.usage?.input_tokens, + result: result.usage?.output_tokens, + total: result.usage?.input_tokens + result.usage?.output_tokens, + } + } + let res: string = ""; + if (prompt) { //if last prompt.messages is {, add { to the response const p = prompt as ClaudeMessagesPrompt; const lastMessage = (p as ClaudeMessagesPrompt).messages[p.messages.length - 1]; - const res = lastMessage.content[0].text === '{' ? '{' + result.content[0]?.text : result.content[0]?.text; - - return [res, claudeFinishReason(result.stop_reason)]; - - } else if (result.outputs) { - // mistral - return [result.outputs[0]?.text, result.outputs[0]?.stop_reason]; // the stop reason is in the expected format ("stop" and "length") - } else if (result.results) { - // Amazon Titan - return [result.results[0]?.outputText ?? '', titanFinishReason(result.results[0]?.completionReason)]; - } else if (result.completion) { // TODO: who uses this? - return [result.completion]; + res = lastMessage.content[0].text === '{' ? '{' + (result.content[0]?.text ?? '') : (result.content[0]?.text ?? ''); } else { - return [result.toString()]; + res = result.content[0].text } - }; + return { + result: res, + finish_reason: claudeFinishReason(result.stop_reason), + token_usage: token_usage, + }; + } + else if (result.delta || result.type) { // claude-v2:1 when streaming + if (!token_usage) { + token_usage = { + prompt: result.usage?.input_tokens, + result: result.usage?.output_tokens, + total: result.usage?.input_tokens + result.usage?.output_tokens, + } + } + let res: string = ""; + if (result.type == 'content_block_start'){ + if (prompt) { + //if last prompt.messages is {, add { to the response + const p = prompt as ClaudeMessagesPrompt; + const lastMessage = (p as ClaudeMessagesPrompt).messages[p.messages.length - 1]; + res = lastMessage.content[0].text === '{' ? '{' + (result?.content_block[0]?.text ?? '') : (result?.content_block[0]?.text ?? ''); + } else { + res = result.content_block[0]?.text; + } + } else { // content_block_delta + res = result.delta?.text || ''; + } + return { + result: res, + finish_reason: claudeFinishReason(result.delta?.stop_reason), + token_usage: token_usage, + }; + } else if (result.outputs) { + // Mistral + return { + result: result.outputs[0]?.text, + finish_reason: result.outputs[0]?.stop_reason, // the stop reason is in the expected format ("stop" and "length") + token_usage: token_usage, + }; + //Token usage not supported + } else if (result.results) { + // Amazon Titan non-streaming + if (!token_usage) { + token_usage = { + prompt: result.inputTextTokenCount, + result: result.results[0].tokenCount, + total: result.inputTextTokenCount + result.results[0].tokenCount, + } + } + return { + result: result.results[0]?.outputText ?? '', + finish_reason: titanFinishReason(result.results[0]?.completionReason), + token_usage: token_usage, + }; + } else if (result.chunks) { + // Amazon Titan streaming + const decoder = new TextDecoder(); + const chunk = decoder.decode(result.chunks); + const result_chunk = JSON.parse(chunk); + if (!token_usage) { + token_usage = { + prompt: result_chunk.inputTextTokenCount, + result: result_chunk.totalOutputTextTokenCount, + total: result_chunk.inputTextTokenCount + result_chunk.totalOutputTextTokenCount, + } + } + return { + result: result_chunk.outputText, + finish_reason: titanFinishReason(result_chunk.completionReason), + token_usage: token_usage, + }; + } else if (result.completion) { // TODO: who uses this? + return { + result: result.completion, + token_usage: token_usage, + }; + } else { // Fallback + return { + result: result, + token_usage: token_usage, + }; + } + }; - const [text, finish_reason] = getTextAnsStopReason(); + extractDataFromResponse(prompt: BedrockPrompt, response: InvokeModelCommandOutput): Completion { - const promptLength = typeof prompt === 'string' ? prompt.length : - (prompt.system || '').length + prompt.messages.reduce((acc, m) => acc + m.content.length, 0); - return { - result: text, - token_usage: { - result: text?.length, - prompt: promptLength, - total: text?.length + promptLength, - }, - finish_reason - } + const decoder = new TextDecoder(); + const body = decoder.decode(response.body); + const result = JSON.parse(body); + + return BedrockDriver.getExtractedCompletionChunk(result, prompt); } async requestCompletion(prompt: BedrockPrompt, options: ExecutionOptions): Promise { @@ -173,7 +307,7 @@ export class BedrockDriver extends AbstractDriver> { + async requestCompletionStream(prompt: BedrockPrompt, options: ExecutionOptions): Promise> { const payload = this.preparePayload(prompt, options); const executor = this.getExecutor(); return executor.invokeModelWithResponseStream({ @@ -187,46 +321,11 @@ export class BedrockDriver extends AbstractDriver { - if (typeof prompt === 'object' && (prompt as ClaudeMessagesPrompt).messages) { - const p = prompt as ClaudeMessagesPrompt; - const lastMessage = p.messages[p.messages.length - 1]; - return lastMessage.content[0].text === '{'; - } - return false; - }; - return transformAsyncIterator(res.body, (stream: ResponseStream) => { const segment = JSON.parse(decoder.decode(stream.chunk?.bytes)); //console.log("Debug Segment for model " + options.model, JSON.stringify(segment)); - if (segment.delta) { // who is this? - return segment.delta.text || ''; - } else if (segment.completion) { // who is this? - return segment.completion; - } else if (segment.text) { //cohere - return segment.text; - } else if (segment.completions) { - return segment.completions[0].data?.text; - } else if (segment.generation) { - return segment.generation; - } else if (segment.generations) { - return segment.generations[0].text; - } else if (segment.outputs) { - // mistral.mixtral-8x7b-instruct-v0:1 - return segment.outputs[0].text; - //segment.outputs[0].stop_reason; - } else if (segment.outputText) { - // Amazon Titan - return segment.outputText; - //completionReason - // token count too - } else { - segment.toString(); - } - - }, - () => addBracket() ? '{' : '' - ); + return BedrockDriver.getExtractedCompletionChunk(segment, prompt); + }); }).catch((err) => { this.logger.error("[Bedrock] Failed to stream", err); diff --git a/drivers/src/groq/index.ts b/drivers/src/groq/index.ts index 4e732c9..c6f13d3 100644 --- a/drivers/src/groq/index.ts +++ b/drivers/src/groq/index.ts @@ -1,4 +1,4 @@ -import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core"; +import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment, CompletionChunkObject } from "@llumiverse/core"; import { transformAsyncIterator } from "@llumiverse/core/async"; import { OpenAITextMessage, formatOpenAILikeTextPrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters"; import Groq from "groq-sdk"; @@ -85,7 +85,7 @@ export class GroqDriver extends AbstractDriver> { + async requestCompletionStream(messages: OpenAITextMessage[], options: ExecutionOptions): Promise> { const res = await this.client.chat.completions.create({ model: options.model, @@ -96,8 +96,15 @@ export class GroqDriver extends AbstractDriver res.choices[0].delta.content || ''); - + return transformAsyncIterator(res, (res) => ({ + result: res.choices[0].delta.content ?? '', + finish_reason: res.choices[0].finish_reason, + token_usage: { + prompt: res.x_groq?.usage?.prompt_tokens, + result: res.x_groq?.usage?.completion_tokens, + total: res.x_groq?.usage?.total_tokens, + }, + } as CompletionChunkObject)); } async listModels(): Promise[]> { diff --git a/drivers/src/huggingface_ie.ts b/drivers/src/huggingface_ie.ts index feac2e4..193c629 100644 --- a/drivers/src/huggingface_ie.ts +++ b/drivers/src/huggingface_ie.ts @@ -9,7 +9,8 @@ import { AbstractDriver, DriverOptions, EmbeddingsResult, - ExecutionOptions + ExecutionOptions, + CompletionChunkObject } from "@llumiverse/core"; import { transformAsyncIterator } from "@llumiverse/core/async"; import { FetchClient } from "api-fetch-client"; @@ -74,11 +75,22 @@ export class HuggingFaceIEDriver extends AbstractDriver { //special like are not part of the result - if (val.token.special) return ""; - return val.token.text; + if (val.token.special) return {result:""}; + let finish_reason = val.details?.finish_reason as string; + if (finish_reason === "eos_token") { + finish_reason = "stop"; + } + return { + result: val.token.text ?? '', + finish_reason: finish_reason, + token_usage:{ + result: val.details?.generated_tokens ?? 0, + } + } as CompletionChunkObject; }); } @@ -98,12 +110,10 @@ export class HuggingFaceIEDriver extends AbstractDriver> { + async requestCompletionStream(messages: OpenAITextMessage[], options: ExecutionOptions): Promise> { const stream = await this.client.post('/v1/chat/completions', { payload: _makeChatCompletionRequest({ model: options.model, @@ -102,7 +102,15 @@ export class MistralAIDriver extends AbstractDriver { const json = JSON.parse(data); - return json.choices[0]?.delta.content ?? ''; + return { + result: json.choices[0]?.delta.content ?? '', + finish_reason: json.choices[0]?.finish_reason, //Uses expected "stop" , "length" format + token_usage: { + prompt: json.usage?.prompt_tokens, + result: json.usage?.completion_tokens, + total: json.usage?.total_tokens, + }, + }; }); } diff --git a/drivers/src/openai/index.ts b/drivers/src/openai/index.ts index 2648705..3bf38b1 100644 --- a/drivers/src/openai/index.ts +++ b/drivers/src/openai/index.ts @@ -9,6 +9,7 @@ import { ExecutionOptions, ExecutionTokenUsage, ModelType, + CompletionChunkObject, TrainingJob, TrainingJobStatus, TrainingOptions, @@ -54,14 +55,13 @@ export abstract class BaseOpenAIDriver extends AbstractDriver< }; const choice = result.choices[0]; - const finish_reason = choice.finish_reason; //if no schema, return content if (!options.result_schema) { return { result: choice.message.content as string, token_usage: tokenInfo, - finish_reason + finish_reason: choice.finish_reason, //Uses expected "stop" , "length" format } } @@ -75,23 +75,39 @@ export abstract class BaseOpenAIDriver extends AbstractDriver< return { result: data, token_usage: tokenInfo, - finish_reason + finish_reason: choice.finish_reason, }; } async requestCompletionStream(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise { const mapFn = options.result_schema ? (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => { - return ( - chunk.choices[0]?.delta?.tool_calls?.[0].function?.arguments ?? "" - ); + return { + result: chunk.choices[0]?.delta?.tool_calls?.[0].function?.arguments ?? "", + finish_reason: chunk.choices[0]?.finish_reason, //Uses expected "stop" , "length" format + token_usage: { + prompt: chunk.usage?.prompt_tokens, + result: chunk.usage?.completion_tokens, + total: (chunk.usage?.prompt_tokens ?? 0) + (chunk.usage?.completion_tokens ?? 0), + } + } as CompletionChunkObject; } : (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => { - return chunk.choices[0]?.delta?.content ?? ""; + return { + result: chunk.choices[0]?.delta.content ?? "", + finish_reason: chunk.choices[0]?.finish_reason, + token_usage: { + prompt: chunk.usage?.prompt_tokens, + result: chunk.usage?.completion_tokens, + total: (chunk.usage?.prompt_tokens ?? 0) + (chunk.usage?.completion_tokens ?? 0), + } + } as CompletionChunkObject; }; + //TODO: OpenAI o1 support requires max_completions_tokens const stream = (await this.service.chat.completions.create({ stream: true, + stream_options: {include_usage: true}, model: options.model, messages: prompt, temperature: options.temperature, @@ -130,7 +146,8 @@ export abstract class BaseOpenAIDriver extends AbstractDriver< } as OpenAI.Chat.ChatCompletionTool, ] : undefined; - + + //TODO: OpenAI o1 support requires max_completions_tokens const res = await this.service.chat.completions.create({ stream: false, model: options.model, diff --git a/drivers/src/replicate.ts b/drivers/src/replicate.ts index 5922953..bf9a294 100644 --- a/drivers/src/replicate.ts +++ b/drivers/src/replicate.ts @@ -2,6 +2,7 @@ import { AIModel, AbstractDriver, Completion, + CompletionChunk, DataSource, DriverOptions, EmbeddingsResult, @@ -55,19 +56,14 @@ export class ReplicateDriver extends AbstractDriver { }); } - extractDataFromResponse(prompt: string, response: Prediction): Completion { + extractDataFromResponse(response: Prediction): Completion { const text = response.output.join(""); return { result: text, - token_usage: { - result: response.output.length, - prompt: prompt.length, - total: response.output.length + prompt.length, - }, }; } - async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise> { + async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise> { const model = ReplicateDriver.parseModelId(options.model); const predictionData = { input: { @@ -82,7 +78,7 @@ export class ReplicateDriver extends AbstractDriver { const prediction = await this.service.predictions.create(predictionData); - const stream = new EventStream(); + const stream = new EventStream(); const source = new EventSource(prediction.urls.stream!); source.addEventListener("output", (e: any) => { @@ -133,11 +129,6 @@ export class ReplicateDriver extends AbstractDriver { const text = res.output.join(""); return { result: text, - token_usage: { - result: res.output.length, - prompt: prompt.length, - total: res.output.length + prompt.length, - }, original_response: options.include_original_response ? res : undefined, }; } diff --git a/drivers/src/test/TestValidationErrorCompletionStream.ts b/drivers/src/test/TestValidationErrorCompletionStream.ts index e5ad75e..2cf1e16 100644 --- a/drivers/src/test/TestValidationErrorCompletionStream.ts +++ b/drivers/src/test/TestValidationErrorCompletionStream.ts @@ -12,9 +12,9 @@ export class TestValidationErrorCompletionStream implements CompletionStream> { - + async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise> { const stream = await this.fetchClient.post('/v1/completions', { payload: { model: options.model, @@ -78,7 +77,15 @@ export class TogetherAIDriver extends AbstractDriver { const json = JSON.parse(data); - return json.choices[0]?.text ?? ''; + return { + result: json.choices[0]?.text ?? '', + finish_reason: json.choices[0]?.finish_reason, //Uses expected "stop" , "length" format + token_usage: { + prompt: json.usage?.prompt_tokens, + result: json.usage?.completion_tokens, + total: json.usage?.prompt_tokens + json.usage?.completion_tokens, + } + }; }); } diff --git a/drivers/src/vertexai/index.ts b/drivers/src/vertexai/index.ts index 5aa3efc..465f525 100644 --- a/drivers/src/vertexai/index.ts +++ b/drivers/src/vertexai/index.ts @@ -1,5 +1,5 @@ import { GenerateContentRequest, VertexAI } from "@google-cloud/vertexai"; -import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsResult, ExecutionOptions, ModelSearchPayload, PromptOptions, PromptSegment } from "@llumiverse/core"; +import { AIModel, AbstractDriver, Completion, CompletionChunkObject, DriverOptions, EmbeddingsResult, ExecutionOptions, ModelSearchPayload, PromptOptions, PromptSegment } from "@llumiverse/core"; import { FetchClient } from "api-fetch-client"; import { GoogleAuth, GoogleAuthOptions } from "google-auth-library"; import { JSONClient } from "google-auth-library/build/src/auth/googleauth.js"; @@ -58,7 +58,7 @@ export class VertexAIDriver extends AbstractDriver> { return getModelDefinition(options.model).requestCompletion(this, prompt, options); } - async requestCompletionStream(prompt: GenerateContentRequest, options: ExecutionOptions): Promise> { + async requestCompletionStream(prompt: GenerateContentRequest, options: ExecutionOptions): Promise> { return getModelDefinition(options.model).requestCompletionStream(this, prompt, options); } diff --git a/drivers/src/vertexai/models.ts b/drivers/src/vertexai/models.ts index e76cd00..f227a06 100644 --- a/drivers/src/vertexai/models.ts +++ b/drivers/src/vertexai/models.ts @@ -1,4 +1,4 @@ -import { AIModel, Completion, ExecutionOptions, ModelType, PromptOptions, PromptSegment } from "@llumiverse/core"; +import { AIModel, Completion, CompletionChunkObject, ExecutionOptions, ModelType, PromptOptions, PromptSegment } from "@llumiverse/core"; import { VertexAIDriver } from "./index.js"; import { GeminiModelDefinition } from "./models/gemini.js"; @@ -10,7 +10,7 @@ export interface ModelDefinition { versions?: string[]; // the versions of the model that are available. ex: ['001', '002'] createPrompt: (driver: VertexAIDriver, segments: PromptSegment[], options: PromptOptions) => Promise; requestCompletion: (driver: VertexAIDriver, prompt: PromptT, options: ExecutionOptions) => Promise; - requestCompletionStream: (driver: VertexAIDriver, promp: PromptT, options: ExecutionOptions) => Promise>; + requestCompletionStream: (driver: VertexAIDriver, promp: PromptT, options: ExecutionOptions) => Promise>; } export function getModelName(model: string) { diff --git a/drivers/src/vertexai/models/gemini.ts b/drivers/src/vertexai/models/gemini.ts index 8ba81fe..c9f5b1f 100644 --- a/drivers/src/vertexai/models/gemini.ts +++ b/drivers/src/vertexai/models/gemini.ts @@ -1,5 +1,5 @@ import { Content, FinishReason, GenerateContentRequest, HarmBlockThreshold, HarmCategory, InlineDataPart, ModelParams, ResponseSchema, TextPart } from "@google-cloud/vertexai"; -import { AIModel, Completion, ExecutionOptions, ExecutionTokenUsage, PromptOptions, PromptRole, PromptSegment, readStreamAsBase64 } from "@llumiverse/core"; +import { AIModel, Completion, CompletionChunkObject, ExecutionOptions, ExecutionTokenUsage, PromptOptions, PromptRole, PromptSegment, readStreamAsBase64 } from "@llumiverse/core"; import { asyncMap } from "@llumiverse/core/async"; import { VertexAIDriver } from "../index.js"; import { BuiltinModels, ModelDefinition } from "../models.js"; @@ -173,26 +173,46 @@ export class GeminiModelDefinition implements ModelDefinition> { + async requestCompletionStream(driver: VertexAIDriver, prompt: GenerateContentRequest, options: ExecutionOptions): Promise> { const model = getGenerativeModel(driver, options); const streamingResp = await model.generateContentStream(prompt); const stream = asyncMap(streamingResp.stream, async (item) => { + const usage = item.usageMetadata; + const token_usage: ExecutionTokenUsage = { + prompt: usage?.promptTokenCount, + result: usage?.candidatesTokenCount, + total: usage?.totalTokenCount, + } if (item.candidates && item.candidates.length > 0) { for (const candidate of item.candidates) { + let finish_reason: string | undefined; + switch (candidate.finishReason) { + case FinishReason.MAX_TOKENS: finish_reason = "length"; break; + case FinishReason.STOP: finish_reason = "stop"; break; + default: finish_reason = candidate.finishReason; + } if (candidate.content?.role === 'model') { const text = collectTextParts(candidate.content); - if (text) return text; + return { + result:text, + token_usage: token_usage, + finish_reason: finish_reason, + }; } } } - return ''; + //No normal output, returning block reason if it exists. + return { + result: item.promptFeedback?.blockReasonMessage ?? "", + finish_reason: item.promptFeedback?.blockReason ?? "", + }; }); return stream; diff --git a/drivers/src/watsonx/index.ts b/drivers/src/watsonx/index.ts index c4ddfb8..21cdcc5 100644 --- a/drivers/src/watsonx/index.ts +++ b/drivers/src/watsonx/index.ts @@ -1,4 +1,4 @@ -import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions } from "@llumiverse/core"; +import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, CompletionChunk } from "@llumiverse/core"; import { transformSSEStream } from "@llumiverse/core/async"; import { FetchClient } from "api-fetch-client"; import { GenerateEmbeddingPayload, GenerateEmbeddingResponse, WatsonAuthToken, WatsonxListModelResponse, WatsonxModelSpec, WatsonxTextGenerationPayload, WatsonxTextGenerationResponse } from "./interfaces.js"; @@ -51,12 +51,12 @@ export class WatsonxDriver extends AbstractDriver result: result.generated_token_count, total: result.input_token_count + result.generated_token_count, }, - finish_reason: result.stop_reason, + finish_reason: watsonFinishReason(result.stop_reason), original_response: options.include_original_response ? res : undefined, } } - async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise> { + async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise> { const payload: WatsonxTextGenerationPayload = { model_id: options.model, @@ -75,7 +75,15 @@ export class WatsonxDriver extends AbstractDriver return transformSSEStream(stream, (data: string) => { const json = JSON.parse(data) as WatsonxTextGenerationResponse; - return json.results[0]?.generated_text ?? ''; + return { + result: json.results[0]?.generated_text ?? '', + finish_reason: watsonFinishReason(json.results[0]?.stop_reason), + token_usage: { + prompt: json.results[0].input_token_count, + result: json.results[0].generated_token_count, + total: json.results[0].input_token_count + json.results[0].generated_token_count, + }, + }; }); } @@ -158,6 +166,15 @@ export class WatsonxDriver extends AbstractDriver } +function watsonFinishReason(reason: string | undefined) { + if (!reason) return undefined; + switch (reason) { + case 'eos_token': return "stop"; + case 'max_tokens': return "length"; + default: return reason; + } +} + /*interface ListModelsParams extends ModelSearchPayload { diff --git a/drivers/test/all-models.test.ts b/drivers/test/all-models.test.ts index 1312290..820bdcb 100644 --- a/drivers/test/all-models.test.ts +++ b/drivers/test/all-models.test.ts @@ -193,7 +193,7 @@ describe.concurrent.each(drivers)("Driver $name", ({ name, driver, models }) => test.each(models)(`${name}: execute prompt on %s`, { timeout: TIMEOUT, retry: 3 }, async (model) => { const r = await driver.execute(testPrompt_color, { model, temperature: 0.5, max_tokens: 1024 }); console.debug("Result for " + model, JSON.stringify(r)); - assertCompletionOk(r); + assertCompletionOk(r, model, driver); }); test.each(models)(`${name}: execute prompt with streaming on %s`, { timeout: TIMEOUT, retry: 3 }, async (model) => { @@ -205,7 +205,7 @@ describe.concurrent.each(drivers)("Driver $name", ({ name, driver, models }) => test.each(models)(`${name}: execute prompt with schema on %s`, { timeout: TIMEOUT, retry: 3 }, async (model) => { const r = await driver.execute(testPrompt_color, { model, temperature: 0.5, max_tokens: 1024, result_schema: testSchema_color }); console.log("Result for " + model, JSON.stringify(r.result)); - assertCompletionOk(r); + assertCompletionOk(r, model, driver); }); test.each(models)(`${name}: execute prompt with streaming and schema on %s`, { timeout: TIMEOUT, retry: 3 }, async (model) => { diff --git a/drivers/test/assertions.ts b/drivers/test/assertions.ts index f21a478..4acdbc4 100644 --- a/drivers/test/assertions.ts +++ b/drivers/test/assertions.ts @@ -1,11 +1,19 @@ +import { Bedrock } from '@aws-sdk/client-bedrock'; import { CompletionStream, ExecutionResponse, extractAndParseJSON } from '@llumiverse/core'; import { expect } from "vitest"; +import { BedrockDriver } from '../src'; -export function assertCompletionOk(r: ExecutionResponse) { +export function assertCompletionOk(r: ExecutionResponse, model?: string, driver?) { expect(r.error).toBeFalsy(); expect(r.prompt).toBeTruthy(); - expect(r.token_usage).toBeTruthy(); + //TODO: This just checks for existence of the object, + //could do with more thorough test however not all models support token_usage. + //Only create the object when there is meaningful information you want to interpret as a pass. + if (!(driver?.provider == 'bedrock' && model?.includes("mistral"))) { //Skip if bedrock:mistral, token_usage not supported. + expect(r.token_usage).toBeTruthy(); + } + expect(r.finish_reason).toBeTruthy(); //if r.result is string, it should be longer than 2 if (typeof r.result === 'string') { expect(r.result.length).toBeGreaterThan(2); @@ -31,6 +39,7 @@ export async function assertStreamingCompletionOk(stream: CompletionStream, json expect(r.error).toBeFalsy(); expect(r.prompt).toBeTruthy(); expect(r.token_usage).toBeTruthy(); + expect(r.finish_reason).toBeTruthy(); if (typeof r.result === "string") expect(r.result?.length).toBeGreaterThan(2); return out;