@@ -5,6 +5,7 @@ import type { ServerOptions as HttpsServerOptions } from 'node:https'
5
5
import { createServer as createHttpsServer } from 'node:https'
6
6
import type { Socket } from 'node:net'
7
7
import type { Duplex } from 'node:stream'
8
+ import crypto from 'node:crypto'
8
9
import colors from 'picocolors'
9
10
import type { WebSocket as WebSocketRaw } from 'ws'
10
11
import { WebSocketServer as WebSocketServerRaw_ } from 'ws'
@@ -91,12 +92,34 @@ const wsServerEvents = [
91
92
'message' ,
92
93
]
93
94
95
+ // we only allow websockets to be connected if it has a valid token
96
+ // this is to prevent untrusted origins to connect to the server
97
+ // for example, Cross-site WebSocket hijacking
98
+ //
99
+ // we should check the token before calling wss.handleUpgrade
100
+ // otherwise untrusted ws clients will be included in wss.clients
101
+ //
102
+ // using the query params means the token might be logged out in server or middleware logs
103
+ // but we assume that is not an issue since the token is regenerated for each process
104
+ function hasValidToken ( config : ResolvedConfig , url : URL ) {
105
+ const token = url . searchParams . get ( 'token' )
106
+ if ( ! token ) return false
107
+
108
+ try {
109
+ const isValidToken = crypto . timingSafeEqual (
110
+ Buffer . from ( token ) ,
111
+ Buffer . from ( config . webSocketToken ) ,
112
+ )
113
+ return isValidToken
114
+ } catch { } // an error is thrown when the length is incorrect
115
+ return false
116
+ }
117
+
94
118
export function createWebSocketServer (
95
119
server : Server | null ,
96
120
config : ResolvedConfig ,
97
121
httpsOptions ?: HttpsServerOptions ,
98
122
) : WebSocketServer {
99
- let wss : WebSocketServerRaw_
100
123
let wsHttpServer : Server | undefined = undefined
101
124
102
125
const hmr = isObject ( config . server . hmr ) && config . server . hmr
@@ -115,21 +138,50 @@ export function createWebSocketServer(
115
138
const port = hmrPort || 24678
116
139
const host = ( hmr && hmr . host ) || undefined
117
140
141
+ const shouldHandle = ( req : IncomingMessage ) => {
142
+ if ( config . legacy ?. skipWebSocketTokenCheck ) {
143
+ return true
144
+ }
145
+
146
+ // If the Origin header is set, this request might be coming from a browser.
147
+ // Browsers always sets the Origin header for WebSocket connections.
148
+ if ( req . headers . origin ) {
149
+ const parsedUrl = new URL ( `http://example.com${ req . url ! } ` )
150
+ return hasValidToken ( config , parsedUrl )
151
+ }
152
+
153
+ // We allow non-browser requests to connect without a token
154
+ // for backward compat and convenience
155
+ // This is fine because if you can sent a request without the SOP limitation,
156
+ // you can also send a normal HTTP request to the server.
157
+ return true
158
+ }
159
+ const handleUpgrade = (
160
+ req : IncomingMessage ,
161
+ socket : Duplex ,
162
+ head : Buffer ,
163
+ _isPing : boolean ,
164
+ ) => {
165
+ wss . handleUpgrade ( req , socket as Socket , head , ( ws ) => {
166
+ wss . emit ( 'connection' , ws , req )
167
+ } )
168
+ }
169
+ const wss : WebSocketServerRaw_ = new WebSocketServerRaw ( { noServer : true } )
170
+ wss . shouldHandle = shouldHandle
171
+
118
172
if ( wsServer ) {
119
173
let hmrBase = config . base
120
174
const hmrPath = hmr ? hmr . path : undefined
121
175
if ( hmrPath ) {
122
176
hmrBase = path . posix . join ( hmrBase , hmrPath )
123
177
}
124
- wss = new WebSocketServerRaw ( { noServer : true } )
125
178
hmrServerWsListener = ( req , socket , head ) => {
179
+ const parsedUrl = new URL ( `http://example.com${ req . url ! } ` )
126
180
if (
127
181
req . headers [ 'sec-websocket-protocol' ] === HMR_HEADER &&
128
- req . url === hmrBase
182
+ parsedUrl . pathname === hmrBase
129
183
) {
130
- wss . handleUpgrade ( req , socket as Socket , head , ( ws ) => {
131
- wss . emit ( 'connection' , ws , req )
132
- } )
184
+ handleUpgrade ( req , socket as Socket , head , false )
133
185
}
134
186
}
135
187
wsServer . on ( 'upgrade' , hmrServerWsListener )
@@ -153,9 +205,22 @@ export function createWebSocketServer(
153
205
} else {
154
206
wsHttpServer = createHttpServer ( route )
155
207
}
156
- // vite dev server in middleware mode
157
- // need to call ws listen manually
158
- wss = new WebSocketServerRaw ( { server : wsHttpServer } )
208
+ wsHttpServer . on ( 'upgrade' , ( req , socket , head ) => {
209
+ handleUpgrade ( req , socket as Socket , head , false )
210
+ } )
211
+ wsHttpServer . on ( 'error' , ( e : Error & { code : string } ) => {
212
+ if ( e . code === 'EADDRINUSE' ) {
213
+ config . logger . error (
214
+ colors . red ( `WebSocket server error: Port is already in use` ) ,
215
+ { error : e } ,
216
+ )
217
+ } else {
218
+ config . logger . error (
219
+ colors . red ( `WebSocket server error:\n${ e . stack || e . message } ` ) ,
220
+ { error : e } ,
221
+ )
222
+ }
223
+ } )
159
224
}
160
225
161
226
wss . on ( 'connection' , ( socket ) => {
0 commit comments