Skip to content

Commit 0fb0a6b

Browse files
committed
Fix deflate extension negotation
1 parent 94f9b71 commit 0fb0a6b

File tree

8 files changed

+84
-53
lines changed

8 files changed

+84
-53
lines changed

accept.go

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM
209209

210210
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
211211
copts := mode.opts()
212-
copts.serverMaxWindowBits = 8
212+
copts.serverMaxWindowBits = 13
213213

214214
for _, p := range ext.params {
215215
switch p {
@@ -222,26 +222,30 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
222222
}
223223

224224
if strings.HasPrefix(p, "client_max_window_bits") {
225+
if p == "client_max_window_bits" {
226+
copts.clientMaxWindowBits = 15
227+
continue
228+
}
229+
bits, ok := parseExtensionParameter(p)
230+
if !ok || bits < 8 || bits > 16 {
231+
err := fmt.Errorf("invalid client_max_window_bits: %q", p)
232+
http.Error(w, err.Error(), http.StatusBadRequest)
233+
return nil, err
234+
}
235+
copts.clientMaxWindowBits = bits
225236
continue
226-
227-
// bits, ok := parseExtensionParameter(p, 15)
228-
// if !ok || bits < 8 || bits > 16 {
229-
// err := fmt.Errorf("invalid client_max_window_bits: %q", p)
230-
// http.Error(w, err.Error(), http.StatusBadRequest)
231-
// return nil, err
232-
// }
233-
// copts.clientMaxWindowBits = bits
234-
// continue
235237
}
236238

237-
if false && strings.HasPrefix(p, "server_max_window_bits") {
238-
// We always send back 8 but make sure to validate.
239-
bits, ok := parseExtensionParameter(p, 0)
239+
if strings.HasPrefix(p, "server_max_window_bits") {
240+
bits, ok := parseExtensionParameter(p)
240241
if !ok || bits < 8 || bits > 16 {
241242
err := fmt.Errorf("invalid server_max_window_bits: %q", p)
242243
http.Error(w, err.Error(), http.StatusBadRequest)
243244
return nil, err
244245
}
246+
if copts.serverMaxWindowBits > bits {
247+
copts.serverMaxWindowBits = bits
248+
}
245249
continue
246250
}
247251

@@ -256,14 +260,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
256260
}
257261

258262
// parseExtensionParameter parses the value in the extension parameter p.
259-
// It falls back to defaultVal if there is no value.
260-
// If defaultVal == 0, then ok == false if there is no value.
261-
func parseExtensionParameter(p string, defaultVal int) (int, bool) {
263+
func parseExtensionParameter(p string) (int, bool) {
262264
ps := strings.Split(p, "=")
263265
if len(ps) == 1 {
264-
if defaultVal > 0 {
265-
return defaultVal, true
266-
}
267266
return 0, false
268267
}
269268
i, e := strconv.Atoi(strings.Trim(ps[1], `"`))

accept_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,12 @@ func Test_acceptCompression(t *testing.T) {
323323
name: "permessage-deflate",
324324
mode: CompressionNoContextTakeover,
325325
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
326-
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
326+
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover; server_max_window_bits=8; client_max_window_bits=15",
327327
expCopts: &compressionOptions{
328328
clientNoContextTakeover: true,
329329
serverNoContextTakeover: true,
330330
serverMaxWindowBits: 8,
331+
clientMaxWindowBits: 15,
331332
},
332333
},
333334
{

autobahn_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ var excludedAutobahnCases = []string{
2828

2929
// We skip the tests related to requestMaxWindowBits as that is unimplemented due
3030
// to limitations in compress/flate. See https://github.com/golang/go/issues/3155
31+
// Same with klauspost/compress which doesn't allow adjusting the sliding window size.
3132
"13.3.*", "13.4.*", "13.5.*", "13.6.*",
3233
}
3334

compress_notjs.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ func (copts *compressionOptions) setHeader(h http.Header) {
3434
if copts.serverNoContextTakeover {
3535
s += "; server_no_context_takeover"
3636
}
37-
if false && copts.serverMaxWindowBits > 0 {
37+
if copts.serverMaxWindowBits > 0 {
3838
s += fmt.Sprintf("; server_max_window_bits=%v", copts.serverMaxWindowBits)
3939
}
40-
if false && copts.clientMaxWindowBits > 0 {
40+
if copts.clientMaxWindowBits > 0 {
4141
s += fmt.Sprintf("; client_max_window_bits=%v", copts.clientMaxWindowBits)
4242
}
4343
h.Set("Sec-WebSocket-Extensions", s)
@@ -147,6 +147,10 @@ func (sw *slidingWindow) init(n int) {
147147
return
148148
}
149149

150+
if n == 0 {
151+
n = 32768
152+
}
153+
150154
p := slidingWindowPool(n)
151155
buf, ok := p.Get().([]byte)
152156
if ok {

dial.go

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,13 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
8282
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
8383
}
8484

85-
resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey)
85+
var copts *compressionOptions
86+
if opts.CompressionMode != CompressionDisabled {
87+
copts = opts.CompressionMode.opts()
88+
copts.clientMaxWindowBits = 13
89+
}
90+
91+
resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
8692
if err != nil {
8793
return nil, resp, err
8894
}
@@ -104,7 +110,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
104110
}
105111
}()
106112

107-
copts, err := verifyServerResponse(opts, secWebSocketKey, resp)
113+
err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
108114
if err != nil {
109115
return nil, resp, err
110116
}
@@ -125,7 +131,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
125131
}), resp, nil
126132
}
127133

128-
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) {
134+
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
129135
if opts.HTTPClient.Timeout > 0 {
130136
return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
131137
}
@@ -153,9 +159,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
153159
if len(opts.Subprotocols) > 0 {
154160
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
155161
}
156-
if opts.CompressionMode != CompressionDisabled {
157-
copts := opts.CompressionMode.opts()
158-
copts.clientMaxWindowBits = 8
162+
if copts != nil {
159163
copts.setHeader(req.Header)
160164
}
161165

@@ -178,32 +182,32 @@ func secWebSocketKey(rr io.Reader) (string, error) {
178182
return base64.StdEncoding.EncodeToString(b), nil
179183
}
180184

181-
func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
185+
func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) error {
182186
if resp.StatusCode != http.StatusSwitchingProtocols {
183-
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
187+
return fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
184188
}
185189

186190
if !headerContainsToken(resp.Header, "Connection", "Upgrade") {
187-
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
191+
return fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
188192
}
189193

190194
if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") {
191-
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
195+
return fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
192196
}
193197

194198
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
195-
return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
199+
return fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
196200
resp.Header.Get("Sec-WebSocket-Accept"),
197201
secWebSocketKey,
198202
)
199203
}
200204

