diff --git a/accept.go b/accept.go index 0f3b0d16..74afeabe 100644 --- a/accept.go +++ b/accept.go @@ -15,6 +15,7 @@ import ( "net/textproto" "net/url" "path" + "strconv" "strings" "github.com/coder/websocket/internal/errd" @@ -298,15 +299,34 @@ func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOp case "server_no_context_takeover": copts.serverNoContextTakeover = true continue - case "client_max_window_bits", - "server_max_window_bits=15": + case "client_max_window_bits": + copts.clientMaxWindowBits = 15 // default + continue + case "server_max_window_bits": + copts.serverMaxWindowBits = 15 // default continue } if strings.HasPrefix(p, "client_max_window_bits=") { - // We can't adjust the deflate window, but decoding with a larger window is acceptable. + // We don't need to change decoder settings; larger window decoder can read smaller windows. + if v, err := strconv.Atoi(strings.TrimPrefix(p, "client_max_window_bits=")); err == nil { + if v >= 8 && v <= 15 { + copts.clientMaxWindowBits = v + } + } continue } + + if strings.HasPrefix(p, "server_max_window_bits=") { + vstr := strings.TrimPrefix(p, "server_max_window_bits=") + v, err := strconv.Atoi(vstr) + if err != nil || v < 8 || v > 15 { + return nil, false // invalid per RFC + } + copts.serverMaxWindowBits = v + continue + } + return nil, false } return copts, true diff --git a/accept_test.go b/accept_test.go index aeea1d8a..3835cbd6 100644 --- a/accept_test.go +++ b/accept_test.go @@ -515,6 +515,22 @@ func Test_selectDeflate(t *testing.T) { expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, + + clientMaxWindowBits: 15, + serverMaxWindowBits: 0, + }, + expOK: true, + }, + { + name: "permessage-deflate/custom-client-window-bits", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; client_max_window_bits=12", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, + + clientMaxWindowBits: 12, + serverMaxWindowBits: 0, }, expOK: true, }, @@ -531,6 +547,9 @@ func Test_selectDeflate(t *testing.T) { expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, + + clientMaxWindowBits: 15, + serverMaxWindowBits: 0, }, expOK: true, }, diff --git a/compress.go b/compress.go index 41bd5bdb..37c54250 100644 --- a/compress.go +++ b/compress.go @@ -3,7 +3,10 @@ package websocket import ( - "compress/flate" + "strconv" + + "github.com/klauspost/compress/flate" + "io" "sync" ) @@ -53,12 +56,18 @@ func (m CompressionMode) opts() *compressionOptions { return &compressionOptions{ clientNoContextTakeover: m == CompressionNoContextTakeover, serverNoContextTakeover: m == CompressionNoContextTakeover, + + serverMaxWindowBits: 0, + clientMaxWindowBits: 0, } } type compressionOptions struct { clientNoContextTakeover bool serverNoContextTakeover bool + + serverMaxWindowBits int + clientMaxWindowBits int } func (copts *compressionOptions) String() string { @@ -69,6 +78,11 @@ func (copts *compressionOptions) String() string { if copts.serverNoContextTakeover { s += "; server_no_context_takeover" } + + if copts.clientMaxWindowBits != 0 { + s += "; client_max_window_bits=" + strconv.Itoa(copts.clientMaxWindowBits) + } + return s } @@ -147,20 +161,24 @@ func putFlateReader(fr io.Reader) { flateReaderPool.Put(fr) } -var flateWriterPool sync.Pool +var flateWriterPool [16]sync.Pool -func getFlateWriter(w io.Writer) *flate.Writer { - fw, ok := flateWriterPool.Get().(*flate.Writer) +func getFlateWriter(w io.Writer, bits int) *flate.Writer { + fw, ok := flateWriterPool[bits].Get().(*flate.Writer) if !ok { - fw, _ = flate.NewWriter(w, flate.BestSpeed) + if bits == 0 { + fw, _ = flate.NewWriter(w, flate.BestCompression) + } else { + fw, _ = flate.NewWriterWindow(w, 1<