From 9a82d29331635977475ba925cdbacda68edd0472 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Thu, 12 Dec 2024 14:06:37 +0100 Subject: [PATCH] fix(llms): add missing events for stream method Signed-off-by: Tomas Dvorak --- src/llms/base.ts | 38 +++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/src/llms/base.ts b/src/llms/base.ts index 5357c0eb..9f87b020 100644 --- a/src/llms/base.ts +++ b/src/llms/base.ts @@ -234,12 +234,40 @@ export abstract class BaseLLM< async (run) => { const cacheEntry = await this.createCacheAccessor(input, options); - const tokens: TOutput[] = []; - for await (const token of cacheEntry.value || this._stream(input, options ?? {}, run)) { - tokens.push(token); - emit(token); + try { + await run.emitter.emit("start", { input, options }); + + const tokenEmitter = run.emitter.child({ groupId: "tokens" }); + const chunks: TOutput[] = []; + const controller = createAbortController(options?.signal); + + for await (const chunk of cacheEntry.value || + this._stream(input, { ...options, signal: controller.signal }, run)) { + if (controller.signal.aborted) { + continue; + } + + chunks.push(chunk); + await tokenEmitter.emit("newToken", { + value: chunk, + callbacks: { abort: () => controller.abort() }, + }); + emit(chunk); + } + const result = this._mergeChunks(chunks); + await run.emitter.emit("success", { value: result }); + cacheEntry.resolve(chunks); + } catch (error) { + await run.emitter.emit("error", { input, error, options }); + await cacheEntry.reject(error); + if (error instanceof LLMError) { + throw error; + } else { + throw new LLMError(`LLM has occurred an error.`, [error]); + } + } finally { + await run.emitter.emit("finish", null); } - cacheEntry.resolve(tokens); }, ).middleware(INSTRUMENTATION_ENABLED ? createTelemetryMiddleware() : doNothing()); });