@@ -82,7 +82,13 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
82
82
return nil , nil , fmt .Errorf ("failed to generate Sec-WebSocket-Key: %w" , err )
83
83
}
84
84
85
- resp , err := handshakeRequest (ctx , urls , opts , secWebSocketKey )
85
+ var copts * compressionOptions
86
+ if opts .CompressionMode != CompressionDisabled {
87
+ copts = opts .CompressionMode .opts ()
88
+ copts .clientMaxWindowBits = 13
89
+ }
90
+
91
+ resp , err := handshakeRequest (ctx , urls , opts , copts , secWebSocketKey )
86
92
if err != nil {
87
93
return nil , resp , err
88
94
}
@@ -104,7 +110,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
104
110
}
105
111
}()
106
112
107
- copts , err : = verifyServerResponse (opts , secWebSocketKey , resp )
113
+ err = verifyServerResponse (opts , copts , secWebSocketKey , resp )
108
114
if err != nil {
109
115
return nil , resp , err
110
116
}
@@ -125,7 +131,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
125
131
}), resp , nil
126
132
}
127
133
128
- func handshakeRequest (ctx context.Context , urls string , opts * DialOptions , secWebSocketKey string ) (* http.Response , error ) {
134
+ func handshakeRequest (ctx context.Context , urls string , opts * DialOptions , copts * compressionOptions , secWebSocketKey string ) (* http.Response , error ) {
129
135
if opts .HTTPClient .Timeout > 0 {
130
136
return nil , errors .New ("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67" )
131
137
}
@@ -153,9 +159,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
153
159
if len (opts .Subprotocols ) > 0 {
154
160
req .Header .Set ("Sec-WebSocket-Protocol" , strings .Join (opts .Subprotocols , "," ))
155
161
}
156
- if opts .CompressionMode != CompressionDisabled {
157
- copts := opts .CompressionMode .opts ()
158
- copts .clientMaxWindowBits = 8
162
+ if copts != nil {
159
163
copts .setHeader (req .Header )
160
164
}
161
165
@@ -178,32 +182,32 @@ func secWebSocketKey(rr io.Reader) (string, error) {
178
182
return base64 .StdEncoding .EncodeToString (b ), nil
179
183
}
180
184
181
- func verifyServerResponse (opts * DialOptions , secWebSocketKey string , resp * http.Response ) ( * compressionOptions , error ) {
185
+ func verifyServerResponse (opts * DialOptions , copts * compressionOptions , secWebSocketKey string , resp * http.Response ) error {
182
186
if resp .StatusCode != http .StatusSwitchingProtocols {
183
- return nil , fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
187
+ return fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
184
188
}
185
189
186
190
if ! headerContainsToken (resp .Header , "Connection" , "Upgrade" ) {
187
- return nil , fmt .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , resp .Header .Get ("Connection" ))
191
+ return fmt .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , resp .Header .Get ("Connection" ))
188
192
}
189
193
190
194
if ! headerContainsToken (resp .Header , "Upgrade" , "WebSocket" ) {
191
- return nil , fmt .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , resp .Header .Get ("Upgrade" ))
195
+ return fmt .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , resp .Header .Get ("Upgrade" ))
192
196
}
193
197
194
198
if resp .Header .Get ("Sec-WebSocket-Accept" ) != secWebSocketAccept (secWebSocketKey ) {
195
- return nil , fmt .Errorf ("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q" ,
199
+ return fmt .Errorf ("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q" ,
196
200
resp .Header .Get ("Sec-WebSocket-Accept" ),
197
201
secWebSocketKey ,
198
202
)
199
203
}
200
204
201
205
err := verifySubprotocol (opts .Subprotocols , resp )
202
206
if err != nil {
203
- return nil , err
207
+ return err
204
208
}
205
209
206
- return verifyServerExtensions (resp .Header )
210
+ return verifyServerExtensions (copts , resp .Header )
207
211
}
208
212
209
213
func verifySubprotocol (subprotos []string , resp * http.Response ) error {
@@ -221,19 +225,20 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error {
221
225
return fmt .Errorf ("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q" , proto )
222
226
}
223
227
224
- func verifyServerExtensions (h http.Header ) ( * compressionOptions , error ) {
228
+ func verifyServerExtensions (copts * compressionOptions , h http.Header ) error {
225
229
exts := websocketExtensions (h )
226
230
if len (exts ) == 0 {
227
- return nil , nil
231
+ return nil
228
232
}
229
233
230
234
ext := exts [0 ]
231
- if ext .name != "permessage-deflate" || len (exts ) > 1 {
232
- return nil , fmt .Errorf ("WebSocket protcol violation: unsupported extensions from server: %+v" , exts [1 :])
235
+ if ext .name != "permessage-deflate" || len (exts ) > 1 || copts == nil {
236
+ return fmt .Errorf ("WebSocket protcol violation: unsupported extensions from server: %+v" , exts [1 :])
233
237
}
234
238
235
- copts := & compressionOptions {}
236
- copts .clientMaxWindowBits = 8
239
+ // Let the server decide its context takeover.
240
+ copts .serverNoContextTakeover = false
241
+
237
242
for _ , p := range ext .params {
238
243
switch p {
239
244
case "client_no_context_takeover" :
@@ -244,28 +249,30 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
244
249
continue
245
250
}
246
251
247
- if false && strings .HasPrefix (p , "server_max_window_bits" ) {
248
- bits , ok := parseExtensionParameter (p , 0 )
252
+ if strings .HasPrefix (p , "server_max_window_bits" ) {
253
+ bits , ok := parseExtensionParameter (p )
249
254
if ! ok || bits < 8 || bits > 16 {
250
- return nil , fmt .Errorf ("invalid server_max_window_bits: %q" , p )
255
+ return fmt .Errorf ("invalid server_max_window_bits: %q" , p )
251
256
}
252
257
copts .serverMaxWindowBits = bits
253
258
continue
254
259
}
255
260
256
- if false && strings .HasPrefix (p , "client_max_window_bits" ) {
257
- bits , ok := parseExtensionParameter (p , 0 )
261
+ if strings .HasPrefix (p , "client_max_window_bits" ) {
262
+ bits , ok := parseExtensionParameter (p )
258
263
if ! ok || bits < 8 || bits > 16 {
259
- return nil , fmt .Errorf ("invalid client_max_window_bits: %q" , p )
264
+ return fmt .Errorf ("invalid client_max_window_bits: %q" , p )
265
+ }
266
+ if copts .clientMaxWindowBits > bits {
267
+ copts .clientMaxWindowBits = bits
260
268
}
261
- copts .clientMaxWindowBits = 8
262
269
continue
263
270
}
264
271
265
- return nil , fmt .Errorf ("unsupported permessage-deflate parameter: %q" , p )
272
+ return fmt .Errorf ("unsupported permessage-deflate parameter: %q" , p )
266
273
}
267
274
268
- return copts , nil
275
+ return nil
269
276
}
270
277
271
278
var readerPool sync.Pool
0 commit comments