Skip to content

Commit d7227c2

Browse files
abursavichnhooyr
authored andcommitted
Server selects first acceptable compression offer
Unacceptable offers are declined without rejecting the request.
1 parent 2291d83 commit d7227c2

File tree

4 files changed

+89
-65
lines changed

4 files changed

+89
-65
lines changed

accept.go

+18-23
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
123123
w.Header().Set("Sec-WebSocket-Protocol", subproto)
124124
}
125125

126-
copts, err := acceptCompression(r, w, opts.CompressionMode)
127-
if err != nil {
128-
return nil, err
126+
copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
127+
if ok {
128+
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
129129
}
130130

131131
w.WriteHeader(http.StatusSwitchingProtocols)
@@ -238,25 +238,26 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string {
238238
return ""
239239
}
240240

241-
func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
241+
func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
242242
if mode == CompressionDisabled {
243-
return nil, nil
243+
return nil, false
244244
}
245-
246-
for _, ext := range websocketExtensions(r.Header) {
245+
for _, ext := range extensions {
247246
switch ext.name {
248247
// We used to implement x-webkit-deflate-fram too but Safari has bugs.
249248
// See https://github.com/nhooyr/websocket/issues/218
250249
case "permessage-deflate":
251-
return acceptDeflate(w, ext, mode)
250+
copts, ok := acceptDeflate(ext, mode)
251+
if ok {
252+
return copts
253+
}
252254
}
253255
}
254-
return nil, nil
256+
return nil, false
255257
}
256258

257-
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
259+
func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
258260
copts := mode.opts()
259-
260261
for _, p := range ext.params {
261262
switch p {
262263
case "client_no_context_takeover":
@@ -265,24 +266,18 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
265266
case "server_no_context_takeover":
266267
copts.serverNoContextTakeover = true
267268
continue
268-
case "server_max_window_bits=15":
269+
case "client_max_window_bits",
270+
"server_max_window_bits=15":
269271
continue
270272
}
271273

272-
if strings.HasPrefix(p, "client_max_window_bits") {
273-
// We cannot adjust the read sliding window so cannot make use of this.
274-
// By not responding to it, we tell the client we're ignoring it.
274+
if strings.HasPrefix(p, "client_max_window_bits=") {
275+
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
275276
continue
276277
}
277-
278-
err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
279-
http.Error(w, err.Error(), http.StatusBadRequest)
280-
return nil, err
278+
return nil, false
281279
}
282-
283-
copts.setHeader(w.Header())
284-
285-
return copts, nil
280+
return copts, true
286281
}
287282

288283
func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {

accept_test.go

+68-38
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,47 @@ func TestAccept(t *testing.T) {
6262
t.Run("badCompression", func(t *testing.T) {
6363
t.Parallel()
6464

65-
w := mockHijacker{
66-
ResponseWriter: httptest.NewRecorder(),
65+
newRequest := func(extensions string) *http.Request {
66+
r := httptest.NewRequest("GET", "/", nil)
67+
r.Header.Set("Connection", "Upgrade")
68+
r.Header.Set("Upgrade", "websocket")
69+
r.Header.Set("Sec-WebSocket-Version", "13")
70+
r.Header.Set("Sec-WebSocket-Key", "meow123")
71+
r.Header.Set("Sec-WebSocket-Extensions", extensions)
72+
return r
73+
}
74+
newResponseWriter := func() http.ResponseWriter {
75+
return mockHijacker{
76+
ResponseWriter: httptest.NewRecorder(),
77+
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
78+
return nil, nil, errors.New("hijack error")
79+
},
80+
}
6781
}
68-
r := httptest.NewRequest("GET", "/", nil)
69-
r.Header.Set("Connection", "Upgrade")
70-
r.Header.Set("Upgrade", "websocket")
71-
r.Header.Set("Sec-WebSocket-Version", "13")
72-
r.Header.Set("Sec-WebSocket-Key", "meow123")
73-
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")
7482

75-
_, err := Accept(w, r, &AcceptOptions{
76-
CompressionMode: CompressionContextTakeover,
83+
t.Run("withoutFallback", func(t *testing.T) {
84+
t.Parallel()
85+
86+
w := newResponseWriter()
87+
r := newRequest("permessage-deflate; harharhar")
88+
_, _ = Accept(w, r, &AcceptOptions{
89+
CompressionMode: CompressionNoContextTakeover,
90+
})
91+
assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
92+
})
93+
t.Run("withFallback", func(t *testing.T) {
94+
t.Parallel()
95+
96+
w := newResponseWriter()
97+
r := newRequest("permessage-deflate; harharhar, permessage-deflate")
98+
_, _ = Accept(w, r, &AcceptOptions{
99+
CompressionMode: CompressionNoContextTakeover,
100+
})
101+
assert.Equal(t, "extension header",
102+
w.Header().Get("Sec-WebSocket-Extensions"),
103+
CompressionNoContextTakeover.opts().String(),
104+
)
77105
})
78-
assert.Contains(t, err, `unsupported permessage-deflate parameter`)
79106
})
80107

81108
t.Run("requireHttpHijacker", func(t *testing.T) {
@@ -344,42 +371,53 @@ func Test_authenticateOrigin(t *testing.T) {
344371
}
345372
}
346373

347-
func Test_acceptCompression(t *testing.T) {
374+
func Test_selectDeflate(t *testing.T) {
348375
t.Parallel()
349376

350377
testCases := []struct {
351-
name string
352-
mode CompressionMode
353-
reqSecWebSocketExtensions string
354-
respSecWebSocketExtensions string
355-
expCopts *compressionOptions
356-
error bool
378+
name string
379+
mode CompressionMode
380+
header string
381+
expCopts *compressionOptions
382+
expOK bool
357383
}{
358384
{
359385
name: "disabled",
360386
mode: CompressionDisabled,
361387
expCopts: nil,
388+
expOK: false,
362389
},
363390
{
364391
name: "noClientSupport",
365392
mode: CompressionNoContextTakeover,
366393
expCopts: nil,
394+
expOK: false,
367395
},
368396
{
369-
name: "permessage-deflate",
370-
mode: CompressionNoContextTakeover,
371-
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
372-
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
397+
name: "permessage-deflate",
398+
mode: CompressionNoContextTakeover,
399+
header: "permessage-deflate; client_max_window_bits",
373400
expCopts: &compressionOptions{
374401
clientNoContextTakeover: true,
375402
serverNoContextTakeover: true,
376403
},
404+
expOK: true,
405+
},
406+
{
407+
name: "permessage-deflate/unknown-parameter",
408+
mode: CompressionNoContextTakeover,
409+
header: "permessage-deflate; meow",
410+
expOK: false,
377411
},
378412
{
379-
name: "permessage-deflate/error",
380-
mode: CompressionNoContextTakeover,
381-
reqSecWebSocketExtensions: "permessage-deflate; meow",
382-
error: true,
413+
name: "permessage-deflate/unknown-parameter",
414+
mode: CompressionNoContextTakeover,
415+
header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
416+
expCopts: &compressionOptions{
417+
clientNoContextTakeover: true,
418+
serverNoContextTakeover: true,
419+
},
420+
expOK: true,
383421
},
384422
// {
385423
// name: "x-webkit-deflate-frame",
@@ -404,19 +442,11 @@ func Test_acceptCompression(t *testing.T) {
404442
t.Run(tc.name, func(t *testing.T) {
405443
t.Parallel()
406444

407-
r := httptest.NewRequest(http.MethodGet, "/", nil)
408-
r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions)
409-
410-
w := httptest.NewRecorder()
411-
copts, err := acceptCompression(r, w, tc.mode)
412-
if tc.error {
413-
assert.Error(t, err)
414-
return
415-
}
416-
417-
assert.Success(t, err)
445+
h := http.Header{}
446+
h.Set("Sec-WebSocket-Extensions", tc.header)
447+
copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
448+
assert.Equal(t, "selected options", tc.expOK, ok)
418449
assert.Equal(t, "compression options", tc.expCopts, copts)
419-
assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
420450
})
421451
}
422452
}

compress.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package websocket
66
import (
77
"compress/flate"
88
"io"
9-
"net/http"
109
"sync"
1110
)
1211

@@ -65,15 +64,15 @@ type compressionOptions struct {
6564
serverNoContextTakeover bool
6665
}
6766

68-
func (copts *compressionOptions) setHeader(h http.Header) {
67+
func (copts *compressionOptions) String() string {
6968
s := "permessage-deflate"
7069
if copts.clientNoContextTakeover {
7170
s += "; client_no_context_takeover"
7271
}
7372
if copts.serverNoContextTakeover {
7473
s += "; server_no_context_takeover"
7574
}
76-
h.Set("Sec-WebSocket-Extensions", s)
75+
return s
7776
}
7877

7978
// These bytes are required to get flate.Reader to return.

dial.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
185185
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
186186
}
187187
if copts != nil {
188-
copts.setHeader(req.Header)
188+
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
189189
}
190190

191191
resp, err := opts.HTTPClient.Do(req)

0 commit comments

Comments
 (0)