Skip to content

Commit

Permalink
Fix deflate extension negotation
Browse files Browse the repository at this point in the history
  • Loading branch information
nhooyr committed Feb 16, 2020
1 parent 94f9b71 commit 0fb0a6b
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 53 deletions.
37 changes: 18 additions & 19 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM

func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
copts := mode.opts()
copts.serverMaxWindowBits = 8
copts.serverMaxWindowBits = 13

for _, p := range ext.params {
switch p {
Expand All @@ -222,26 +222,30 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
}

if strings.HasPrefix(p, "client_max_window_bits") {
if p == "client_max_window_bits" {
copts.clientMaxWindowBits = 15
continue
}
bits, ok := parseExtensionParameter(p)
if !ok || bits < 8 || bits > 16 {
err := fmt.Errorf("invalid client_max_window_bits: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
copts.clientMaxWindowBits = bits
continue

// bits, ok := parseExtensionParameter(p, 15)
// if !ok || bits < 8 || bits > 16 {
// err := fmt.Errorf("invalid client_max_window_bits: %q", p)
// http.Error(w, err.Error(), http.StatusBadRequest)
// return nil, err
// }
// copts.clientMaxWindowBits = bits
// continue
}

if false && strings.HasPrefix(p, "server_max_window_bits") {
// We always send back 8 but make sure to validate.
bits, ok := parseExtensionParameter(p, 0)
if strings.HasPrefix(p, "server_max_window_bits") {
bits, ok := parseExtensionParameter(p)
if !ok || bits < 8 || bits > 16 {
err := fmt.Errorf("invalid server_max_window_bits: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
if copts.serverMaxWindowBits > bits {
copts.serverMaxWindowBits = bits
}
continue
}

Expand All @@ -256,14 +260,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
}

// parseExtensionParameter parses the value in the extension parameter p.
// It falls back to defaultVal if there is no value.
// If defaultVal == 0, then ok == false if there is no value.
func parseExtensionParameter(p string, defaultVal int) (int, bool) {
func parseExtensionParameter(p string) (int, bool) {
ps := strings.Split(p, "=")
if len(ps) == 1 {
if defaultVal > 0 {
return defaultVal, true
}
return 0, false
}
i, e := strconv.Atoi(strings.Trim(ps[1], `"`))
Expand Down
3 changes: 2 additions & 1 deletion accept_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,12 @@ func Test_acceptCompression(t *testing.T) {
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover; server_max_window_bits=8; client_max_window_bits=15",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
serverMaxWindowBits: 8,
clientMaxWindowBits: 15,
},
},
{
Expand Down
1 change: 1 addition & 0 deletions autobahn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ var excludedAutobahnCases = []string{

// We skip the tests related to requestMaxWindowBits as that is unimplemented due
// to limitations in compress/flate. See https://github.com/golang/go/issues/3155
// Same with klauspost/compress which doesn't allow adjusting the sliding window size.
"13.3.*", "13.4.*", "13.5.*", "13.6.*",
}

Expand Down
8 changes: 6 additions & 2 deletions compress_notjs.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ func (copts *compressionOptions) setHeader(h http.Header) {
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
if false && copts.serverMaxWindowBits > 0 {
if copts.serverMaxWindowBits > 0 {
s += fmt.Sprintf("; server_max_window_bits=%v", copts.serverMaxWindowBits)
}
if false && copts.clientMaxWindowBits > 0 {
if copts.clientMaxWindowBits > 0 {
s += fmt.Sprintf("; client_max_window_bits=%v", copts.clientMaxWindowBits)
}
h.Set("Sec-WebSocket-Extensions", s)
Expand Down Expand Up @@ -147,6 +147,10 @@ func (sw *slidingWindow) init(n int) {
return
}

if n == 0 {
n = 32768
}

p := slidingWindowPool(n)
buf, ok := p.Get().([]byte)
if ok {
Expand Down
63 changes: 35 additions & 28 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
}

resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey)
var copts *compressionOptions
if opts.CompressionMode != CompressionDisabled {
copts = opts.CompressionMode.opts()
copts.clientMaxWindowBits = 13
}

resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
if err != nil {
return nil, resp, err
}
Expand All @@ -104,7 +110,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
}
}()

copts, err := verifyServerResponse(opts, secWebSocketKey, resp)
err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
if err != nil {
return nil, resp, err
}
Expand All @@ -125,7 +131,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
}), resp, nil
}

func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) {
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
if opts.HTTPClient.Timeout > 0 {
return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
}
Expand Down Expand Up @@ -153,9 +159,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if opts.CompressionMode != CompressionDisabled {
copts := opts.CompressionMode.opts()
copts.clientMaxWindowBits = 8
if copts != nil {
copts.setHeader(req.Header)
}

Expand All @@ -178,32 +182,32 @@ func secWebSocketKey(rr io.Reader) (string, error) {
return base64.StdEncoding.EncodeToString(b), nil
}

func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) error {
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
return fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}

if !headerContainsToken(resp.Header, "Connection", "Upgrade") {
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
return fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
}

if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") {
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
return fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
}

if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
return fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
secWebSocketKey,
)
}

err := verifySubprotocol(opts.Subprotocols, resp)
if err != nil {
return nil, err
return err
}

return verifyServerExtensions(resp.Header)
return verifyServerExtensions(copts, resp.Header)
}

