Skip to content

Commit

Permalink
fix(ClientRequest): passthrough Upgrade requests correctly (#683)
Browse files Browse the repository at this point in the history
  • Loading branch information
kettanaito authored Dec 2, 2024
1 parent 24f5a31 commit cd32d01
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 5 deletions.
31 changes: 27 additions & 4 deletions src/interceptors/ClientRequest/MockHttpSocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,43 @@ export class MockHttpSocket extends MockSocket {
private requestStream?: Readable
private shouldKeepAlive?: boolean

private responseType: 'mock' | 'bypassed' = 'bypassed'
private socketState: 'unknown' | 'mock' | 'passthrough' = 'unknown'
private responseParser: HTTPParser<1>
private responseStream?: Readable
private originalSocket?: net.Socket

constructor(options: MockHttpSocketOptions) {
super({
write: (chunk, encoding, callback) => {
this.writeBuffer.push([chunk, encoding, callback])
// Buffer the writes so they can be flushed in case of the original connection
// and when reading the request body in the interceptor. If the connection has
// been established, no need to buffer the chunks anymore, they will be forwarded.
if (this.socketState !== 'passthrough') {
this.writeBuffer.push([chunk, encoding, callback])
}

if (chunk) {
/**
* Forward any writes to the mock socket to the underlying original socket.
* This ensures functional duplex connections, like WebSocket.
* @see https://github.com/mswjs/interceptors/issues/682
*/
if (this.socketState === 'passthrough') {
this.originalSocket?.write(chunk, encoding, callback)
}

this.requestParser.execute(
Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk, encoding)
)
}
},
read: (chunk) => {
if (chunk !== null) {
/**
* @todo We need to free the parser if the connection has been
* upgraded to a non-HTTP protocol. It won't be able to parse data
* from that point onward anyway. No need to keep it in memory.
*/
this.responseParser.execute(
Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)
)
Expand Down Expand Up @@ -151,11 +171,14 @@ export class MockHttpSocket extends MockSocket {
* its data/events through this Socket.
*/
public passthrough(): void {
this.socketState = 'passthrough'

if (this.destroyed) {
return
}

const socket = this.createConnection()
this.originalSocket = socket

// If the developer destroys the socket, destroy the original connection.
this.once('error', (error) => {
Expand Down Expand Up @@ -276,7 +299,7 @@ export class MockHttpSocket extends MockSocket {
// First, emit all the connection events
// to emulate a successful connection.
this.mockConnect()
this.responseType = 'mock'
this.socketState = 'mock'

// Flush the write buffer to trigger write callbacks
// if it hasn't been flushed already (e.g. someone started reading request stream).
Expand Down Expand Up @@ -581,7 +604,7 @@ export class MockHttpSocket extends MockSocket {

this.responseListenersPromise = this.onResponse({
response,
isMockedResponse: this.responseType === 'mock',
isMockedResponse: this.socketState === 'mock',
requestId: Reflect.get(this.request, kRequestId),
request: this.request,
socket: this,
Expand Down
1 change: 0 additions & 1 deletion src/interceptors/Socket/MockSocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ export class MockSocket extends net.Socket {
args as WriteArgs
)
this.options.write(chunk, encoding, callback)

return super.end.apply(this, args as any)
}

Expand Down
40 changes: 40 additions & 0 deletions test/modules/http/compliance/http-upgrade.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/**
* @see https://github.com/mswjs/interceptors/issues/682
*/
// @vitest-environment node-with-websocket
import { vi, it, expect, beforeAll, afterEach, afterAll } from 'vitest'
import { Server } from 'socket.io'
import { io } from 'socket.io-client'
import { ClientRequestInterceptor } from '../../../../src/interceptors/ClientRequest'

const interceptor = new ClientRequestInterceptor()
const server = new Server(51678)

beforeAll(() => {
interceptor.apply()
})

afterEach(() => {
interceptor.removeAllListeners()
})

afterAll(async () => {
interceptor.dispose()
await new Promise<void>((resolve, reject) => {
server.disconnectSockets()
server.close((error) => {
if (error) reject(error)
resolve()
})
})
})

it('bypasses a WebSocket upgrade request', async () => {
const client = io(`http://localhost:51678`, {
transports: ['websocket'],
})

await vi.waitFor(async () => {
expect(client.connected).toBe(true)
})
})

0 comments on commit cd32d01

Please sign in to comment.