diff --git a/deno_dist/context.ts b/deno_dist/context.ts index 756cc5a64..d9bfce7d8 100644 --- a/deno_dist/context.ts +++ b/deno_dist/context.ts @@ -399,12 +399,12 @@ export class Context< headers?: HeaderRecord ): Response => { const { readable, writable } = new TransformStream() - const stream = new StreamingApi(writable) + const stream = new StreamingApi(writable, readable) cb(stream).finally(() => stream.close()) return typeof arg === 'number' - ? this.newResponse(readable, arg, headers) - : this.newResponse(readable, arg) + ? this.newResponse(stream.responseReadable, arg, headers) + : this.newResponse(stream.responseReadable, arg) } /** @deprecated diff --git a/deno_dist/helper/streaming/sse.ts b/deno_dist/helper/streaming/sse.ts index 10f641a4e..b951ac7f4 100644 --- a/deno_dist/helper/streaming/sse.ts +++ b/deno_dist/helper/streaming/sse.ts @@ -9,8 +9,8 @@ export interface SSEMessage { } export class SSEStreamingApi extends StreamingApi { - constructor(writable: WritableStream) { - super(writable) + constructor(writable: WritableStream, readable: ReadableStream) { + super(writable, readable) } async writeSSE(message: SSEMessage) { @@ -40,9 +40,9 @@ const setSSEHeaders = (context: Context) => { export const streamSSE = (c: Context, cb: (stream: SSEStreamingApi) => Promise) => { return stream(c, async (originalStream: StreamingApi) => { const { readable, writable } = new TransformStream() - const stream = new SSEStreamingApi(writable) + const stream = new SSEStreamingApi(writable, readable) - originalStream.pipe(readable).catch((err) => { + originalStream.pipe(stream.responseReadable).catch((err) => { console.error('Error in stream piping: ', err) stream.close() }) diff --git a/deno_dist/helper/streaming/stream.ts b/deno_dist/helper/streaming/stream.ts index 87b815b75..bd04ecdfb 100644 --- a/deno_dist/helper/streaming/stream.ts +++ b/deno_dist/helper/streaming/stream.ts @@ -3,7 +3,7 @@ import { StreamingApi } from '../../utils/stream.ts' export const stream = (c: Context, cb: (stream: StreamingApi) => Promise): Response => { const { readable, writable } = new TransformStream() - const stream = new StreamingApi(writable) + const stream = new StreamingApi(writable, readable) cb(stream).finally(() => stream.close()) - return c.newResponse(readable) + return c.newResponse(stream.responseReadable) } diff --git a/deno_dist/utils/stream.ts b/deno_dist/utils/stream.ts index 1fcda3223..a2978a403 100644 --- a/deno_dist/utils/stream.ts +++ b/deno_dist/utils/stream.ts @@ -2,11 +2,29 @@ export class StreamingApi { private writer: WritableStreamDefaultWriter private encoder: TextEncoder private writable: WritableStream + private abortSubscribers: (() => void | Promise)[] = [] + responseReadable: ReadableStream - constructor(writable: WritableStream) { + constructor(writable: WritableStream, _readable: ReadableStream) { this.writable = writable this.writer = writable.getWriter() this.encoder = new TextEncoder() + + const reader = _readable.getReader() + + this.responseReadable = new ReadableStream({ + async pull(controller) { + const { done, value } = await reader.read() + if (done) { + controller.close() + } else { + controller.enqueue(value) + } + }, + cancel: () => { + this.abortSubscribers.forEach((subscriber) => subscriber()) + }, + }) } async write(input: Uint8Array | string) { @@ -43,4 +61,8 @@ export class StreamingApi { await body.pipeTo(this.writable, { preventClose: true }) this.writer = this.writable.getWriter() } + + async onAbort(listener: () => void | Promise) { + this.abortSubscribers.push(listener) + } } diff --git a/src/context.test.ts b/src/context.test.ts index 19c5bd1a7..57e2d857c 100644 --- a/src/context.test.ts +++ b/src/context.test.ts @@ -342,4 +342,26 @@ describe('c.render', () => { expect(res.headers.get('foo')).toBe('bar') expect(await res.text()).toBe('title

content

') }) + + it('c.stream() - with aborted during writing', async () => { + let aborted = false + const res = c.stream(async (stream) => { + stream.onAbort(() => { + aborted = true + }) + for (let i = 0; i < 3; i++) { + await stream.write(new Uint8Array([i])) + await stream.sleep(1) + } + }) + if (!res.body) { + throw new Error('Body is null') + } + const reader = res.body.getReader() + for (let i = 0; i < 2; i++) { + await reader.read() + await reader.cancel() + } + expect(aborted).toBe(true) + }) }) diff --git a/src/context.ts b/src/context.ts index 88bd1f1f0..c151cefc6 100644 --- a/src/context.ts +++ b/src/context.ts @@ -399,12 +399,12 @@ export class Context< headers?: HeaderRecord ): Response => { const { readable, writable } = new TransformStream() - const stream = new StreamingApi(writable) + const stream = new StreamingApi(writable, readable) cb(stream).finally(() => stream.close()) return typeof arg === 'number' - ? this.newResponse(readable, arg, headers) - : this.newResponse(readable, arg) + ? this.newResponse(stream.responseReadable, arg, headers) + : this.newResponse(stream.responseReadable, arg) } /** @deprecated diff --git a/src/helper/streaming/sse.ts b/src/helper/streaming/sse.ts index c71f88d2b..2fc8e019f 100644 --- a/src/helper/streaming/sse.ts +++ b/src/helper/streaming/sse.ts @@ -9,8 +9,8 @@ export interface SSEMessage { } export class SSEStreamingApi extends StreamingApi { - constructor(writable: WritableStream) { - super(writable) + constructor(writable: WritableStream, readable: ReadableStream) { + super(writable, readable) } async writeSSE(message: SSEMessage) { @@ -40,9 +40,9 @@ const setSSEHeaders = (context: Context) => { export const streamSSE = (c: Context, cb: (stream: SSEStreamingApi) => Promise) => { return stream(c, async (originalStream: StreamingApi) => { const { readable, writable } = new TransformStream() - const stream = new SSEStreamingApi(writable) + const stream = new SSEStreamingApi(writable, readable) - originalStream.pipe(readable).catch((err) => { + originalStream.pipe(stream.responseReadable).catch((err) => { console.error('Error in stream piping: ', err) stream.close() }) diff --git a/src/helper/streaming/stream.ts b/src/helper/streaming/stream.ts index a5bab3c79..f82679d4f 100644 --- a/src/helper/streaming/stream.ts +++ b/src/helper/streaming/stream.ts @@ -3,7 +3,7 @@ import { StreamingApi } from '../../utils/stream' export const stream = (c: Context, cb: (stream: StreamingApi) => Promise): Response => { const { readable, writable } = new TransformStream() - const stream = new StreamingApi(writable) + const stream = new StreamingApi(writable, readable) cb(stream).finally(() => stream.close()) - return c.newResponse(readable) + return c.newResponse(stream.responseReadable) } diff --git a/src/utils/stream.test.ts b/src/utils/stream.test.ts index 0eefe7668..5ce92ef17 100644 --- a/src/utils/stream.test.ts +++ b/src/utils/stream.test.ts @@ -3,8 +3,8 @@ import { StreamingApi } from './stream' describe('StreamingApi', () => { it('write(string)', async () => { const { readable, writable } = new TransformStream() - const api = new StreamingApi(writable) - const reader = readable.getReader() + const api = new StreamingApi(writable, readable) + const reader = api.responseReadable.getReader() api.write('foo') expect((await reader.read()).value).toEqual(new TextEncoder().encode('foo')) api.write('bar') @@ -13,8 +13,8 @@ describe('StreamingApi', () => { it('write(Uint8Array)', async () => { const { readable, writable } = new TransformStream() - const api = new StreamingApi(writable) - const reader = readable.getReader() + const api = new StreamingApi(writable, readable) + const reader = api.responseReadable.getReader() api.write(new Uint8Array([1, 2, 3])) expect((await reader.read()).value).toEqual(new Uint8Array([1, 2, 3])) api.write(new Uint8Array([4, 5, 6])) @@ -23,8 +23,8 @@ describe('StreamingApi', () => { it('writeln(string)', async () => { const { readable, writable } = new TransformStream() - const api = new StreamingApi(writable) - const reader = readable.getReader() + const api = new StreamingApi(writable, readable) + const reader = api.responseReadable.getReader() api.writeln('foo') expect((await reader.read()).value).toEqual(new TextEncoder().encode('foo\n')) api.writeln('bar') @@ -44,7 +44,7 @@ describe('StreamingApi', () => { const { readable: receiverReadable, writable: receiverWritable } = new TransformStream() - const api = new StreamingApi(receiverWritable) + const api = new StreamingApi(receiverWritable, receiverReadable) // pipe readable to api in other scope ;(async () => { @@ -52,34 +52,48 @@ describe('StreamingApi', () => { })() // read data from api - const reader = receiverReadable.getReader() + const reader = api.responseReadable.getReader() expect((await reader.read()).value).toEqual(new TextEncoder().encode('foo')) expect((await reader.read()).value).toEqual(new TextEncoder().encode('bar')) }) it('close()', async () => { const { readable, writable } = new TransformStream() - const api = new StreamingApi(writable) - const reader = readable.getReader() + const api = new StreamingApi(writable, readable) + const reader = api.responseReadable.getReader() await api.close() expect((await reader.read()).done).toBe(true) }) it('should not throw an error in write()', async () => { - const { writable } = new TransformStream() - const api = new StreamingApi(writable) + const { readable, writable } = new TransformStream() + const api = new StreamingApi(writable, readable) await api.close() const write = () => api.write('foo') expect(write).not.toThrow() }) it('should not throw an error in close()', async () => { - const { writable } = new TransformStream() - const api = new StreamingApi(writable) + const { readable, writable } = new TransformStream() + const api = new StreamingApi(writable, readable) const close = async () => { await api.close() await api.close() } expect(close).not.toThrow() }) + + it('onAbort()', async () => { + const { readable, writable } = new TransformStream() + const handleAbort1 = vi.fn() + const handleAbort2 = vi.fn() + const api = new StreamingApi(writable, readable) + api.onAbort(handleAbort1) + api.onAbort(handleAbort2) + expect(handleAbort1).not.toBeCalled() + expect(handleAbort2).not.toBeCalled() + await api.responseReadable.cancel() + expect(handleAbort1).toBeCalled() + expect(handleAbort2).toBeCalled() + }) }) diff --git a/src/utils/stream.ts b/src/utils/stream.ts index 1fcda3223..a2978a403 100644 --- a/src/utils/stream.ts +++ b/src/utils/stream.ts @@ -2,11 +2,29 @@ export class StreamingApi { private writer: WritableStreamDefaultWriter private encoder: TextEncoder private writable: WritableStream + private abortSubscribers: (() => void | Promise)[] = [] + responseReadable: ReadableStream - constructor(writable: WritableStream) { + constructor(writable: WritableStream, _readable: ReadableStream) { this.writable = writable this.writer = writable.getWriter() this.encoder = new TextEncoder() + + const reader = _readable.getReader() + + this.responseReadable = new ReadableStream({ + async pull(controller) { + const { done, value } = await reader.read() + if (done) { + controller.close() + } else { + controller.enqueue(value) + } + }, + cancel: () => { + this.abortSubscribers.forEach((subscriber) => subscriber()) + }, + }) } async write(input: Uint8Array | string) { @@ -43,4 +61,8 @@ export class StreamingApi { await body.pipeTo(this.writable, { preventClose: true }) this.writer = this.writable.getWriter() } + + async onAbort(listener: () => void | Promise) { + this.abortSubscribers.push(listener) + } }