diff --git a/accept.go b/accept.go index 5a162de0..3eee08c9 100644 --- a/accept.go +++ b/accept.go @@ -209,7 +209,7 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() - copts.serverMaxWindowBits = 8 + copts.serverMaxWindowBits = 13 for _, p := range ext.params { switch p { @@ -222,26 +222,30 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi } if strings.HasPrefix(p, "client_max_window_bits") { + if p == "client_max_window_bits" { + copts.clientMaxWindowBits = 15 + continue + } + bits, ok := parseExtensionParameter(p) + if !ok || bits < 8 || bits > 16 { + err := fmt.Errorf("invalid client_max_window_bits: %q", p) + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err + } + copts.clientMaxWindowBits = bits continue - - // bits, ok := parseExtensionParameter(p, 15) - // if !ok || bits < 8 || bits > 16 { - // err := fmt.Errorf("invalid client_max_window_bits: %q", p) - // http.Error(w, err.Error(), http.StatusBadRequest) - // return nil, err - // } - // copts.clientMaxWindowBits = bits - // continue } - if false && strings.HasPrefix(p, "server_max_window_bits") { - // We always send back 8 but make sure to validate. - bits, ok := parseExtensionParameter(p, 0) + if strings.HasPrefix(p, "server_max_window_bits") { + bits, ok := parseExtensionParameter(p) if !ok || bits < 8 || bits > 16 { err := fmt.Errorf("invalid server_max_window_bits: %q", p) http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } + if copts.serverMaxWindowBits > bits { + copts.serverMaxWindowBits = bits + } continue } @@ -256,14 +260,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi } // parseExtensionParameter parses the value in the extension parameter p. -// It falls back to defaultVal if there is no value. -// If defaultVal == 0, then ok == false if there is no value. -func parseExtensionParameter(p string, defaultVal int) (int, bool) { +func parseExtensionParameter(p string) (int, bool) { ps := strings.Split(p, "=") if len(ps) == 1 { - if defaultVal > 0 { - return defaultVal, true - } return 0, false } i, e := strconv.Atoi(strings.Trim(ps[1], `"`)) diff --git a/accept_test.go b/accept_test.go index 555f0dc0..188f1ec0 100644 --- a/accept_test.go +++ b/accept_test.go @@ -323,11 +323,12 @@ func Test_acceptCompression(t *testing.T) { name: "permessage-deflate", mode: CompressionNoContextTakeover, reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", - respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover", + respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover; server_max_window_bits=8; client_max_window_bits=15", expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, serverMaxWindowBits: 8, + clientMaxWindowBits: 15, }, }, { diff --git a/autobahn_test.go b/autobahn_test.go index 50473534..e56a4912 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -28,6 +28,7 @@ var excludedAutobahnCases = []string{ // We skip the tests related to requestMaxWindowBits as that is unimplemented due // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 + // Same with klauspost/compress which doesn't allow adjusting the sliding window size. "13.3.*", "13.4.*", "13.5.*", "13.6.*", } diff --git a/compress_notjs.go b/compress_notjs.go index ef82eb4d..4f892dc0 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -34,10 +34,10 @@ func (copts *compressionOptions) setHeader(h http.Header) { if copts.serverNoContextTakeover { s += "; server_no_context_takeover" } - if false && copts.serverMaxWindowBits > 0 { + if copts.serverMaxWindowBits > 0 { s += fmt.Sprintf("; server_max_window_bits=%v", copts.serverMaxWindowBits) } - if false && copts.clientMaxWindowBits > 0 { + if copts.clientMaxWindowBits > 0 { s += fmt.Sprintf("; client_max_window_bits=%v", copts.clientMaxWindowBits) } h.Set("Sec-WebSocket-Extensions", s) @@ -147,6 +147,10 @@ func (sw *slidingWindow) init(n int) { return } + if n == 0 { + n = 32768 + } + p := slidingWindowPool(n) buf, ok := p.Get().([]byte) if ok { diff --git a/dial.go b/dial.go index 8ff39597..dcc27f7a 100644 --- a/dial.go +++ b/dial.go @@ -82,7 +82,13 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) } - resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey) + var copts *compressionOptions + if opts.CompressionMode != CompressionDisabled { + copts = opts.CompressionMode.opts() + copts.clientMaxWindowBits = 13 + } + + resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey) if err != nil { return nil, resp, err } @@ -104,7 +110,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( } }() - copts, err := verifyServerResponse(opts, secWebSocketKey, resp) + err = verifyServerResponse(opts, copts, secWebSocketKey, resp) if err != nil { return nil, resp, err } @@ -125,7 +131,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( }), resp, nil } -func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) { +func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { if opts.HTTPClient.Timeout > 0 { return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") } @@ -153,9 +159,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } - if opts.CompressionMode != CompressionDisabled { - copts := opts.CompressionMode.opts() - copts.clientMaxWindowBits = 8 + if copts != nil { copts.setHeader(req.Header) } @@ -178,21 +182,21 @@ func secWebSocketKey(rr io.Reader) (string, error) { return base64.StdEncoding.EncodeToString(b), nil } -func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { +func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) error { if resp.StatusCode != http.StatusSwitchingProtocols { - return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) + return fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } if !headerContainsToken(resp.Header, "Connection", "Upgrade") { - return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) + return fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { - return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) + return fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { - return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + return fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", resp.Header.Get("Sec-WebSocket-Accept"), secWebSocketKey, ) @@ -200,10 +204,10 @@ func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http. err := verifySubprotocol(opts.Subprotocols, resp) if err != nil { - return nil, err + return err } - return verifyServerExtensions(resp.Header) + return verifyServerExtensions(copts, resp.Header) } func verifySubprotocol(subprotos []string, resp *http.Response) error { @@ -221,19 +225,20 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error { return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } -func verifyServerExtensions(h http.Header) (*compressionOptions, error) { +func verifyServerExtensions(copts *compressionOptions, h http.Header) error { exts := websocketExtensions(h) if len(exts) == 0 { - return nil, nil + return nil } ext := exts[0] - if ext.name != "permessage-deflate" || len(exts) > 1 { - return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) + if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil { + return fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } - copts := &compressionOptions{} - copts.clientMaxWindowBits = 8 + // Let the server decide its context takeover. + copts.serverNoContextTakeover = false + for _, p := range ext.params { switch p { case "client_no_context_takeover": @@ -244,28 +249,30 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) { continue } - if false && strings.HasPrefix(p, "server_max_window_bits") { - bits, ok := parseExtensionParameter(p, 0) + if strings.HasPrefix(p, "server_max_window_bits") { + bits, ok := parseExtensionParameter(p) if !ok || bits < 8 || bits > 16 { - return nil, fmt.Errorf("invalid server_max_window_bits: %q", p) + return fmt.Errorf("invalid server_max_window_bits: %q", p) } copts.serverMaxWindowBits = bits continue } - if false && strings.HasPrefix(p, "client_max_window_bits") { - bits, ok := parseExtensionParameter(p, 0) + if strings.HasPrefix(p, "client_max_window_bits") { + bits, ok := parseExtensionParameter(p) if !ok || bits < 8 || bits > 16 { - return nil, fmt.Errorf("invalid client_max_window_bits: %q", p) + return fmt.Errorf("invalid client_max_window_bits: %q", p) + } + if copts.clientMaxWindowBits > bits { + copts.clientMaxWindowBits = bits } - copts.clientMaxWindowBits = 8 continue } - return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + return fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } - return copts, nil + return nil } var readerPool sync.Pool diff --git a/dial_test.go b/dial_test.go index 06084cc5..2347b69f 100644 --- a/dial_test.go +++ b/dial_test.go @@ -221,7 +221,7 @@ func Test_verifyServerHandshake(t *testing.T) { opts := &DialOptions{ Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), } - _, err = verifyServerResponse(opts, key, resp) + err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp) if tc.success { assert.Success(t, err) } else { diff --git a/read.go b/read.go index a1efecab..06209420 100644 --- a/read.go +++ b/read.go @@ -88,9 +88,16 @@ func newMsgReader(c *Conn) *msgReader { return mr } +func (mr *msgReader) maxWindowBits() int { + if mr.c.client { + return mr.c.copts.serverMaxWindowBits + } + return mr.c.copts.clientMaxWindowBits +} + func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() { - mr.dict.init(32768) + mr.dict.init(pow(2, mr.maxWindowBits())) } if mr.flateBufio == nil { mr.flateBufio = getBufioReader(mr.readFunc) diff --git a/write.go b/write.go index 81b9141a..d1057729 100644 --- a/write.go +++ b/write.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "math" "sync" "time" @@ -89,6 +90,13 @@ func newMsgWriterState(c *Conn) *msgWriterState { return mw } +func (mw *msgWriterState) maxWindowBits() int { + if mw.c.client { + return mw.c.copts.clientMaxWindowBits + } + return mw.c.copts.serverMaxWindowBits +} + func (mw *msgWriterState) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ @@ -96,10 +104,14 @@ func (mw *msgWriterState) ensureFlate() { } } - mw.dict.init(8192) + mw.dict.init(pow(2, mw.maxWindowBits())) mw.flate = true } +func pow(x, y int) int { + return int(math.Pow(float64(x), float64(y))) +} + func (mw *msgWriterState) flateContextTakeover() bool { if mw.c.client { return !mw.c.copts.clientNoContextTakeover