diff --git a/docs/docs/api/Dispatcher.md b/docs/docs/api/Dispatcher.md index 67819ecd525..ce14bfe1cd2 100644 --- a/docs/docs/api/Dispatcher.md +++ b/docs/docs/api/Dispatcher.md @@ -986,6 +986,57 @@ client.dispatch( ); ``` +##### `dns` + +The `dns` interceptor enables you to cache DNS lookups for a given duration, per origin. + +>It is well suited for scenarios where you want to cache DNS lookups to avoid the overhead of resolving the same domain multiple times + +**Options** +- `maxTTL` - The maximum time-to-live (in milliseconds) of the DNS cache. It should be a positive integer. Default: `10000`. + - Set `0` to disable TTL. +- `maxItems` - The maximum number of items to cache. It should be a positive integer. Default: `Infinity`. +- `dualStack` - Whether to resolve both IPv4 and IPv6 addresses. Default: `true`. + - It will also attempt a happy-eyeballs-like approach to connect to the available addresses in case of a connection failure. +- `affinity` - Whether to use IPv4 or IPv6 addresses. Default: `4`. + - It can be either `'4` or `6`. + - It will only take effect if `dualStack` is `false`. +- `lookup: (hostname: string, options: LookupOptions, callback: (err: NodeJS.ErrnoException | null, addresses: DNSInterceptorRecord[]) => void) => void` - Custom lookup function. Default: `dns.lookup`. + - For more info see [dns.lookup](https://nodejs.org/api/dns.html#dns_dns_lookup_hostname_options_callback). +- `pick: (origin: URL, records: DNSInterceptorRecords, affinity: 4 | 6) => DNSInterceptorRecord` - Custom pick function. Default: `RoundRobin`. + - The function should return a single record from the records array. + - By default a simplified version of Round Robin is used. + - The `records` property can be mutated to store the state of the balancing algorithm. + +> The `Dispatcher#options` also gets extended with the options `dns.affinity`, `dns.dualStack`, `dns.lookup` and `dns.pick` which can be used to configure the interceptor at a request-per-request basis. + + +**DNSInterceptorRecord** +It represents a DNS record. +- `family` - (`number`) The IP family of the address. It can be either `4` or `6`. +- `address` - (`string`) The IP address. + +**DNSInterceptorOriginRecords** +It represents a map of DNS IP addresses records for a single origin. +- `4.ips` - (`DNSInterceptorRecord[] | null`) The IPv4 addresses. +- `6.ips` - (`DNSInterceptorRecord[] | null`) The IPv6 addresses. + +**Example - Basic DNS Interceptor** + +```js +const { Client, interceptors } = require("undici"); +const { dns } = interceptors; + +const client = new Agent().compose([ + dns({ ...opts }) +]) + +const response = await client.request({ + origin: `http://localhost:3030`, + ...requestOpts +}) +``` + ##### `Response Error Interceptor` **Introduction** diff --git a/index.js b/index.js index 7a68d04abb3..0c37ed4853b 100644 --- a/index.js +++ b/index.js @@ -41,7 +41,8 @@ module.exports.createRedirectInterceptor = createRedirectInterceptor module.exports.interceptors = { redirect: require('./lib/interceptor/redirect'), retry: require('./lib/interceptor/retry'), - dump: require('./lib/interceptor/dump') + dump: require('./lib/interceptor/dump'), + dns: require('./lib/interceptor/dns') } module.exports.buildConnector = buildConnector diff --git a/lib/interceptor/dns.js b/lib/interceptor/dns.js new file mode 100644 index 00000000000..917732646e6 --- /dev/null +++ b/lib/interceptor/dns.js @@ -0,0 +1,375 @@ +'use strict' +const { isIP } = require('node:net') +const { lookup } = require('node:dns') +const DecoratorHandler = require('../handler/decorator-handler') +const { InvalidArgumentError, InformationalError } = require('../core/errors') +const maxInt = Math.pow(2, 31) - 1 + +class DNSInstance { + #maxTTL = 0 + #maxItems = 0 + #records = new Map() + dualStack = true + affinity = null + lookup = null + pick = null + + constructor (opts) { + this.#maxTTL = opts.maxTTL + this.#maxItems = opts.maxItems + this.dualStack = opts.dualStack + this.affinity = opts.affinity + this.lookup = opts.lookup ?? this.#defaultLookup + this.pick = opts.pick ?? this.#defaultPick + } + + get full () { + return this.#records.size === this.#maxItems + } + + runLookup (origin, opts, cb) { + const ips = this.#records.get(origin.hostname) + + // If full, we just return the origin + if (ips == null && this.full) { + cb(null, origin.origin) + return + } + + const newOpts = { + affinity: this.affinity, + dualStack: this.dualStack, + lookup: this.lookup, + pick: this.pick, + ...opts.dns, + maxTTL: this.#maxTTL, + maxItems: this.#maxItems + } + + // If no IPs we lookup + if (ips == null) { + this.lookup(origin, newOpts, (err, addresses) => { + if (err || addresses == null || addresses.length === 0) { + cb(err ?? new InformationalError('No DNS entries found')) + return + } + + this.setRecords(origin, addresses) + const records = this.#records.get(origin.hostname) + + const ip = this.pick( + origin, + records, + newOpts.affinity + ) + + let port + if (typeof ip.port === 'number') { + port = `:${ip.port}` + } else if (origin.port !== '') { + port = `:${origin.port}` + } else { + port = '' + } + + cb( + null, + `${origin.protocol}//${ + ip.family === 6 ? `[${ip.address}]` : ip.address + }${port}` + ) + }) + } else { + // If there's IPs we pick + const ip = this.pick( + origin, + ips, + newOpts.affinity + ) + + // If no IPs we lookup - deleting old records + if (ip == null) { + this.#records.delete(origin.hostname) + this.runLookup(origin, opts, cb) + return + } + + let port + if (typeof ip.port === 'number') { + port = `:${ip.port}` + } else if (origin.port !== '') { + port = `:${origin.port}` + } else { + port = '' + } + + cb( + null, + `${origin.protocol}//${ + ip.family === 6 ? `[${ip.address}]` : ip.address + }${port}` + ) + } + } + + #defaultLookup (origin, opts, cb) { + lookup( + origin.hostname, + { + all: true, + family: this.dualStack === false ? this.affinity : 0, + order: 'ipv4first' + }, + (err, addresses) => { + if (err) { + return cb(err) + } + + const results = new Map() + + for (const addr of addresses) { + // On linux we found duplicates, we attempt to remove them with + // the latest record + results.set(`${addr.address}:${addr.family}`, addr) + } + + cb(null, results.values()) + } + ) + } + + #defaultPick (origin, hostnameRecords, affinity) { + let ip = null + const { records, offset } = hostnameRecords + + let family + if (this.dualStack) { + if (affinity == null) { + // Balance between ip families + if (offset == null || offset === maxInt) { + hostnameRecords.offset = 0 + affinity = 4 + } else { + hostnameRecords.offset++ + affinity = (hostnameRecords.offset & 1) === 1 ? 6 : 4 + } + } + + if (records[affinity] != null && records[affinity].ips.length > 0) { + family = records[affinity] + } else { + family = records[affinity === 4 ? 6 : 4] + } + } else { + family = records[affinity] + } + + // If no IPs we return null + if (family == null || family.ips.length === 0) { + return ip + } + + if (family.offset == null || family.offset === maxInt) { + family.offset = 0 + } else { + family.offset++ + } + + const position = family.offset % family.ips.length + ip = family.ips[position] ?? null + + if (ip == null) { + return ip + } + + if (Date.now() - ip.timestamp > ip.ttl) { // record TTL is already in ms + // We delete expired records + // It is possible that they have different TTL, so we manage them individually + family.ips.splice(position, 1) + return this.pick(origin, hostnameRecords, affinity) + } + + return ip + } + + setRecords (origin, addresses) { + const timestamp = Date.now() + const records = { records: { 4: null, 6: null } } + for (const record of addresses) { + record.timestamp = timestamp + if (typeof record.ttl === 'number') { + // The record TTL is expected to be in ms + record.ttl = Math.min(record.ttl, this.#maxTTL) + } else { + record.ttl = this.#maxTTL + } + + const familyRecords = records.records[record.family] ?? { ips: [] } + + familyRecords.ips.push(record) + records.records[record.family] = familyRecords + } + + this.#records.set(origin.hostname, records) + } + + getHandler (meta, opts) { + return new DNSDispatchHandler(this, meta, opts) + } +} + +class DNSDispatchHandler extends DecoratorHandler { + #state = null + #opts = null + #dispatch = null + #handler = null + #origin = null + + constructor (state, { origin, handler, dispatch }, opts) { + super(handler) + this.#origin = origin + this.#handler = handler + this.#opts = { ...opts } + this.#state = state + this.#dispatch = dispatch + } + + onError (err) { + switch (err.code) { + case 'ETIMEDOUT': + case 'ECONNREFUSED': { + if (this.#state.dualStack) { + // We delete the record and retry + this.#state.runLookup(this.#origin, this.#opts, (err, newOrigin) => { + if (err) { + return this.#handler.onError(err) + } + + const dispatchOpts = { + ...this.#opts, + origin: newOrigin + } + + this.#dispatch(dispatchOpts, this) + }) + + // if dual-stack disabled, we error out + return + } + + this.#handler.onError(err) + return + } + case 'ENOTFOUND': + this.#state.deleteRecord(this.#origin) + // eslint-disable-next-line no-fallthrough + default: + this.#handler.onError(err) + break + } + } +} + +module.exports = interceptorOpts => { + if ( + interceptorOpts?.maxTTL != null && + (typeof interceptorOpts?.maxTTL !== 'number' || interceptorOpts?.maxTTL < 0) + ) { + throw new InvalidArgumentError('Invalid maxTTL. Must be a positive number') + } + + if ( + interceptorOpts?.maxItems != null && + (typeof interceptorOpts?.maxItems !== 'number' || + interceptorOpts?.maxItems < 1) + ) { + throw new InvalidArgumentError( + 'Invalid maxItems. Must be a positive number and greater than zero' + ) + } + + if ( + interceptorOpts?.affinity != null && + interceptorOpts?.affinity !== 4 && + interceptorOpts?.affinity !== 6 + ) { + throw new InvalidArgumentError('Invalid affinity. Must be either 4 or 6') + } + + if ( + interceptorOpts?.dualStack != null && + typeof interceptorOpts?.dualStack !== 'boolean' + ) { + throw new InvalidArgumentError('Invalid dualStack. Must be a boolean') + } + + if ( + interceptorOpts?.lookup != null && + typeof interceptorOpts?.lookup !== 'function' + ) { + throw new InvalidArgumentError('Invalid lookup. Must be a function') + } + + if ( + interceptorOpts?.pick != null && + typeof interceptorOpts?.pick !== 'function' + ) { + throw new InvalidArgumentError('Invalid pick. Must be a function') + } + + const dualStack = interceptorOpts?.dualStack ?? true + let affinity + if (dualStack) { + affinity = interceptorOpts?.affinity ?? null + } else { + affinity = interceptorOpts?.affinity ?? 4 + } + + const opts = { + maxTTL: interceptorOpts?.maxTTL ?? 10e3, // Expressed in ms + lookup: interceptorOpts?.lookup ?? null, + pick: interceptorOpts?.pick ?? null, + dualStack, + affinity, + maxItems: interceptorOpts?.maxItems ?? Infinity + } + + const instance = new DNSInstance(opts) + + return dispatch => { + return function dnsInterceptor (origDispatchOpts, handler) { + const origin = + origDispatchOpts.origin.constructor === URL + ? origDispatchOpts.origin + : new URL(origDispatchOpts.origin) + + if (isIP(origin.hostname) !== 0) { + return dispatch(origDispatchOpts, handler) + } + + instance.runLookup(origin, origDispatchOpts, (err, newOrigin) => { + if (err) { + return handler.onError(err) + } + + let dispatchOpts = null + dispatchOpts = { + ...origDispatchOpts, + servername: origin.hostname, // For SNI on TLS + origin: newOrigin, + headers: { + host: origin.hostname, + ...origDispatchOpts.headers + } + } + + dispatch( + dispatchOpts, + instance.getHandler({ origin, dispatch, handler }, origDispatchOpts) + ) + }) + + return true + } + } +} diff --git a/test/interceptors/dns.js b/test/interceptors/dns.js new file mode 100644 index 00000000000..6b4b30b13cc --- /dev/null +++ b/test/interceptors/dns.js @@ -0,0 +1,1721 @@ +'use strict' + +const { test, after } = require('node:test') +const { isIP } = require('node:net') +const { lookup } = require('node:dns') +const { createServer } = require('node:http') +const { createServer: createSecureServer } = require('node:https') +const { once } = require('node:events') +const { setTimeout: sleep } = require('node:timers/promises') + +const { tspl } = require('@matteo.collina/tspl') +const pem = require('https-pem') + +const { interceptors, Agent } = require('../..') +const { dns } = interceptors + +test('Should validate options', t => { + t = tspl(t, { plan: 10 }) + + t.throws(() => dns({ dualStack: 'true' }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ dualStack: 0 }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ affinity: '4' }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ affinity: 7 }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ maxTTL: -1 }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ maxTTL: '0' }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ maxItems: '1' }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ maxItems: -1 }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ lookup: {} }), { code: 'UND_ERR_INVALID_ARG' }) + t.throws(() => dns({ pick: [] }), { code: 'UND_ERR_INVALID_ARG' }) +}) + +test('Should automatically resolve IPs (dual stack)', async t => { + t = tspl(t, { plan: 8 }) + + const hostsnames = [] + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + const url = new URL(opts.origin) + + t.equal(hostsnames.includes(url.hostname), false) + + if (url.hostname[0] === '[') { + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + } else { + t.equal(isIP(url.hostname), 4) + } + + hostsnames.push(url.hostname) + + return dispatch(opts, handler) + } + }, + dns({ + lookup: (_origin, _opts, cb) => { + cb(null, [ + { + address: '::1', + family: 6 + }, + { + address: '127.0.0.1', + family: 4 + } + ]) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') +}) + +test('Should respect DNS origin hostname for SNI on TLS', async t => { + t = tspl(t, { plan: 12 }) + + const hostsnames = [] + const server = createSecureServer(pem) + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + t.equal(req.headers.host, 'localhost') + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent({ + connect: { + rejectUnauthorized: false + } + }).compose([ + dispatch => { + return (opts, handler) => { + const url = new URL(opts.origin) + + t.equal(hostsnames.includes(url.hostname), false) + t.equal(opts.servername, 'localhost') + + if (url.hostname[0] === '[') { + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + } else { + t.equal(isIP(url.hostname), 4) + } + + hostsnames.push(url.hostname) + + return dispatch(opts, handler) + } + }, + dns({ + lookup: (_origin, _opts, cb) => { + cb(null, [ + { + address: '::1', + family: 6 + }, + { + address: '127.0.0.1', + family: 4 + } + ]) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `https://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `https://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') +}) + +test('Should recover on network errors (dual stack - 4)', async t => { + t = tspl(t, { plan: 8 }) + + let counter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0, '::1') + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + case 3: + // [::1] -> ::1 + t.equal(isIP(url.hostname), 4) + break + + case 4: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ + lookup: (_origin, _opts, cb) => { + cb(null, [ + { + address: '::1', + family: 6 + }, + { + address: '127.0.0.1', + family: 4 + } + ]) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') +}) + +test('Should recover on network errors (dual stack - 6)', async t => { + t = tspl(t, { plan: 7 }) + + let counter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0, '127.0.0.1') + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + case 3: + // [::1] -> ::1 + t.equal(isIP(url.hostname), 4) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ + lookup: (_origin, _opts, cb) => { + cb(null, [ + { + address: '::1', + family: 6 + }, + { + address: '127.0.0.1', + family: 4 + } + ]) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') +}) + +test('Should throw when on dual-stack disabled (4)', async t => { + t = tspl(t, { plan: 2 }) + + let counter = 0 + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ dualStack: false, affinity: 4 }) + ]) + + const promise = client.request({ + ...requestOptions, + origin: 'http://localhost:1234' + }) + + await t.rejects(promise, 'ECONNREFUSED') + + await t.complete +}) + +test('Should throw when on dual-stack disabled (6)', async t => { + t = tspl(t, { plan: 2 }) + + let counter = 0 + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ dualStack: false, affinity: 6 }) + ]) + + const promise = client.request({ + ...requestOptions, + origin: 'http://localhost:9999' + }) + + await t.rejects(promise, 'ECONNREFUSED') + + await t.complete +}) + +test('Should automatically resolve IPs (dual stack disabled - 4)', async t => { + t = tspl(t, { plan: 6 }) + + let counter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname), 4) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ dualStack: false }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') +}) + +test('Should automatically resolve IPs (dual stack disabled - 6)', async t => { + t = tspl(t, { plan: 6 }) + + let counter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ dualStack: false, affinity: 6 }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') +}) + +test('Should we handle TTL (4)', async t => { + t = tspl(t, { plan: 10 }) + + let counter = 0 + let lookupCounter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0, '127.0.0.1') + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + case 2: + t.equal(isIP(url.hostname), 4) + break + + case 3: + t.equal(isIP(url.hostname), 4) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ + dualStack: false, + affinity: 4, + maxTTL: 400, + lookup: (origin, opts, cb) => { + ++lookupCounter + lookup( + origin.hostname, + { all: true, family: opts.affinity }, + cb + ) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + await sleep(200) + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') + + await sleep(300) + + const response3 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response3.statusCode, 200) + t.equal(await response3.body.text(), 'hello world!') + + t.equal(lookupCounter, 2) +}) + +test('Should we handle TTL (6)', async t => { + t = tspl(t, { plan: 10 }) + + let counter = 0 + let lookupCounter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0, '::1') + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + case 3: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ + dualStack: false, + affinity: 6, + maxTTL: 400, + lookup: (origin, opts, cb) => { + ++lookupCounter + lookup( + origin.hostname, + { all: true, family: opts.affinity }, + cb + ) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + await sleep(200) + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') + + await sleep(300) + + const response3 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response3.statusCode, 200) + t.equal(await response3.body.text(), 'hello world!') + t.equal(lookupCounter, 2) +}) + +test('Should set lowest TTL between resolved and option maxTTL', async t => { + t = tspl(t, { plan: 9 }) + + let lookupCounter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0, '127.0.0.1') + + await once(server, 'listening') + + const client = new Agent().compose( + dns({ + dualStack: false, + affinity: 4, + maxTTL: 200, + lookup: (origin, opts, cb) => { + ++lookupCounter + cb(null, [ + { + address: '127.0.0.1', + family: 4, + ttl: lookupCounter === 1 ? 50 : 500 + } + ]) + } + }) + ) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + await sleep(100) + + // 100ms: lookup since ttl = Math.min(50, maxTTL: 200) + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') + + await sleep(100) + + // 100ms: cached since ttl = Math.min(500, maxTTL: 200) + const response3 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response3.statusCode, 200) + t.equal(await response3.body.text(), 'hello world!') + + await sleep(150) + + // 250ms: lookup since ttl = Math.min(500, maxTTL: 200) + const response4 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response4.statusCode, 200) + t.equal(await response4.body.text(), 'hello world!') + + t.equal(lookupCounter, 3) +}) + +test('Should use all dns entries (dual stack)', async t => { + t = tspl(t, { plan: 16 }) + + let counter = 0 + let lookupCounter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + switch (counter) { + case 1: + t.equal(url.hostname, '1.1.1.1') + break + + case 2: + t.equal(url.hostname, '[::1]') + break + + case 3: + t.equal(url.hostname, '2.2.2.2') + break + + case 4: + t.equal(url.hostname, '[::2]') + break + + case 5: + t.equal(url.hostname, '1.1.1.1') + break + default: + t.fail('should not reach this point') + } + + url.hostname = '127.0.0.1' + opts.origin = url.toString() + return dispatch(opts, handler) + } + }, + dns({ + lookup (origin, opts, cb) { + lookupCounter++ + cb(null, [ + { address: '::1', family: 6 }, + { address: '::2', family: 6 }, + { address: '1.1.1.1', family: 4 }, + { address: '2.2.2.2', family: 4 } + ]) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + for (let i = 0; i < 5; i++) { + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + } + + t.equal(lookupCounter, 1) +}) + +test('Should use all dns entries (dual stack disabled - 4)', async t => { + t = tspl(t, { plan: 10 }) + + let counter = 0 + let lookupCounter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(url.hostname, '1.1.1.1') + break + + case 2: + t.equal(url.hostname, '2.2.2.2') + break + + case 3: + t.equal(url.hostname, '1.1.1.1') + break + default: + t.fail('should not reach this point') + } + + url.hostname = '127.0.0.1' + opts.origin = url.toString() + return dispatch(opts, handler) + } + }, + dns({ + dualStack: false, + lookup (origin, opts, cb) { + lookupCounter++ + cb(null, [ + { address: '1.1.1.1', family: 4 }, + { address: '2.2.2.2', family: 4 } + ]) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response1 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response1.statusCode, 200) + t.equal(await response1.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') + + const response3 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response3.statusCode, 200) + t.equal(await response3.body.text(), 'hello world!') + + t.equal(lookupCounter, 1) +}) + +test('Should use all dns entries (dual stack disabled - 6)', async t => { + t = tspl(t, { plan: 10 }) + + let counter = 0 + let lookupCounter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(url.hostname, '[::1]') + break + + case 2: + t.equal(url.hostname, '[::2]') + break + + case 3: + t.equal(url.hostname, '[::1]') + break + default: + t.fail('should not reach this point') + } + + url.hostname = '127.0.0.1' + opts.origin = url.toString() + return dispatch(opts, handler) + } + }, + dns({ + dualStack: false, + affinity: 6, + lookup (origin, opts, cb) { + lookupCounter++ + cb(null, [ + { address: '::1', family: 6 }, + { address: '::2', family: 6 } + ]) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response1 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response1.statusCode, 200) + t.equal(await response1.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') + + const response3 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response3.statusCode, 200) + t.equal(await response3.body.text(), 'hello world!') + + t.equal(lookupCounter, 1) +}) + +test('Should handle single family resolved (dual stack)', async t => { + t = tspl(t, { plan: 7 }) + + let counter = 0 + let lookupCounter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + default: + t.fail('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ + lookup (origin, opts, cb) { + lookupCounter++ + if (lookupCounter === 1) { + cb(null, [ + { address: '127.0.0.1', family: 4, ttl: 50 } + ]) + } else { + cb(null, [ + { address: '::1', family: 6, ttl: 50 } + ]) + } + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + await sleep(100) + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') + + t.equal(lookupCounter, 2) +}) + +test('Should prefer affinity (dual stack - 4)', async t => { + t = tspl(t, { plan: 10 }) + + let counter = 0 + let lookupCounter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(url.hostname, '1.1.1.1') + break + + case 2: + t.equal(url.hostname, '2.2.2.2') + break + + case 3: + t.equal(url.hostname, '1.1.1.1') + break + default: + t.fail('should not reach this point') + } + + url.hostname = '127.0.0.1' + opts.origin = url.toString() + return dispatch(opts, handler) + } + }, + dns({ + affinity: 4, + lookup (origin, opts, cb) { + lookupCounter++ + cb(null, [ + { address: '1.1.1.1', family: 4 }, + { address: '2.2.2.2', family: 4 }, + { address: '::1', family: 6 }, + { address: '::2', family: 6 } + ]) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + await sleep(100) + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') + + const response3 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response3.statusCode, 200) + t.equal(await response3.body.text(), 'hello world!') + + t.equal(lookupCounter, 1) +}) + +test('Should prefer affinity (dual stack - 6)', async t => { + t = tspl(t, { plan: 10 }) + + let counter = 0 + let lookupCounter = 0 + const server = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server.listen(0) + + await once(server, 'listening') + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(url.hostname, '[::1]') + break + + case 2: + t.equal(url.hostname, '[::2]') + break + + case 3: + t.equal(url.hostname, '[::1]') + break + default: + t.fail('should not reach this point') + } + + url.hostname = '127.0.0.1' + opts.origin = url.toString() + return dispatch(opts, handler) + } + }, + dns({ + affinity: 6, + lookup (origin, opts, cb) { + lookupCounter++ + cb(null, [ + { address: '1.1.1.1', family: 4 }, + { address: '2.2.2.2', family: 4 }, + { address: '::1', family: 6 }, + { address: '::2', family: 6 } + ]) + } + }) + ]) + + after(async () => { + await client.close() + server.close() + + await once(server, 'close') + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + await sleep(100) + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') + + const response3 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server.address().port}` + }) + + t.equal(response3.statusCode, 200) + t.equal(await response3.body.text(), 'hello world!') + + t.equal(lookupCounter, 1) +}) + +test('Should use resolved ports (4)', async t => { + t = tspl(t, { plan: 5 }) + + let lookupCounter = 0 + const server1 = createServer() + const server2 = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server1.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server1.listen(0) + + server2.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world! (x2)') + }) + server2.listen(0) + + await Promise.all([once(server1, 'listening'), once(server2, 'listening')]) + + const client = new Agent().compose([ + dns({ + lookup (origin, opts, cb) { + lookupCounter++ + cb(null, [ + { address: '127.0.0.1', family: 4, port: server1.address().port }, + { address: '127.0.0.1', family: 4, port: server2.address().port } + ]) + } + }) + ]) + + after(async () => { + await client.close() + server1.close() + server2.close() + + await Promise.all([once(server1, 'close'), once(server2, 'close')]) + }) + + const response = await client.request({ + ...requestOptions, + origin: 'http://localhost' + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: 'http://localhost' + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world! (x2)') + + t.equal(lookupCounter, 1) +}) + +test('Should use resolved ports (6)', async t => { + t = tspl(t, { plan: 5 }) + + let lookupCounter = 0 + const server1 = createServer() + const server2 = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server1.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server1.listen(0, '::1') + + server2.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world! (x2)') + }) + server2.listen(0, '::1') + + await Promise.all([once(server1, 'listening'), once(server2, 'listening')]) + + const client = new Agent().compose([ + dns({ + lookup (origin, opts, cb) { + lookupCounter++ + cb(null, [ + { address: '::1', family: 6, port: server1.address().port }, + { address: '::1', family: 6, port: server2.address().port } + ]) + } + }) + ]) + + after(async () => { + await client.close() + server1.close() + server2.close() + + await Promise.all([once(server1, 'close'), once(server2, 'close')]) + }) + + const response = await client.request({ + ...requestOptions, + origin: 'http://localhost' + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: 'http://localhost' + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world! (x2)') + + t.equal(lookupCounter, 1) +}) + +test('Should handle max cached items', async t => { + t = tspl(t, { plan: 9 }) + + let counter = 0 + const server1 = createServer() + const server2 = createServer() + const requestOptions = { + method: 'GET', + path: '/', + headers: { + 'content-type': 'application/json' + } + } + + server1.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world!') + }) + + server1.listen(0) + + server2.on('request', (req, res) => { + res.writeHead(200, { 'content-type': 'text/plain' }) + res.end('hello world! (x2)') + }) + server2.listen(0) + + await Promise.all([once(server1, 'listening'), once(server2, 'listening')]) + + const client = new Agent().compose([ + dispatch => { + return (opts, handler) => { + ++counter + const url = new URL(opts.origin) + + switch (counter) { + case 1: + t.equal(isIP(url.hostname), 4) + break + + case 2: + // [::1] -> ::1 + t.equal(isIP(url.hostname.slice(1, 4)), 6) + break + + case 3: + t.equal(url.hostname, 'developer.mozilla.org') + // Rewrite origin to avoid reaching internet + opts.origin = `http://127.0.0.1:${server2.address().port}` + break + default: + t.fails('should not reach this point') + } + + return dispatch(opts, handler) + } + }, + dns({ + maxItems: 1, + lookup: (_origin, _opts, cb) => { + cb(null, [ + { + address: '::1', + family: 6 + }, + { + address: '127.0.0.1', + family: 4 + } + ]) + } + }) + ]) + + after(async () => { + await client.close() + server1.close() + server2.close() + + await Promise.all([once(server1, 'close'), once(server2, 'close')]) + }) + + const response = await client.request({ + ...requestOptions, + origin: `http://localhost:${server1.address().port}` + }) + + t.equal(response.statusCode, 200) + t.equal(await response.body.text(), 'hello world!') + + const response2 = await client.request({ + ...requestOptions, + origin: `http://localhost:${server1.address().port}` + }) + + t.equal(response2.statusCode, 200) + t.equal(await response2.body.text(), 'hello world!') + + const response3 = await client.request({ + ...requestOptions, + origin: 'https://developer.mozilla.org' + }) + + t.equal(response3.statusCode, 200) + t.equal(await response3.body.text(), 'hello world! (x2)') +}) diff --git a/types/interceptors.d.ts b/types/interceptors.d.ts index 24166b61f4f..65e9397554e 100644 --- a/types/interceptors.d.ts +++ b/types/interceptors.d.ts @@ -1,3 +1,5 @@ +import { LookupOptions } from 'node:dns' + import Dispatcher from "./dispatcher"; import RetryHandler from "./retry-handler"; @@ -9,6 +11,18 @@ declare namespace Interceptors { export type RedirectInterceptorOpts = { maxRedirections?: number } export type ResponseErrorInterceptorOpts = { throwOnError: boolean } + // DNS interceptor + export type DNSInterceptorRecord = { address: string, ttl: number, family: 4 | 6 } + export type DNSInterceptorOriginRecords = { 4: { ips: DNSInterceptorRecord[] } | null, 6: { ips: DNSInterceptorRecord[] } | null } + export type DNSInterceptorOpts = { + maxTTL?: number + maxItems?: number + lookup?: (hostname: string, options: LookupOptions, callback: (err: NodeJS.ErrnoException | null, addresses: DNSInterceptorRecord[]) => void) => void + pick?: (origin: URL, records: DNSInterceptorOriginRecords, affinity: 4 | 6) => DNSInterceptorRecord + dualStack?: boolean + affinity?: 4 | 6 + } + export function createRedirectInterceptor(opts: RedirectInterceptorOpts): Dispatcher.DispatcherComposeInterceptor export function dump(opts?: DumpInterceptorOpts): Dispatcher.DispatcherComposeInterceptor export function retry(opts?: RetryInterceptorOpts): Dispatcher.DispatcherComposeInterceptor