Skip to content

Commit

Permalink
Fix negotation of flate parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
nhooyr committed Feb 16, 2020
1 parent 0fb0a6b commit 2b3efae
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 90 deletions.
26 changes: 1 addition & 25 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM

func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
copts := mode.opts()
copts.serverMaxWindowBits = 13

for _, p := range ext.params {
switch p {
Expand All @@ -222,30 +221,7 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
}

if strings.HasPrefix(p, "client_max_window_bits") {
if p == "client_max_window_bits" {
copts.clientMaxWindowBits = 15
continue
}
bits, ok := parseExtensionParameter(p)
if !ok || bits < 8 || bits > 16 {
err := fmt.Errorf("invalid client_max_window_bits: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
copts.clientMaxWindowBits = bits
continue
}

if strings.HasPrefix(p, "server_max_window_bits") {
bits, ok := parseExtensionParameter(p)
if !ok || bits < 8 || bits > 16 {
err := fmt.Errorf("invalid server_max_window_bits: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
if copts.serverMaxWindowBits > bits {
copts.serverMaxWindowBits = bits
}
// We cannot adjust the read sliding window so cannot make use of this.
continue
}

Expand Down
4 changes: 1 addition & 3 deletions accept_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,10 @@ func Test_acceptCompression(t *testing.T) {
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover; server_max_window_bits=8; client_max_window_bits=15",
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
serverMaxWindowBits: 8,
clientMaxWindowBits: 15,
},
},
{
Expand Down
10 changes: 0 additions & 10 deletions compress_notjs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package websocket

import (
"fmt"
"io"
"net/http"
"sync"
Expand All @@ -20,10 +19,7 @@ func (m CompressionMode) opts() *compressionOptions {

type compressionOptions struct {
clientNoContextTakeover bool
clientMaxWindowBits int

serverNoContextTakeover bool
serverMaxWindowBits int
}

func (copts *compressionOptions) setHeader(h http.Header) {
Expand All @@ -34,12 +30,6 @@ func (copts *compressionOptions) setHeader(h http.Header) {
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
if copts.serverMaxWindowBits > 0 {
s += fmt.Sprintf("; server_max_window_bits=%v", copts.serverMaxWindowBits)
}
if copts.clientMaxWindowBits > 0 {
s += fmt.Sprintf("; client_max_window_bits=%v", copts.clientMaxWindowBits)
}
h.Set("Sec-WebSocket-Extensions", s)
}

Expand Down
48 changes: 13 additions & 35 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
var copts *compressionOptions
if opts.CompressionMode != CompressionDisabled {
copts = opts.CompressionMode.opts()
copts.clientMaxWindowBits = 13
}

resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
Expand All @@ -110,7 +109,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
}
}()

err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
if err != nil {
return nil, resp, err
}
Expand Down Expand Up @@ -182,29 +181,29 @@ func secWebSocketKey(rr io.Reader) (string, error) {
return base64.StdEncoding.EncodeToString(b), nil
}

func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) error {
func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
if resp.StatusCode != http.StatusSwitchingProtocols {
return fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}

if !headerContainsToken(resp.Header, "Connection", "Upgrade") {
return fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
}

if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") {
return fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
}

if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
return fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
secWebSocketKey,
)
}

err := verifySubprotocol(opts.Subprotocols, resp)
if err != nil {
return err
return nil, err
}

return verifyServerExtensions(copts, resp.Header)
Expand All @@ -225,19 +224,18 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error {
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
}

func verifyServerExtensions(copts *compressionOptions, h http.Header) error {
func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
exts := websocketExtensions(h)
if len(exts) == 0 {
return nil
return nil, nil
}

ext := exts[0]
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
return fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
}

// Let the server decide its context takeover.
copts.serverNoContextTakeover = false
copts = &*copts

for _, p := range ext.params {
switch p {
Expand All @@ -249,30 +247,10 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) error {
continue
}

if strings.HasPrefix(p, "server_max_window_bits") {
bits, ok := parseExtensionParameter(p)
if !ok || bits < 8 || bits > 16 {
return fmt.Errorf("invalid server_max_window_bits: %q", p)
}
copts.serverMaxWindowBits = bits
continue
}

if strings.HasPrefix(p, "client_max_window_bits") {
bits, ok := parseExtensionParameter(p)
if !ok || bits < 8 || bits > 16 {
return fmt.Errorf("invalid client_max_window_bits: %q", p)
}
if copts.clientMaxWindowBits > bits {
copts.clientMaxWindowBits = bits
}
continue
}

return fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
}

return nil
return copts, nil
}

var readerPool sync.Pool
Expand Down
2 changes: 1 addition & 1 deletion dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func Test_verifyServerHandshake(t *testing.T) {
opts := &DialOptions{
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
}
err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
_, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
if tc.success {
assert.Success(t, err)
} else {
Expand Down
9 changes: 1 addition & 8 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,9 @@ func newMsgReader(c *Conn) *msgReader {
return mr
}

func (mr *msgReader) maxWindowBits() int {
if mr.c.client {
return mr.c.copts.serverMaxWindowBits
}
return mr.c.copts.clientMaxWindowBits
}

func (mr *msgReader) resetFlate() {
if mr.flateContextTakeover() {
mr.dict.init(pow(2, mr.maxWindowBits()))
mr.dict.init(32768)
}
if mr.flateBufio == nil {
mr.flateBufio = getBufioReader(mr.readFunc)
Expand Down
9 changes: 1 addition & 8 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,14 @@ func newMsgWriterState(c *Conn) *msgWriterState {
return mw
}

func (mw *msgWriterState) maxWindowBits() int {
if mw.c.client {
return mw.c.copts.clientMaxWindowBits
}
return mw.c.copts.serverMaxWindowBits
}

func (mw *msgWriterState) ensureFlate() {
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: writerFunc(mw.write),
}
}

mw.dict.init(pow(2, mw.maxWindowBits()))
mw.dict.init(8192)
mw.flate = true
}

Expand Down

0 comments on commit 2b3efae

Please sign in to comment.