func verifySubprotocol(subprotos []string, resp *http.Response) error {
Expand All @@ -221,19 +225,20 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error {
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
}

func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
func verifyServerExtensions(copts *compressionOptions, h http.Header) error {
exts := websocketExtensions(h)
if len(exts) == 0 {
return nil, nil
return nil
}

ext := exts[0]
if ext.name != "permessage-deflate" || len(exts) > 1 {
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
return fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
}

copts := &compressionOptions{}
copts.clientMaxWindowBits = 8
// Let the server decide its context takeover.
copts.serverNoContextTakeover = false

for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
Expand All @@ -244,28 +249,30 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
continue
}

if false && strings.HasPrefix(p, "server_max_window_bits") {
bits, ok := parseExtensionParameter(p, 0)
if strings.HasPrefix(p, "server_max_window_bits") {
bits, ok := parseExtensionParameter(p)
if !ok || bits < 8 || bits > 16 {
return nil, fmt.Errorf("invalid server_max_window_bits: %q", p)
return fmt.Errorf("invalid server_max_window_bits: %q", p)
}
copts.serverMaxWindowBits = bits
continue
}

if false && strings.HasPrefix(p, "client_max_window_bits") {
bits, ok := parseExtensionParameter(p, 0)
if strings.HasPrefix(p, "client_max_window_bits") {
bits, ok := parseExtensionParameter(p)
if !ok || bits < 8 || bits > 16 {
return nil, fmt.Errorf("invalid client_max_window_bits: %q", p)
return fmt.Errorf("invalid client_max_window_bits: %q", p)
}
if copts.clientMaxWindowBits > bits {
copts.clientMaxWindowBits = bits
}
copts.clientMaxWindowBits = 8
continue
}

return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
return fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
}

return copts, nil
return nil
}

var readerPool sync.Pool
Expand Down
2 changes: 1 addition & 1 deletion dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func Test_verifyServerHandshake(t *testing.T) {
opts := &DialOptions{
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
}
_, err = verifyServerResponse(opts, key, resp)
err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
if tc.success {
assert.Success(t, err)
} else {
Expand Down
9 changes: 8 additions & 1 deletion read.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,16 @@ func newMsgReader(c *Conn) *msgReader {
return mr
}

func (mr *msgReader) maxWindowBits() int {
if mr.c.client {
return mr.c.copts.serverMaxWindowBits
}
return mr.c.copts.clientMaxWindowBits
}

func (mr *msgReader) resetFlate() {
if mr.flateContextTakeover() {
mr.dict.init(32768)
mr.dict.init(pow(2, mr.maxWindowBits()))
}
if mr.flateBufio == nil {
mr.flateBufio = getBufioReader(mr.readFunc)
Expand Down
14 changes: 13 additions & 1 deletion write.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
"math"
"sync"
"time"

Expand Down Expand Up @@ -89,17 +90,28 @@ func newMsgWriterState(c *Conn) *msgWriterState {
return mw
}

func (mw *msgWriterState) maxWindowBits() int {
if mw.c.client {
return mw.c.copts.clientMaxWindowBits
}
return mw.c.copts.serverMaxWindowBits
}

func (mw *msgWriterState) ensureFlate() {
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: writerFunc(mw.write),
}
}

mw.dict.init(8192)
mw.dict.init(pow(2, mw.maxWindowBits()))
mw.flate = true
}

func pow(x, y int) int {
return int(math.Pow(float64(x), float64(y)))
}

func (mw *msgWriterState) flateContextTakeover() bool {
if mw.c.client {
return !mw.c.copts.clientNoContextTakeover
Expand Down

0 comments on commit 0fb0a6b

Please sign in to comment.