From 55e0250b4ad8d0d24d2993cbce0ebdabc5710c93 Mon Sep 17 00:00:00 2001 From: m4rc3l05 <15786310+M4RC3L05@users.noreply.github.com> Date: Tue, 6 Feb 2024 23:00:37 +0000 Subject: [PATCH] Catch when nodejs request is aborted When we receive a request, we start listening to the close event on the response. When fired, we check if the request was aborted using `incoming.destroyed`, if it was, we dispatch an abort event to the existing request abort signal. Without this, whe the client abortes the request, the signal on the request was not being called with the abort event. Signed-off-by: m4rc3l05 <15786310+M4RC3L05@users.noreply.github.com> --- src/listener.ts | 9 +++- src/request.ts | 15 +++++- test/listener.test.ts | 116 ++++++++++++++++++++++++++++++++++++++++++ test/request.test.ts | 30 ++++++++++- 4 files changed, 166 insertions(+), 4 deletions(-) diff --git a/src/listener.ts b/src/listener.ts index 2388344..cde6cea 100644 --- a/src/listener.ts +++ b/src/listener.ts @@ -1,6 +1,6 @@ import type { IncomingMessage, ServerResponse, OutgoingHttpHeaders } from 'node:http' import type { Http2ServerRequest, Http2ServerResponse } from 'node:http2' -import { newRequest } from './request' +import { getAbortController, newRequest } from './request' import { cacheKey } from './response' import type { CustomErrorHandler, FetchCallback, HttpBindings } from './types' import { writeFromReadableStream, buildOutgoingHttpHeaders } from './utils' @@ -143,6 +143,13 @@ export const getRequestListener = ( // so generate a pseudo Request object with only the minimum required information. const req = newRequest(incoming) + // Detect if request was aborted. + outgoing.on('close', () => { + if (incoming.destroyed) { + req[getAbortController]().abort() + } + }) + try { res = fetchCallback(req, { incoming, outgoing } as HttpBindings) as | Response diff --git a/src/request.ts b/src/request.ts index bd4a0ca..e12f830 100644 --- a/src/request.ts +++ b/src/request.ts @@ -25,7 +25,8 @@ Object.defineProperty(global, 'Request', { const newRequestFromIncoming = ( method: string, url: string, - incoming: IncomingMessage | Http2ServerRequest + incoming: IncomingMessage | Http2ServerRequest, + abortController: AbortController ): Request => { const headerRecord: [string, string][] = [] const rawHeaders = incoming.rawHeaders @@ -39,6 +40,7 @@ const newRequestFromIncoming = ( const init = { method: method, headers: headerRecord, + signal: abortController.signal, } as RequestInit if (!(method === 'GET' || method === 'HEAD')) { @@ -53,6 +55,8 @@ const getRequestCache = Symbol('getRequestCache') const requestCache = Symbol('requestCache') const incomingKey = Symbol('incomingKey') const urlKey = Symbol('urlKey') +const abortControllerKey = Symbol('abortControllerKey') +export const getAbortController = Symbol('getAbortController') const requestPrototype: Record = { get method() { @@ -63,11 +67,18 @@ const requestPrototype: Record = { return this[urlKey] }, + [getAbortController]() { + this[getRequestCache]() + return this[abortControllerKey] + }, + [getRequestCache]() { + this[abortControllerKey] ||= new AbortController() return (this[requestCache] ||= newRequestFromIncoming( this.method, this[urlKey], - this[incomingKey] + this[incomingKey], + this[abortControllerKey] )) }, } diff --git a/test/listener.test.ts b/test/listener.test.ts index 5c10d84..77e6c01 100644 --- a/test/listener.test.ts +++ b/test/listener.test.ts @@ -89,3 +89,119 @@ describe('Error handling - async fetchCallback', () => { expect(res.text).toBe('error handler did not return a response') }) }) + +describe('Abort request', () => { + let onAbort: (req: Request) => void + let reqReadyResolve: () => void + let reqReadyPromise: Promise + const fetchCallback = async (req: Request) => { + req.signal.addEventListener('abort', () => onAbort(req)) + reqReadyResolve?.() + await new Promise(() => {}) // never resolve + } + + const requestListener = getRequestListener(fetchCallback) + + const server = createServer(async (req, res) => { + await requestListener(req, res) + }) + + beforeEach(() => { + reqReadyPromise = new Promise((r) => { + reqReadyResolve = r + }) + }) + + afterAll(() => { + server.close() + }) + + it('should emit an abort event when the nodejs request is aborted', async () => { + const requests: Request[] = [] + const abortedPromise = new Promise((resolve) => { + onAbort = (req) => { + requests.push(req) + resolve() + } + }) + + const req = request(server) + .get('/abort') + .end(() => {}) + + await reqReadyPromise + + req.abort() + + await abortedPromise + + expect(requests).toHaveLength(1) + const abortedReq = requests[0] + expect(abortedReq).toBeInstanceOf(Request) + expect(abortedReq.signal.aborted).toBe(true) + }) + + it('should emit an abort event when the nodejs request is aborted on multiple requests', async () => { + const requests: Request[] = [] + + { + const abortedPromise = new Promise((resolve) => { + onAbort = (req) => { + requests.push(req) + resolve() + } + }) + + reqReadyPromise = new Promise((r) => { + reqReadyResolve = r + }) + + const req = request(server) + .get('/abort') + .end(() => {}) + + await reqReadyPromise + + req.abort() + + await abortedPromise + } + + expect(requests).toHaveLength(1) + + for (const abortedReq of requests) { + expect(abortedReq).toBeInstanceOf(Request) + expect(abortedReq.signal.aborted).toBe(true) + } + + { + const abortedPromise = new Promise((resolve) => { + onAbort = (req) => { + requests.push(req) + resolve() + } + }) + + reqReadyPromise = new Promise((r) => { + reqReadyResolve = r + }) + + const req = request(server) + .get('/abort') + .end(() => {}) + + await reqReadyPromise + + req.abort() + + await abortedPromise + } + + expect(requests).toHaveLength(2) + + for (const abortedReq of requests) { + expect(abortedReq).toBeInstanceOf(Request) + expect(abortedReq.signal.aborted).toBe(true) + } + }) +}) diff --git a/test/request.test.ts b/test/request.test.ts index 9d852f1..d071e13 100644 --- a/test/request.test.ts +++ b/test/request.test.ts @@ -1,5 +1,5 @@ import type { IncomingMessage } from 'node:http' -import { newRequest, Request, GlobalRequest } from '../src/request' +import { newRequest, Request, GlobalRequest, getAbortController } from '../src/request' describe('Request', () => { describe('newRequest', () => { @@ -40,6 +40,34 @@ describe('Request', () => { expect(req).toBeInstanceOf(global.Request) expect(req.url).toBe('http://localhost/foo.txt') }) + + it('should generate only one `AbortController` per `Request` object created', async () => { + const req = newRequest({ + headers: { + host: 'localhost/..', + }, + rawHeaders: ['host', 'localhost/..'], + url: '/foo.txt', + } as IncomingMessage) + const req2 = newRequest({ + headers: { + host: 'localhost/..', + }, + rawHeaders: ['host', 'localhost/..'], + url: '/foo.txt', + } as IncomingMessage) + + const x = req[getAbortController]() + const y = req[getAbortController]() + const z = req2[getAbortController]() + + expect(x).toBeInstanceOf(AbortController) + expect(y).toBeInstanceOf(AbortController) + expect(z).toBeInstanceOf(AbortController) + expect(x).toBe(y) + expect(z).not.toBe(x) + expect(z).not.toBe(y) + }) }) describe('GlobalRequest', () => {