201205
err := verifySubprotocol(opts.Subprotocols, resp)
202206
if err != nil {
203-
return nil, err
207+
return err
204208
}
205209

206-
return verifyServerExtensions(resp.Header)
210+
return verifyServerExtensions(copts, resp.Header)
207211
}
208212

209213
func verifySubprotocol(subprotos []string, resp *http.Response) error {
@@ -221,19 +225,20 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error {
221225
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
222226
}
223227

224-
func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
228+
func verifyServerExtensions(copts *compressionOptions, h http.Header) error {
225229
exts := websocketExtensions(h)
226230
if len(exts) == 0 {
227-
return nil, nil
231+
return nil
228232
}
229233

230234
ext := exts[0]
231-
if ext.name != "permessage-deflate" || len(exts) > 1 {
232-
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
235+
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
236+
return fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
233237
}
234238

235-
copts := &compressionOptions{}
236-
copts.clientMaxWindowBits = 8
239+
// Let the server decide its context takeover.
240+
copts.serverNoContextTakeover = false
241+
237242
for _, p := range ext.params {
238243
switch p {
239244
case "client_no_context_takeover":
@@ -244,28 +249,30 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
244249
continue
245250
}
246251

247-
if false && strings.HasPrefix(p, "server_max_window_bits") {
248-
bits, ok := parseExtensionParameter(p, 0)
252+
if strings.HasPrefix(p, "server_max_window_bits") {
253+
bits, ok := parseExtensionParameter(p)
249254
if !ok || bits < 8 || bits > 16 {
250-
return nil, fmt.Errorf("invalid server_max_window_bits: %q", p)
255+
return fmt.Errorf("invalid server_max_window_bits: %q", p)
251256
}
252257
copts.serverMaxWindowBits = bits
253258
continue
254259
}
255260

256-
if false && strings.HasPrefix(p, "client_max_window_bits") {
257-
bits, ok := parseExtensionParameter(p, 0)
261+
if strings.HasPrefix(p, "client_max_window_bits") {
262+
bits, ok := parseExtensionParameter(p)
258263
if !ok || bits < 8 || bits > 16 {
259-
return nil, fmt.Errorf("invalid client_max_window_bits: %q", p)
264+
return fmt.Errorf("invalid client_max_window_bits: %q", p)
265+
}
266+
if copts.clientMaxWindowBits > bits {
267+
copts.clientMaxWindowBits = bits
260268
}
261-
copts.clientMaxWindowBits = 8
262269
continue
263270
}
264271

265-
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
272+
return fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
266273
}
267274

268-
return copts, nil
275+
return nil
269276
}
270277

271278
var readerPool sync.Pool

dial_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ func Test_verifyServerHandshake(t *testing.T) {
221221
opts := &DialOptions{
222222
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
223223
}
224-
_, err = verifyServerResponse(opts, key, resp)
224+
err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
225225
if tc.success {
226226
assert.Success(t, err)
227227
} else {

read.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,16 @@ func newMsgReader(c *Conn) *msgReader {
8888
return mr
8989
}
9090

91+
func (mr *msgReader) maxWindowBits() int {
92+
if mr.c.client {
93+
return mr.c.copts.serverMaxWindowBits
94+
}
95+
return mr.c.copts.clientMaxWindowBits
96+
}
97+
9198
func (mr *msgReader) resetFlate() {
9299
if mr.flateContextTakeover() {
93-
mr.dict.init(32768)
100+
mr.dict.init(pow(2, mr.maxWindowBits()))
94101
}
95102
if mr.flateBufio == nil {
96103
mr.flateBufio = getBufioReader(mr.readFunc)

write.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"errors"
1111
"fmt"
1212
"io"
13+
"math"
1314
"sync"
1415
"time"
1516

@@ -89,17 +90,28 @@ func newMsgWriterState(c *Conn) *msgWriterState {
8990
return mw
9091
}
9192

93+
func (mw *msgWriterState) maxWindowBits() int {
94+
if mw.c.client {
95+
return mw.c.copts.clientMaxWindowBits
96+
}
97+
return mw.c.copts.serverMaxWindowBits
98+
}
99+
92100
func (mw *msgWriterState) ensureFlate() {
93101
if mw.trimWriter == nil {
94102
mw.trimWriter = &trimLastFourBytesWriter{
95103
w: writerFunc(mw.write),
96104
}
97105
}
98106

99-
mw.dict.init(8192)
107+
mw.dict.init(pow(2, mw.maxWindowBits()))
100108
mw.flate = true
101109
}
102110

111+
func pow(x, y int) int {
112+
return int(math.Pow(float64(x), float64(y)))
113+
}
114+
103115
func (mw *msgWriterState) flateContextTakeover() bool {
104116
if mw.c.client {
105117
return !mw.c.copts.clientNoContextTakeover

0 commit comments

Comments
 (0)