@@ -82,7 +82,13 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
8282 return nil , nil , fmt .Errorf ("failed to generate Sec-WebSocket-Key: %w" , err )
8383 }
8484
85- resp , err := handshakeRequest (ctx , urls , opts , secWebSocketKey )
85+ var copts * compressionOptions
86+ if opts .CompressionMode != CompressionDisabled {
87+ copts = opts .CompressionMode .opts ()
88+ copts .clientMaxWindowBits = 13
89+ }
90+
91+ resp , err := handshakeRequest (ctx , urls , opts , copts , secWebSocketKey )
8692 if err != nil {
8793 return nil , resp , err
8894 }
@@ -104,7 +110,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
104110 }
105111 }()
106112
107- copts , err : = verifyServerResponse (opts , secWebSocketKey , resp )
113+ err = verifyServerResponse (opts , copts , secWebSocketKey , resp )
108114 if err != nil {
109115 return nil , resp , err
110116 }
@@ -125,7 +131,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
125131 }), resp , nil
126132}
127133
128- func handshakeRequest (ctx context.Context , urls string , opts * DialOptions , secWebSocketKey string ) (* http.Response , error ) {
134+ func handshakeRequest (ctx context.Context , urls string , opts * DialOptions , copts * compressionOptions , secWebSocketKey string ) (* http.Response , error ) {
129135 if opts .HTTPClient .Timeout > 0 {
130136 return nil , errors .New ("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67" )
131137 }
@@ -153,9 +159,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
153159 if len (opts .Subprotocols ) > 0 {
154160 req .Header .Set ("Sec-WebSocket-Protocol" , strings .Join (opts .Subprotocols , "," ))
155161 }
156- if opts .CompressionMode != CompressionDisabled {
157- copts := opts .CompressionMode .opts ()
158- copts .clientMaxWindowBits = 8
162+ if copts != nil {
159163 copts .setHeader (req .Header )
160164 }
161165
@@ -178,32 +182,32 @@ func secWebSocketKey(rr io.Reader) (string, error) {
178182 return base64 .StdEncoding .EncodeToString (b ), nil
179183}
180184
181- func verifyServerResponse (opts * DialOptions , secWebSocketKey string , resp * http.Response ) ( * compressionOptions , error ) {
185+ func verifyServerResponse (opts * DialOptions , copts * compressionOptions , secWebSocketKey string , resp * http.Response ) error {
182186 if resp .StatusCode != http .StatusSwitchingProtocols {
183- return nil , fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
187+ return fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
184188 }
185189
186190 if ! headerContainsToken (resp .Header , "Connection" , "Upgrade" ) {
187- return nil , fmt .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , resp .Header .Get ("Connection" ))
191+ return fmt .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , resp .Header .Get ("Connection" ))
188192 }
189193
190194 if ! headerContainsToken (resp .Header , "Upgrade" , "WebSocket" ) {
191- return nil , fmt .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , resp .Header .Get ("Upgrade" ))
195+ return fmt .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , resp .Header .Get ("Upgrade" ))
192196 }
193197
194198 if resp .Header .Get ("Sec-WebSocket-Accept" ) != secWebSocketAccept (secWebSocketKey ) {
195- return nil , fmt .Errorf ("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q" ,
199+ return fmt .Errorf ("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q" ,
196200 resp .Header .Get ("Sec-WebSocket-Accept" ),
197201 secWebSocketKey ,
198202 )
199203 }
200204
201205 err := verifySubprotocol (opts .Subprotocols , resp )
202206 if err != nil {
203- return nil , err
207+ return err
204208 }
205209
206- return verifyServerExtensions (resp .Header )
210+ return verifyServerExtensions (copts , resp .Header )
207211}
208212
209213func verifySubprotocol (subprotos []string , resp * http.Response ) error {
@@ -221,19 +225,20 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error {
221225 return fmt .Errorf ("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q" , proto )
222226}
223227
224- func verifyServerExtensions (h http.Header ) ( * compressionOptions , error ) {
228+ func verifyServerExtensions (copts * compressionOptions , h http.Header ) error {
225229 exts := websocketExtensions (h )
226230 if len (exts ) == 0 {
227- return nil , nil
231+ return nil
228232 }
229233
230234 ext := exts [0 ]
231- if ext .name != "permessage-deflate" || len (exts ) > 1 {
232- return nil , fmt .Errorf ("WebSocket protcol violation: unsupported extensions from server: %+v" , exts [1 :])
235+ if ext .name != "permessage-deflate" || len (exts ) > 1 || copts == nil {
236+ return fmt .Errorf ("WebSocket protcol violation: unsupported extensions from server: %+v" , exts [1 :])
233237 }
234238
235- copts := & compressionOptions {}
236- copts .clientMaxWindowBits = 8
239+ // Let the server decide its context takeover.
240+ copts .serverNoContextTakeover = false
241+
237242 for _ , p := range ext .params {
238243 switch p {
239244 case "client_no_context_takeover" :
@@ -244,28 +249,30 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
244249 continue
245250 }
246251
247- if false && strings .HasPrefix (p , "server_max_window_bits" ) {
248- bits , ok := parseExtensionParameter (p , 0 )
252+ if strings .HasPrefix (p , "server_max_window_bits" ) {
253+ bits , ok := parseExtensionParameter (p )
249254 if ! ok || bits < 8 || bits > 16 {
250- return nil , fmt .Errorf ("invalid server_max_window_bits: %q" , p )
255+ return fmt .Errorf ("invalid server_max_window_bits: %q" , p )
251256 }
252257 copts .serverMaxWindowBits = bits
253258 continue
254259 }
255260
256- if false && strings .HasPrefix (p , "client_max_window_bits" ) {
257- bits , ok := parseExtensionParameter (p , 0 )
261+ if strings .HasPrefix (p , "client_max_window_bits" ) {
262+ bits , ok := parseExtensionParameter (p )
258263 if ! ok || bits < 8 || bits > 16 {
259- return nil , fmt .Errorf ("invalid client_max_window_bits: %q" , p )
264+ return fmt .Errorf ("invalid client_max_window_bits: %q" , p )
265+ }
266+ if copts .clientMaxWindowBits > bits {
267+ copts .clientMaxWindowBits = bits
260268 }
261- copts .clientMaxWindowBits = 8
262269 continue
263270 }
264271
265- return nil , fmt .Errorf ("unsupported permessage-deflate parameter: %q" , p )
272+ return fmt .Errorf ("unsupported permessage-deflate parameter: %q" , p )
266273 }
267274
268- return copts , nil
275+ return nil
269276}
270277
271278var readerPool sync.Pool
0 commit comments