Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement stream.onAbort #1871

Merged
merged 2 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions deno_dist/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -399,12 +399,12 @@ export class Context<
headers?: HeaderRecord
): Response => {
const { readable, writable } = new TransformStream()
const stream = new StreamingApi(writable)
const stream = new StreamingApi(writable, readable)
cb(stream).finally(() => stream.close())

return typeof arg === 'number'
? this.newResponse(readable, arg, headers)
: this.newResponse(readable, arg)
? this.newResponse(stream.responseReadable, arg, headers)
: this.newResponse(stream.responseReadable, arg)
}

/** @deprecated
Expand Down
8 changes: 4 additions & 4 deletions deno_dist/helper/streaming/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ export interface SSEMessage {
}

export class SSEStreamingApi extends StreamingApi {
constructor(writable: WritableStream) {
super(writable)
constructor(writable: WritableStream, readable: ReadableStream) {
super(writable, readable)
}

async writeSSE(message: SSEMessage) {
Expand Down Expand Up @@ -40,9 +40,9 @@ const setSSEHeaders = (context: Context) => {
export const streamSSE = (c: Context, cb: (stream: SSEStreamingApi) => Promise<void>) => {
return stream(c, async (originalStream: StreamingApi) => {
const { readable, writable } = new TransformStream()
const stream = new SSEStreamingApi(writable)
const stream = new SSEStreamingApi(writable, readable)

originalStream.pipe(readable).catch((err) => {
originalStream.pipe(stream.responseReadable).catch((err) => {
console.error('Error in stream piping: ', err)
stream.close()
})
Expand Down
4 changes: 2 additions & 2 deletions deno_dist/helper/streaming/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { StreamingApi } from '../../utils/stream.ts'

export const stream = (c: Context, cb: (stream: StreamingApi) => Promise<void>): Response => {
const { readable, writable } = new TransformStream()
const stream = new StreamingApi(writable)
const stream = new StreamingApi(writable, readable)
cb(stream).finally(() => stream.close())
return c.newResponse(readable)
return c.newResponse(stream.responseReadable)
}
24 changes: 23 additions & 1 deletion deno_dist/utils/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,29 @@ export class StreamingApi {
private writer: WritableStreamDefaultWriter<Uint8Array>
private encoder: TextEncoder
private writable: WritableStream
private abortSubscribers: (() => void | Promise<void>)[] = []
responseReadable: ReadableStream

constructor(writable: WritableStream) {
constructor(writable: WritableStream, _readable: ReadableStream) {
this.writable = writable
this.writer = writable.getWriter()
this.encoder = new TextEncoder()

const reader = _readable.getReader()

this.responseReadable = new ReadableStream({
async pull(controller) {
const { done, value } = await reader.read()
if (done) {
controller.close()
} else {
controller.enqueue(value)
}
},
cancel: () => {
this.abortSubscribers.forEach((subscriber) => subscriber())
},
})
}

async write(input: Uint8Array | string) {
Expand Down Expand Up @@ -43,4 +61,8 @@ export class StreamingApi {
await body.pipeTo(this.writable, { preventClose: true })
this.writer = this.writable.getWriter()
}

async onAbort(listener: () => void | Promise<void>) {
this.abortSubscribers.push(listener)
}
}
22 changes: 22 additions & 0 deletions src/context.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -342,4 +342,26 @@ describe('c.render', () => {
expect(res.headers.get('foo')).toBe('bar')
expect(await res.text()).toBe('<html><head>title</head><body><h1>content</h1></body></html>')
})

it('c.stream() - with aborted during writing', async () => {
let aborted = false
const res = c.stream(async (stream) => {
stream.onAbort(() => {
aborted = true
})
for (let i = 0; i < 3; i++) {
await stream.write(new Uint8Array([i]))
await stream.sleep(1)
}
})
if (!res.body) {
throw new Error('Body is null')
}
const reader = res.body.getReader()
for (let i = 0; i < 2; i++) {
await reader.read()
await reader.cancel()
}
expect(aborted).toBe(true)
})
})
6 changes: 3 additions & 3 deletions src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -399,12 +399,12 @@ export class Context<
headers?: HeaderRecord
): Response => {
const { readable, writable } = new TransformStream()
const stream = new StreamingApi(writable)
const stream = new StreamingApi(writable, readable)
cb(stream).finally(() => stream.close())

return typeof arg === 'number'
? this.newResponse(readable, arg, headers)
: this.newResponse(readable, arg)
? this.newResponse(stream.responseReadable, arg, headers)
: this.newResponse(stream.responseReadable, arg)
}

/** @deprecated
Expand Down
8 changes: 4 additions & 4 deletions src/helper/streaming/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ export interface SSEMessage {
}

export class SSEStreamingApi extends StreamingApi {
constructor(writable: WritableStream) {
super(writable)
constructor(writable: WritableStream, readable: ReadableStream) {
super(writable, readable)
}

async writeSSE(message: SSEMessage) {
Expand Down Expand Up @@ -40,9 +40,9 @@ const setSSEHeaders = (context: Context) => {
export const streamSSE = (c: Context, cb: (stream: SSEStreamingApi) => Promise<void>) => {
return stream(c, async (originalStream: StreamingApi) => {
const { readable, writable } = new TransformStream()
const stream = new SSEStreamingApi(writable)
const stream = new SSEStreamingApi(writable, readable)

originalStream.pipe(readable).catch((err) => {
originalStream.pipe(stream.responseReadable).catch((err) => {
console.error('Error in stream piping: ', err)
stream.close()
})
Expand Down
4 changes: 2 additions & 2 deletions src/helper/streaming/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { StreamingApi } from '../../utils/stream'

export const stream = (c: Context, cb: (stream: StreamingApi) => Promise<void>): Response => {
const { readable, writable } = new TransformStream()
const stream = new StreamingApi(writable)
const stream = new StreamingApi(writable, readable)
cb(stream).finally(() => stream.close())
return c.newResponse(readable)
return c.newResponse(stream.responseReadable)
}
42 changes: 28 additions & 14 deletions src/utils/stream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ import { StreamingApi } from './stream'
describe('StreamingApi', () => {
it('write(string)', async () => {
const { readable, writable } = new TransformStream()
const api = new StreamingApi(writable)
const reader = readable.getReader()
const api = new StreamingApi(writable, readable)
const reader = api.responseReadable.getReader()
api.write('foo')
expect((await reader.read()).value).toEqual(new TextEncoder().encode('foo'))
api.write('bar')
Expand All @@ -13,8 +13,8 @@ describe('StreamingApi', () => {

it('write(Uint8Array)', async () => {
const { readable, writable } = new TransformStream()
const api = new StreamingApi(writable)
const reader = readable.getReader()
const api = new StreamingApi(writable, readable)
const reader = api.responseReadable.getReader()
api.write(new Uint8Array([1, 2, 3]))
expect((await reader.read()).value).toEqual(new Uint8Array([1, 2, 3]))
api.write(new Uint8Array([4, 5, 6]))
Expand All @@ -23,8 +23,8 @@ describe('StreamingApi', () => {

it('writeln(string)', async () => {
const { readable, writable } = new TransformStream()
const api = new StreamingApi(writable)
const reader = readable.getReader()
const api = new StreamingApi(writable, readable)
const reader = api.responseReadable.getReader()
api.writeln('foo')
expect((await reader.read()).value).toEqual(new TextEncoder().encode('foo\n'))
api.writeln('bar')
Expand All @@ -44,42 +44,56 @@ describe('StreamingApi', () => {

const { readable: receiverReadable, writable: receiverWritable } = new TransformStream()

const api = new StreamingApi(receiverWritable)
const api = new StreamingApi(receiverWritable, receiverReadable)

// pipe readable to api in other scope
;(async () => {
await api.pipe(senderReadable)
})()

// read data from api
const reader = receiverReadable.getReader()
const reader = api.responseReadable.getReader()
expect((await reader.read()).value).toEqual(new TextEncoder().encode('foo'))
expect((await reader.read()).value).toEqual(new TextEncoder().encode('bar'))
})

it('close()', async () => {
const { readable, writable } = new TransformStream()
const api = new StreamingApi(writable)
const reader = readable.getReader()
const api = new StreamingApi(writable, readable)
const reader = api.responseReadable.getReader()
await api.close()
expect((await reader.read()).done).toBe(true)
})

it('should not throw an error in write()', async () => {
const { writable } = new TransformStream()
const api = new StreamingApi(writable)
const { readable, writable } = new TransformStream()
const api = new StreamingApi(writable, readable)
await api.close()
const write = () => api.write('foo')
expect(write).not.toThrow()
})

it('should not throw an error in close()', async () => {
const { writable } = new TransformStream()
const api = new StreamingApi(writable)
const { readable, writable } = new TransformStream()
const api = new StreamingApi(writable, readable)
const close = async () => {
await api.close()
await api.close()
}
expect(close).not.toThrow()
})

it('onAbort()', async () => {
const { readable, writable } = new TransformStream()
const handleAbort1 = vi.fn()
const handleAbort2 = vi.fn()
const api = new StreamingApi(writable, readable)
api.onAbort(handleAbort1)
api.onAbort(handleAbort2)
expect(handleAbort1).not.toBeCalled()
expect(handleAbort2).not.toBeCalled()
await api.responseReadable.cancel()
expect(handleAbort1).toBeCalled()
expect(handleAbort2).toBeCalled()
})
})
24 changes: 23 additions & 1 deletion src/utils/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,29 @@ export class StreamingApi {
private writer: WritableStreamDefaultWriter<Uint8Array>
private encoder: TextEncoder
private writable: WritableStream
private abortSubscribers: (() => void | Promise<void>)[] = []
responseReadable: ReadableStream

constructor(writable: WritableStream) {
constructor(writable: WritableStream, _readable: ReadableStream) {
this.writable = writable
this.writer = writable.getWriter()
this.encoder = new TextEncoder()

const reader = _readable.getReader()

this.responseReadable = new ReadableStream({
async pull(controller) {
const { done, value } = await reader.read()
if (done) {
controller.close()
} else {
controller.enqueue(value)
}
},
cancel: () => {
this.abortSubscribers.forEach((subscriber) => subscriber())
},
})
}

async write(input: Uint8Array | string) {
Expand Down Expand Up @@ -43,4 +61,8 @@ export class StreamingApi {
await body.pipeTo(this.writable, { preventClose: true })
this.writer = this.writable.getWriter()
}

async onAbort(listener: () => void | Promise<void>) {
this.abortSubscribers.push(listener)
}
}
Loading