Skip to content

Commit

Permalink
Rebase on main after completion_output merged
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonRuggiero committed Nov 5, 2024
2 parents 91af894 + dd7622d commit 59ea337
Show file tree
Hide file tree
Showing 19 changed files with 387 additions and 162 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
audience: sts.amazonaws.com
role-to-assume: arn:aws:iam::716085231028:role/ComposablePromptExecutor
role-session-name: github-actions
aws-region: us-west-2
aws-region: us-east-1

- run: npx vitest
env:
Expand Down
40 changes: 30 additions & 10 deletions core/src/CompletionStream.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { AbstractDriver } from "./Driver.js";
import { CompletionStream, DriverOptions, ExecutionOptions, ExecutionResponse } from "./types.js";
import { CompletionStream, DriverOptions, ExecutionOptions, ExecutionResponse, ExecutionTokenUsage } from "./types.js";

export class DefaultCompletionStream<PromptT = any> implements CompletionStream<PromptT> {

Expand Down Expand Up @@ -27,27 +27,47 @@ export class DefaultCompletionStream<PromptT = any> implements CompletionStream<
const start = Date.now();
const stream = await this.driver.requestCompletionStream(this.prompt, this.options);

let finish_reason: string | undefined = undefined;
let promptTokens: number = 0;
let resultTokens: number | undefined = undefined;
for await (const chunk of stream) {
if (chunk) {
chunks.push(chunk);
yield chunk;
if (typeof chunk === 'string') {
chunks.push(chunk);
yield chunk;
}else{
if (chunk.finish_reason) { //Do not replace non-null values with null values
finish_reason = chunk.finish_reason; //Used to skip empty finish_reason chunks coming after "stop" or "length"
}
if (chunk.token_usage) {
//Tokens returned include prior parts of stream,
//so overwrite rather than accumulate
//Math.max used as some models report final token count at beginning of stream
promptTokens = Math.max(promptTokens,chunk.token_usage.prompt ?? 0);
resultTokens = Math.max(resultTokens ?? 0,chunk.token_usage.result ?? 0);
}
if (chunk.result) {
chunks.push(chunk.result);
yield chunk.result;
}
}
}
}

const content = chunks.join('');

const promptTokens = typeof this.prompt === 'string' ? this.prompt.length : JSON.stringify(this.prompt).length;
const resultTokens = content.length; //TODO use chunks.length ?
// Return undefined for the ExecutionTokenUsage object if there is nothing to fill it with.
// Allows for checking for truthyness on token_usage, rather than it's internals. For testing and downstream usage.
let tokens: ExecutionTokenUsage | undefined = resultTokens ?
{ prompt: promptTokens, result: resultTokens, total: resultTokens + promptTokens, } : undefined

this.completion = {
result: content,
prompt: this.prompt,
execution_time: Date.now() - start,
token_usage: {
prompt: promptTokens,
result: resultTokens,
total: resultTokens + promptTokens,
}
token_usage: tokens,
finish_reason: finish_reason,
chunks: chunks.length,
}

this.driver.validateResult(this.completion, this.options);
Expand Down
5 changes: 3 additions & 2 deletions core/src/Driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import {
PromptSegment,
TrainingJob,
TrainingOptions,
TrainingPromptOptions
TrainingPromptOptions,
CompletionChunk
} from "./types.js";
import { validateResult } from "./validation.js";

Expand Down Expand Up @@ -223,7 +224,7 @@ export abstract class AbstractDriver<OptionsT extends DriverOptions = DriverOpti

abstract requestCompletion(prompt: PromptT, options: ExecutionOptions): Promise<Completion>;

abstract requestCompletionStream(prompt: PromptT, options: ExecutionOptions): Promise<AsyncIterable<string>>;
abstract requestCompletionStream(prompt: PromptT, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunk>>;

//list models available for this environement
abstract listModels(params?: ModelSearchPayload): Promise<AIModel[]>;
Expand Down
12 changes: 7 additions & 5 deletions core/src/async.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ServerSentEvent } from "api-fetch-client";
import { CompletionChunk } from "./types.js";

export async function* asyncMap<T, R>(asyncIterable: AsyncIterable<T>, callback: (value: T, index: number) => R) {
let i = 0;
Expand All @@ -15,22 +16,23 @@ export function oneAsyncIterator<T>(value: T): AsyncIterable<T> {
}

/**
* Given a ReadableStream of server seent events, tran
* Given a ReadableStream of server sent events, tran
*/
export function transformSSEStream(stream: ReadableStream<ServerSentEvent>, transform: (data: string) => string): ReadableStream<string> & AsyncIterable<string> {
export function transformSSEStream(stream: ReadableStream<ServerSentEvent>, transform: (data: string) => CompletionChunk): ReadableStream<CompletionChunk> & AsyncIterable<CompletionChunk> {
// on node and bun the readablestream is an async iterable
return stream.pipeThrough(new TransformStream<ServerSentEvent, string>({
return stream.pipeThrough(new TransformStream<ServerSentEvent, CompletionChunk>({
transform(event: ServerSentEvent, controller) {
if (event.type === 'event' && event.data && event.data !== '[DONE]') {
try {
controller.enqueue(transform(event.data) ?? '');
const result = transform(event.data) ?? ''
controller.enqueue(result);
} catch (err) {
// double check for the last event whicb is not a JSON - at this time togetherai and mistralai returrns the string [DONE]
// do nothing - happens if data is not a JSON - the last event data is the [DONE] string
}
}
}
})) as ReadableStream<string> & AsyncIterable<string>;
})) as ReadableStream<CompletionChunk> & AsyncIterable<CompletionChunk>;
}

export class EventStream<T, ReturnT = any> implements AsyncIterable<T>{
Expand Down
18 changes: 18 additions & 0 deletions core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ export interface ResultValidationError {
data?: string;
}

//ResultT should be either JSONObject or string
//Internal structure used in driver implementation.
export interface CompletionChunkObject<ResultT = any> {
result: ResultT;
token_usage?: ExecutionTokenUsage;
finish_reason?: "stop" | "length" | string;
}

//Internal structure used in driver implementation.
export type CompletionChunk = CompletionChunkObject | string;

//ResultT should be either JSONObject or string
export interface Completion<ResultT = any> {
// the driver impl must return the result and optionally the token_usage. the execution time is computed by the extended abstract driver
result: ResultT;
Expand Down Expand Up @@ -69,6 +81,10 @@ export interface ExecutionResponse<PromptT = any> extends Completion {
* The time it took to execute the request in seconds
*/
execution_time?: number;
/**
* The number of chunks for streamed executions
*/
chunks?: number;
}


Expand Down Expand Up @@ -118,6 +134,8 @@ export interface ExecutionOptions extends PromptOptions {
top_p?: number;

/**
* Currently not supported, will be ignored.
* Should be an integer.
* Only supported for OpenAI. Look at OpenAI documentation for more detailsx
*/
top_logprobs?: number;
Expand Down
Loading

0 comments on commit 59ea337

Please sign in to comment.