diff --git a/src/listener.ts b/src/listener.ts index 2388344..cde6cea 100644 --- a/src/listener.ts +++ b/src/listener.ts @@ -1,6 +1,6 @@ import type { IncomingMessage, ServerResponse, OutgoingHttpHeaders } from 'node:http' import type { Http2ServerRequest, Http2ServerResponse } from 'node:http2' -import { newRequest } from './request' +import { getAbortController, newRequest } from './request' import { cacheKey } from './response' import type { CustomErrorHandler, FetchCallback, HttpBindings } from './types' import { writeFromReadableStream, buildOutgoingHttpHeaders } from './utils' @@ -143,6 +143,13 @@ export const getRequestListener = ( // so generate a pseudo Request object with only the minimum required information. const req = newRequest(incoming) + // Detect if request was aborted. + outgoing.on('close', () => { + if (incoming.destroyed) { + req[getAbortController]().abort() + } + }) + try { res = fetchCallback(req, { incoming, outgoing } as HttpBindings) as | Response diff --git a/src/request.ts b/src/request.ts index bd4a0ca..e12f830 100644 --- a/src/request.ts +++ b/src/request.ts @@ -25,7 +25,8 @@ Object.defineProperty(global, 'Request', { const newRequestFromIncoming = ( method: string, url: string, - incoming: IncomingMessage | Http2ServerRequest + incoming: IncomingMessage | Http2ServerRequest, + abortController: AbortController ): Request => { const headerRecord: [string, string][] = [] const rawHeaders = incoming.rawHeaders @@ -39,6 +40,7 @@ const newRequestFromIncoming = ( const init = { method: method, headers: headerRecord, + signal: abortController.signal, } as RequestInit if (!(method === 'GET' || method === 'HEAD')) { @@ -53,6 +55,8 @@ const getRequestCache = Symbol('getRequestCache') const requestCache = Symbol('requestCache') const incomingKey = Symbol('incomingKey') const urlKey = Symbol('urlKey') +const abortControllerKey = Symbol('abortControllerKey') +export const getAbortController = Symbol('getAbortController') const requestPrototype: Record = { get method() { @@ -63,11 +67,18 @@ const requestPrototype: Record = { return this[urlKey] }, + [getAbortController]() { + this[getRequestCache]() + return this[abortControllerKey] + }, + [getRequestCache]() { + this[abortControllerKey] ||= new AbortController() return (this[requestCache] ||= newRequestFromIncoming( this.method, this[urlKey], - this[incomingKey] + this[incomingKey], + this[abortControllerKey] )) }, } diff --git a/test/listener.test.ts b/test/listener.test.ts index 5c10d84..77e6c01 100644 --- a/test/listener.test.ts +++ b/test/listener.test.ts @@ -89,3 +89,119 @@ describe('Error handling - async fetchCallback', () => { expect(res.text).toBe('error handler did not return a response') }) }) + +describe('Abort request', () => { + let onAbort: (req: Request) => void + let reqReadyResolve: () => void + let reqReadyPromise: Promise + const fetchCallback = async (req: Request) => { + req.signal.addEventListener('abort', () => onAbort(req)) + reqReadyResolve?.() + await new Promise(() => {}) // never resolve + } + + const requestListener = getRequestListener(fetchCallback) + + const server = createServer(async (req, res) => { + await requestListener(req, res) + }) + + beforeEach(() => { + reqReadyPromise = new Promise((r) => { + reqReadyResolve = r + }) + }) + + afterAll(() => { + server.close() + }) + + it('should emit an abort event when the nodejs request is aborted', async () => { + const requests: Request[] = [] + const abortedPromise = new Promise((resolve) => { + onAbort = (req) => { + requests.push(req) + resolve() + } + }) + + const req = request(server) + .get('/abort') + .end(() => {}) + + await reqReadyPromise + + req.abort() + + await abortedPromise + + expect(requests).toHaveLength(1) + const abortedReq = requests[0] + expect(abortedReq).toBeInstanceOf(Request) + expect(abortedReq.signal.aborted).toBe(true) + }) + + it('should emit an abort event when the nodejs request is aborted on multiple requests', async () => { + const requests: Request[] = [] + + { + const abortedPromise = new Promise((resolve) => { + onAbort = (req) => { + requests.push(req) + resolve() + } + }) + + reqReadyPromise = new Promise((r) => { + reqReadyResolve = r + }) + + const req = request(server) + .get('/abort') + .end(() => {}) + + await reqReadyPromise + + req.abort() + + await abortedPromise + } + + expect(requests).toHaveLength(1) + + for (const abortedReq of requests) { + expect(abortedReq).toBeInstanceOf(Request) + expect(abortedReq.signal.aborted).toBe(true) + } + + { + const abortedPromise = new Promise((resolve) => { + onAbort = (req) => { + requests.push(req) + resolve() + } + }) + + reqReadyPromise = new Promise((r) => { + reqReadyResolve = r + }) + + const req = request(server) + .get('/abort') + .end(() => {}) + + await reqReadyPromise + + req.abort() + + await abortedPromise + } + + expect(requests).toHaveLength(2) + + for (const abortedReq of requests) { + expect(abortedReq).toBeInstanceOf(Request) + expect(abortedReq.signal.aborted).toBe(true) + } + }) +}) diff --git a/test/request.test.ts b/test/request.test.ts index 9d852f1..d071e13 100644 --- a/test/request.test.ts +++ b/test/request.test.ts @@ -1,5 +1,5 @@ import type { IncomingMessage } from 'node:http' -import { newRequest, Request, GlobalRequest } from '../src/request' +import { newRequest, Request, GlobalRequest, getAbortController } from '../src/request' describe('Request', () => { describe('newRequest', () => { @@ -40,6 +40,34 @@ describe('Request', () => { expect(req).toBeInstanceOf(global.Request) expect(req.url).toBe('http://localhost/foo.txt') }) + + it('should generate only one `AbortController` per `Request` object created', async () => { + const req = newRequest({ + headers: { + host: 'localhost/..', + }, + rawHeaders: ['host', 'localhost/..'], + url: '/foo.txt', + } as IncomingMessage) + const req2 = newRequest({ + headers: { + host: 'localhost/..', + }, + rawHeaders: ['host', 'localhost/..'], + url: '/foo.txt', + } as IncomingMessage) + + const x = req[getAbortController]() + const y = req[getAbortController]() + const z = req2[getAbortController]() + + expect(x).toBeInstanceOf(AbortController) + expect(y).toBeInstanceOf(AbortController) + expect(z).toBeInstanceOf(AbortController) + expect(x).toBe(y) + expect(z).not.toBe(x) + expect(z).not.toBe(y) + }) }) describe('GlobalRequest', () => {