From 181a5dddb650f1b060b88cbe3bf7293ddfecebdf Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Fri, 1 Mar 2024 01:32:50 +0100 Subject: [PATCH] fix(ChatCompletionStream): abort on async iterator break and handle errors (#699) `break`-ing the async iterator did not previously abort the request which increases usage. Errors are now handled more effectively in the async iterator. --- src/lib/ChatCompletionRunFunctions.test.ts | 53 +++++++++++++++++++++- src/lib/ChatCompletionStream.ts | 35 +++++++++++--- 2 files changed, 81 insertions(+), 7 deletions(-) diff --git a/src/lib/ChatCompletionRunFunctions.test.ts b/src/lib/ChatCompletionRunFunctions.test.ts index bb360b217..b524218ae 100644 --- a/src/lib/ChatCompletionRunFunctions.test.ts +++ b/src/lib/ChatCompletionRunFunctions.test.ts @@ -1,5 +1,5 @@ import OpenAI from 'openai'; -import { OpenAIError } from 'openai/error'; +import { OpenAIError, APIConnectionError } from 'openai/error'; import { PassThrough } from 'stream'; import { ParsingToolFunction, @@ -2207,6 +2207,7 @@ describe('resource completions', () => { await listener.sanityCheck(); }); }); + describe('stream', () => { test('successful flow', async () => { const { fetch, handleRequest } = mockStreamingChatCompletionFetch(); @@ -2273,5 +2274,55 @@ describe('resource completions', () => { expect(listener.finalMessage).toEqual({ role: 'assistant', content: 'The weather is great today!' }); await listener.sanityCheck(); }); + test('handles network errors', async () => { + const { fetch, handleRequest } = mockFetch(); + + const openai = new OpenAI({ apiKey: '...', fetch }); + + const stream = openai.beta.chat.completions.stream( + { + max_tokens: 1024, + model: 'gpt-3.5-turbo', + messages: [{ role: 'user', content: 'Say hello there!' }], + }, + { maxRetries: 0 }, + ); + + handleRequest(async () => { + throw new Error('mock request error'); + }).catch(() => {}); + + async function runStream() { + await stream.done(); + } + + await expect(runStream).rejects.toThrow(APIConnectionError); + }); + test('handles network errors on async iterator', async () => { + const { fetch, handleRequest } = mockFetch(); + + const openai = new OpenAI({ apiKey: '...', fetch }); + + const stream = openai.beta.chat.completions.stream( + { + max_tokens: 1024, + model: 'gpt-3.5-turbo', + messages: [{ role: 'user', content: 'Say hello there!' }], + }, + { maxRetries: 0 }, + ); + + handleRequest(async () => { + throw new Error('mock request error'); + }).catch(() => {}); + + async function runStream() { + for await (const _event of stream) { + continue; + } + } + + await expect(runStream).rejects.toThrow(APIConnectionError); + }); }); }); diff --git a/src/lib/ChatCompletionStream.ts b/src/lib/ChatCompletionStream.ts index a2aa7032e..2ea040383 100644 --- a/src/lib/ChatCompletionStream.ts +++ b/src/lib/ChatCompletionStream.ts @@ -210,13 +210,16 @@ export class ChatCompletionStream [Symbol.asyncIterator](): AsyncIterator { const pushQueue: ChatCompletionChunk[] = []; - const readQueue: ((chunk: ChatCompletionChunk | undefined) => void)[] = []; + const readQueue: { + resolve: (chunk: ChatCompletionChunk | undefined) => void; + reject: (err: unknown) => void; + }[] = []; let done = false; this.on('chunk', (chunk) => { const reader = readQueue.shift(); if (reader) { - reader(chunk); + reader.resolve(chunk); } else { pushQueue.push(chunk); } @@ -225,7 +228,23 @@ export class ChatCompletionStream this.on('end', () => { done = true; for (const reader of readQueue) { - reader(undefined); + reader.resolve(undefined); + } + readQueue.length = 0; + }); + + this.on('abort', (err) => { + done = true; + for (const reader of readQueue) { + reader.reject(err); + } + readQueue.length = 0; + }); + + this.on('error', (err) => { + done = true; + for (const reader of readQueue) { + reader.reject(err); } readQueue.length = 0; }); @@ -236,13 +255,17 @@ export class ChatCompletionStream if (done) { return { value: undefined, done: true }; } - return new Promise((resolve) => readQueue.push(resolve)).then( - (chunk) => (chunk ? { value: chunk, done: false } : { value: undefined, done: true }), - ); + return new Promise((resolve, reject) => + readQueue.push({ resolve, reject }), + ).then((chunk) => (chunk ? { value: chunk, done: false } : { value: undefined, done: true })); } const chunk = pushQueue.shift()!; return { value: chunk, done: false }; }, + return: async () => { + this.abort(); + return { value: undefined, done: true }; + }, }; }