Skip to content

Commit

Permalink
fix(streaming): handle special line characters and fix multi-byte cha…
Browse files Browse the repository at this point in the history
…racter decoding (#757)
  • Loading branch information
stainless-bot authored Apr 4, 2024
1 parent 47ca41d commit 8dcdda2
Show file tree
Hide file tree
Showing 2 changed files with 338 additions and 27 deletions.
120 changes: 94 additions & 26 deletions src/streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,6 @@ export class Stream<Item> implements AsyncIterable<Item> {

static fromSSEResponse<Item>(response: Response, controller: AbortController) {
let consumed = false;
const decoder = new SSEDecoder();

async function* iterMessages(): AsyncGenerator<ServerSentEvent, void, unknown> {
if (!response.body) {
controller.abort();
throw new OpenAIError(`Attempted to iterate over a response with no body`);
}

const lineDecoder = new LineDecoder();

const iter = readableStreamAsyncIterable<Bytes>(response.body);
for await (const chunk of iter) {
for (const line of lineDecoder.decode(chunk)) {
const sse = decoder.decode(line);
if (sse) yield sse;
}
}

for (const line of lineDecoder.flush()) {
const sse = decoder.decode(line);
if (sse) yield sse;
}
}

async function* iterator(): AsyncIterator<Item, any, undefined> {
if (consumed) {
Expand All @@ -54,7 +31,7 @@ export class Stream<Item> implements AsyncIterable<Item> {
consumed = true;
let done = false;
try {
for await (const sse of iterMessages()) {
for await (const sse of _iterSSEMessages(response, controller)) {
if (done) continue;

if (sse.data.startsWith('[DONE]')) {
Expand Down Expand Up @@ -220,6 +197,97 @@ export class Stream<Item> implements AsyncIterable<Item> {
}
}

export async function* _iterSSEMessages(
response: Response,
controller: AbortController,
): AsyncGenerator<ServerSentEvent, void, unknown> {
if (!response.body) {
controller.abort();
throw new OpenAIError(`Attempted to iterate over a response with no body`);
}

const sseDecoder = new SSEDecoder();
const lineDecoder = new LineDecoder();

const iter = readableStreamAsyncIterable<Bytes>(response.body);
for await (const sseChunk of iterSSEChunks(iter)) {
for (const line of lineDecoder.decode(sseChunk)) {
const sse = sseDecoder.decode(line);
if (sse) yield sse;
}
}

for (const line of lineDecoder.flush()) {
const sse = sseDecoder.decode(line);
if (sse) yield sse;
}
}

/**
* Given an async iterable iterator, iterates over it and yields full
* SSE chunks, i.e. yields when a double new-line is encountered.
*/
async function* iterSSEChunks(iterator: AsyncIterableIterator<Bytes>): AsyncGenerator<Uint8Array> {
let data = new Uint8Array();

for await (const chunk of iterator) {
if (chunk == null) {
continue;
}

const binaryChunk =
chunk instanceof ArrayBuffer ? new Uint8Array(chunk)
: typeof chunk === 'string' ? new TextEncoder().encode(chunk)
: chunk;

let newData = new Uint8Array(data.length + binaryChunk.length);
newData.set(data);
newData.set(binaryChunk, data.length);
data = newData;

let patternIndex;
while ((patternIndex = findDoubleNewlineIndex(data)) !== -1) {
yield data.slice(0, patternIndex);
data = data.slice(patternIndex);
}
}

if (data.length > 0) {
yield data;
}
}

function findDoubleNewlineIndex(buffer: Uint8Array): number {
// This function searches the buffer for the end patterns (\r\r, \n\n, \r\n\r\n)
// and returns the index right after the first occurrence of any pattern,
// or -1 if none of the patterns are found.
const newline = 0x0a; // \n
const carriage = 0x0d; // \r

for (let i = 0; i < buffer.length - 2; i++) {
if (buffer[i] === newline && buffer[i + 1] === newline) {
// \n\n
return i + 2;
}
if (buffer[i] === carriage && buffer[i + 1] === carriage) {
// \r\r
return i + 2;
}
if (
buffer[i] === carriage &&
buffer[i + 1] === newline &&
i + 3 < buffer.length &&
buffer[i + 2] === carriage &&
buffer[i + 3] === newline
) {
// \r\n\r\n
return i + 4;
}
}

return -1;
}

class SSEDecoder {
private data: string[];
private event: string | null;
Expand Down Expand Up @@ -283,8 +351,8 @@ class SSEDecoder {
*/
class LineDecoder {
// prettier-ignore
static NEWLINE_CHARS = new Set(['\n', '\r', '\x0b', '\x0c', '\x1c', '\x1d', '\x1e', '\x85', '\u2028', '\u2029']);
static NEWLINE_REGEXP = /\r\n|[\n\r\x0b\x0c\x1c\x1d\x1e\x85\u2028\u2029]/g;
static NEWLINE_CHARS = new Set(['\n', '\r']);
static NEWLINE_REGEXP = /\r\n|[\n\r]/g;

buffer: string[];
trailingCR: boolean;
Expand Down
Loading

0 comments on commit 8dcdda2

Please sign in to comment.