diff --git a/lib/web/fetch/util.js b/lib/web/fetch/util.js index 37c8c78fc5d..49ab59e8876 100644 --- a/lib/web/fetch/util.js +++ b/lib/web/fetch/util.js @@ -1476,5 +1476,6 @@ module.exports = { buildContentRange, parseMetadata, createInflate, - extractMimeType + extractMimeType, + getDecodeSplit } diff --git a/lib/web/websocket/connection.js b/lib/web/websocket/connection.js index 8f12933c959..45f68e1de93 100644 --- a/lib/web/websocket/connection.js +++ b/lib/web/websocket/connection.js @@ -13,6 +13,7 @@ const { CloseEvent } = require('./events') const { makeRequest } = require('../fetch/request') const { fetching } = require('../fetch/index') const { Headers } = require('../fetch/headers') +const { getDecodeSplit } = require('../fetch/util') const { kHeadersList } = require('../../core/symbols') /** @type {import('crypto')} */ @@ -176,9 +177,18 @@ function establishWebSocketConnection (url, protocols, ws, onEstablish, options) // the WebSocket Connection_. const secProtocol = response.headersList.get('Sec-WebSocket-Protocol') - if (secProtocol !== null && secProtocol !== request.headersList.get('Sec-WebSocket-Protocol')) { - failWebsocketConnection(ws, 'Protocol was not set in the opening handshake.') - return + if (secProtocol !== null) { + const requestProtocols = getDecodeSplit('sec-websocket-protocol', request.headersList) + + // The client can request that the server use a specific subprotocol by + // including the |Sec-WebSocket-Protocol| field in its handshake. If it + // is specified, the server needs to include the same field and one of + // the selected subprotocol values in its response for the connection to + // be established. + if (!requestProtocols.includes(secProtocol)) { + failWebsocketConnection(ws, 'Protocol was not set in the opening handshake.') + return + } } response.socket.on('data', onSocketData) diff --git a/test/websocket/issue-2844.js b/test/websocket/issue-2844.js new file mode 100644 index 00000000000..d103a1722ff --- /dev/null +++ b/test/websocket/issue-2844.js @@ -0,0 +1,73 @@ +'use strict' + +const { test } = require('node:test') +const { once } = require('node:events') +const { WebSocketServer } = require('ws') +const { WebSocket } = require('../..') +const { tspl } = require('@matteo.collina/tspl') + +test('The server must reply with at least one subprotocol the client sends', async (t) => { + const { completed, deepStrictEqual, fail } = tspl(t, { plan: 2 }) + + const wss = new WebSocketServer({ + handleProtocols: (protocols) => { + deepStrictEqual(protocols, new Set(['msgpack', 'json'])) + + return protocols.values().next().value + }, + port: 0 + }) + + wss.on('connection', (ws) => { + ws.on('error', fail) + ws.send('something') + }) + + await once(wss, 'listening') + + const ws = new WebSocket(`ws://localhost:${wss.address().port}`, { + protocols: ['msgpack', 'json'] + }) + + ws.onerror = fail + ws.onopen = () => deepStrictEqual(ws.protocol, 'msgpack') + + t.after(() => { + wss.close() + ws.close() + }) + + await completed +}) + +test('The connection fails when the client sends subprotocols that the server does not responc with', async (t) => { + const { completed, fail, ok } = tspl(t, { plan: 1 }) + + const wss = new WebSocketServer({ + handleProtocols: () => false, + port: 0 + }) + + wss.on('connection', (ws) => { + ws.on('error', fail) + ws.send('something') + }) + + await once(wss, 'listening') + + const ws = new WebSocket(`ws://localhost:${wss.address().port}`, { + protocols: ['json'] + }) + + ws.onerror = ok.bind(null, true) + // The server will try to send 'something', this ensures that the connection + // fails during the handshake and doesn't receive any messages. + ws.onmessage = fail + + t.after(() => { + wss.close() + ws.close() + }) + + await completed +}) diff --git a/test/wpt/server/websocket.mjs b/test/wpt/server/websocket.mjs index cc8ce78151b..9bb05d12612 100644 --- a/test/wpt/server/websocket.mjs +++ b/test/wpt/server/websocket.mjs @@ -8,7 +8,7 @@ import { server } from './server.mjs' const wss = new WebSocketServer({ server, - handleProtocols: (protocols) => [...protocols].join(', ') + handleProtocols: (protocols) => protocols.values().next().value }) wss.on('connection', (ws, request) => {