diff --git a/vendor/src/github.com/gorilla/websocket/AUTHORS b/vendor/src/github.com/gorilla/websocket/AUTHORS index b003eca0..1931f400 100755 --- a/vendor/src/github.com/gorilla/websocket/AUTHORS +++ b/vendor/src/github.com/gorilla/websocket/AUTHORS @@ -4,5 +4,6 @@ # Please keep the list sorted. Gary Burd +Google LLC (https://opensource.google.com/) Joachim Bauch diff --git a/vendor/src/github.com/gorilla/websocket/README.md b/vendor/src/github.com/gorilla/websocket/README.md index 33c3d2be..19aa2e75 100755 --- a/vendor/src/github.com/gorilla/websocket/README.md +++ b/vendor/src/github.com/gorilla/websocket/README.md @@ -1,14 +1,14 @@ # Gorilla WebSocket +[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) +[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) + Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. -[![Build Status](https://travis-ci.org/gorilla/websocket.svg?branch=master)](https://travis-ci.org/gorilla/websocket) -[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) - ### Documentation -* [API Reference](http://godoc.org/github.com/gorilla/websocket) +* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) * [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) * [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) * [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) @@ -27,7 +27,7 @@ package API is stable. ### Protocol Compliance The Gorilla WebSocket package passes the server tests in the [Autobahn Test -Suite](http://autobahn.ws/testsuite) using the application in the [examples/autobahn +Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). ### Gorilla WebSocket compared with other packages @@ -40,7 +40,7 @@ subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn RFC 6455 Features -Passes Autobahn Test SuiteYesNo +Passes Autobahn Test SuiteYesNo Receive fragmented messageYesNo, see note 1 Send close messageYesNo Send pings and receive pongsYesNo @@ -51,7 +51,7 @@ subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn Write message using io.WriteCloserYesNo, see note 3 -Notes: +Notes: 1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html). 2. The application can get the type of a received data message by implementing diff --git a/vendor/src/github.com/gorilla/websocket/client.go b/vendor/src/github.com/gorilla/websocket/client.go index 43a87c75..c4b62fbc 100755 --- a/vendor/src/github.com/gorilla/websocket/client.go +++ b/vendor/src/github.com/gorilla/websocket/client.go @@ -5,15 +5,15 @@ package websocket import ( - "bufio" "bytes" + "context" "crypto/tls" - "encoding/base64" "errors" "io" "io/ioutil" "net" "net/http" + "net/http/httptrace" "net/url" "strings" "time" @@ -53,6 +53,10 @@ type Dialer struct { // NetDial is nil, net.Dial is used. NetDial func(network, addr string) (net.Conn, error) + // NetDialContext specifies the dial function for creating TCP connections. If + // NetDialContext is nil, net.DialContext is used. + NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -66,11 +70,22 @@ type Dialer struct { // HandshakeTimeout specifies the duration for the handshake to complete. HandshakeTimeout time.Duration - // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer // size is zero, then a useful default size is used. The I/O buffer sizes // do not limit the size of the messages that can be sent or received. ReadBufferSize, WriteBufferSize int + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + // Subprotocols specifies the client's requested subprotocols. Subprotocols []string @@ -86,52 +101,13 @@ type Dialer struct { Jar http.CookieJar } -var errMalformedURL = errors.New("malformed ws or wss URL") - -// parseURL parses the URL. -// -// This function is a replacement for the standard library url.Parse function. -// In Go 1.4 and earlier, url.Parse loses information from the path. -func parseURL(s string) (*url.URL, error) { - // From the RFC: - // - // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] - // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] - var u url.URL - switch { - case strings.HasPrefix(s, "ws://"): - u.Scheme = "ws" - s = s[len("ws://"):] - case strings.HasPrefix(s, "wss://"): - u.Scheme = "wss" - s = s[len("wss://"):] - default: - return nil, errMalformedURL - } - - if i := strings.Index(s, "?"); i >= 0 { - u.RawQuery = s[i+1:] - s = s[:i] - } - - if i := strings.Index(s, "/"); i >= 0 { - u.Opaque = s[i:] - s = s[:i] - } else { - u.Opaque = "/" - } - - u.Host = s - - if strings.Contains(u.Host, "@") { - // Don't bother parsing user information because user information is - // not allowed in websocket URIs. - return nil, errMalformedURL - } - - return &u, nil +// Dial creates a new client connection by calling DialContext with a background context. +func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + return d.DialContext(context.Background(), urlStr, requestHeader) } +var errMalformedURL = errors.New("malformed ws or wss URL") + func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { hostPort = u.Host hostNoPort = u.Host @@ -150,26 +126,29 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { return hostPort, hostNoPort } -// DefaultDialer is a dialer with all fields set to the default zero values. +// DefaultDialer is a dialer with all fields set to the default values. var DefaultDialer = &Dialer{ - Proxy: http.ProxyFromEnvironment, + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, } -// Dial creates a new client connection. Use requestHeader to specify the +// nilDialer is dialer to use when receiver is nil. +var nilDialer = *DefaultDialer + +// DialContext creates a new client connection. Use requestHeader to specify the // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). // Use the response.Header to get the selected subprotocol // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). // +// The context will be used in the request and in the Dialer. +// // If the WebSocket handshake fails, ErrBadHandshake is returned along with a // non-nil *http.Response so that callers can handle redirects, authentication, // etcetera. The response body may not contain the entire response and does not // need to be closed by the application. -func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { - +func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { if d == nil { - d = &Dialer{ - Proxy: http.ProxyFromEnvironment, - } + d = &nilDialer } challengeKey, err := generateChallengeKey() @@ -177,7 +156,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re return nil, nil, err } - u, err := parseURL(urlStr) + u, err := url.Parse(urlStr) if err != nil { return nil, nil, err } @@ -205,6 +184,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re Header: make(http.Header), Host: u.Host, } + req = req.WithContext(ctx) // Set the cookies present in the cookie jar of the dialer if d.Jar != nil { @@ -237,45 +217,83 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re k == "Sec-Websocket-Extensions" || (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) + case k == "Sec-Websocket-Protocol": + req.Header["Sec-WebSocket-Protocol"] = vs default: req.Header[k] = vs } } if d.EnableCompression { - req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} } - hostPort, hostNoPort := hostPortNoPort(u) - - var proxyURL *url.URL - // Check wether the proxy method has been configured - if d.Proxy != nil { - proxyURL, err = d.Proxy(req) - } - if err != nil { - return nil, nil, err + if d.HandshakeTimeout != 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) + defer cancel() } - var targetHostPort string - if proxyURL != nil { - targetHostPort, _ = hostPortNoPort(proxyURL) + // Get network dial function. + var netDial func(network, add string) (net.Conn, error) + + if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial } else { - targetHostPort = hostPort + netDialer := &net.Dialer{} + netDial = func(network, addr string) (net.Conn, error) { + return netDialer.DialContext(ctx, network, addr) + } } - var deadline time.Time - if d.HandshakeTimeout != 0 { - deadline = time.Now().Add(d.HandshakeTimeout) + // If needed, wrap the dial function to set the connection deadline. + if deadline, ok := ctx.Deadline(); ok { + forwardDial := netDial + netDial = func(network, addr string) (net.Conn, error) { + c, err := forwardDial(network, addr) + if err != nil { + return nil, err + } + err = c.SetDeadline(deadline) + if err != nil { + c.Close() + return nil, err + } + return c, nil + } } - netDial := d.NetDial - if netDial == nil { - netDialer := &net.Dialer{Deadline: deadline} - netDial = netDialer.Dial + // If needed, wrap the dial function to connect through a proxy. + if d.Proxy != nil { + proxyURL, err := d.Proxy(req) + if err != nil { + return nil, nil, err + } + if proxyURL != nil { + dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) + if err != nil { + return nil, nil, err + } + netDial = dialer.Dial + } + } + + hostPort, hostNoPort := hostPortNoPort(u) + trace := httptrace.ContextClientTrace(ctx) + if trace != nil && trace.GetConn != nil { + trace.GetConn(hostPort) } - netConn, err := netDial("tcp", targetHostPort) + netConn, err := netDial("tcp", hostPort) + if trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{ + Conn: netConn, + }) + } if err != nil { return nil, nil, err } @@ -286,42 +304,6 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } }() - if err := netConn.SetDeadline(deadline); err != nil { - return nil, nil, err - } - - if proxyURL != nil { - connectHeader := make(http.Header) - if user := proxyURL.User; user != nil { - proxyUser := user.Username() - if proxyPassword, passwordSet := user.Password(); passwordSet { - credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) - connectHeader.Set("Proxy-Authorization", "Basic "+credential) - } - } - connectReq := &http.Request{ - Method: "CONNECT", - URL: &url.URL{Opaque: hostPort}, - Host: hostPort, - Header: connectHeader, - } - - connectReq.Write(netConn) - - // Read response. - // Okay to use and discard buffered reader here, because - // TLS server will not speak until spoken to. - br := bufio.NewReader(netConn) - resp, err := http.ReadResponse(br, connectReq) - if err != nil { - return nil, nil, err - } - if resp.StatusCode != 200 { - f := strings.SplitN(resp.Status, " ", 2) - return nil, nil, errors.New(f[1]) - } - } - if u.Scheme == "https" { cfg := cloneTLSConfig(d.TLSClientConfig) if cfg.ServerName == "" { @@ -329,22 +311,31 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } tlsConn := tls.Client(netConn, cfg) netConn = tlsConn - if err := tlsConn.Handshake(); err != nil { - return nil, nil, err + + var err error + if trace != nil { + err = doHandshakeWithTrace(trace, tlsConn, cfg) + } else { + err = doHandshake(tlsConn, cfg) } - if !cfg.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { - return nil, nil, err - } + + if err != nil { + return nil, nil, err } } - conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) if err := req.Write(netConn); err != nil { return nil, nil, err } + if trace != nil && trace.GotFirstResponseByte != nil { + if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { + trace.GotFirstResponseByte() + } + } + resp, err := http.ReadResponse(conn.br, req) if err != nil { return nil, nil, err @@ -357,8 +348,8 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } if resp.StatusCode != 101 || - !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || - !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { // Before closing the network connection on return from this // function, slurp up some of the response to aid application @@ -390,3 +381,15 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re netConn = nil // to avoid close in defer. return conn, resp, nil } + +func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.Handshake(); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/vendor/src/github.com/gorilla/websocket/client_server_test.go b/vendor/src/github.com/gorilla/websocket/client_server_test.go index 7d39da68..5fd2c85a 100755 --- a/vendor/src/github.com/gorilla/websocket/client_server_test.go +++ b/vendor/src/github.com/gorilla/websocket/client_server_test.go @@ -5,14 +5,21 @@ package websocket import ( + "bytes" + "context" "crypto/tls" "crypto/x509" "encoding/base64" + "encoding/binary" + "fmt" "io" "io/ioutil" + "log" + "net" "net/http" "net/http/cookiejar" "net/http/httptest" + "net/http/httptrace" "net/url" "reflect" "strings" @@ -31,9 +38,10 @@ var cstUpgrader = Upgrader{ } var cstDialer = Dialer{ - Subprotocols: []string{"p1", "p2"}, - ReadBufferSize: 1024, - WriteBufferSize: 1024, + Subprotocols: []string{"p1", "p2"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + HandshakeTimeout: 30 * time.Second, } type cstHandler struct{ *testing.T } @@ -41,6 +49,7 @@ type cstHandler struct{ *testing.T } type cstServer struct { *httptest.Server URL string + t *testing.T } const ( @@ -68,18 +77,18 @@ func newTLSServer(t *testing.T) *cstServer { func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.URL.Path != cstPath { t.Logf("path=%v, want %v", r.URL.Path, cstPath) - http.Error(w, "bad path", 400) + http.Error(w, "bad path", http.StatusBadRequest) return } if r.URL.RawQuery != cstRawQuery { t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) - http.Error(w, "bad path", 400) + http.Error(w, "bad path", http.StatusBadRequest) return } subprotos := Subprotocols(r) if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) - http.Error(w, "bad protocol", 400) + http.Error(w, "bad protocol", http.StatusBadRequest) return } ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) @@ -143,8 +152,9 @@ func TestProxyDial(t *testing.T) { s := newServer(t) defer s.Close() - surl, _ := url.Parse(s.URL) + surl, _ := url.Parse(s.Server.URL) + cstDialer := cstDialer // make local copy for modification on next line. cstDialer.Proxy = http.ProxyURL(surl) connect := false @@ -155,13 +165,13 @@ func TestProxyDial(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { if r.Method == "CONNECT" { connect = true - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) return } if !connect { - t.Log("connect not recieved") - http.Error(w, "connect not recieved", 405) + t.Log("connect not received") + http.Error(w, "connect not received", http.StatusMethodNotAllowed) return } origHandler.ServeHTTP(w, r) @@ -173,16 +183,16 @@ func TestProxyDial(t *testing.T) { } defer ws.Close() sendRecv(t, ws) - - cstDialer.Proxy = http.ProxyFromEnvironment } func TestProxyAuthorizationDial(t *testing.T) { s := newServer(t) defer s.Close() - surl, _ := url.Parse(s.URL) + surl, _ := url.Parse(s.Server.URL) surl.User = url.UserPassword("username", "password") + + cstDialer := cstDialer // make local copy for modification on next line. cstDialer.Proxy = http.ProxyURL(surl) connect := false @@ -195,13 +205,13 @@ func TestProxyAuthorizationDial(t *testing.T) { expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth { connect = true - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) return } if !connect { - t.Log("connect with proxy authorization not recieved") - http.Error(w, "connect with proxy authorization not recieved", 405) + t.Log("connect with proxy authorization not received") + http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed) return } origHandler.ServeHTTP(w, r) @@ -213,8 +223,6 @@ func TestProxyAuthorizationDial(t *testing.T) { } defer ws.Close() sendRecv(t, ws) - - cstDialer.Proxy = http.ProxyFromEnvironment } func TestDial(t *testing.T) { @@ -237,7 +245,7 @@ func TestDialCookieJar(t *testing.T) { d := cstDialer d.Jar = jar - u, _ := parseURL(s.URL) + u, _ := url.Parse(s.URL) switch u.Scheme { case "ws": @@ -246,7 +254,7 @@ func TestDialCookieJar(t *testing.T) { u.Scheme = "https" } - cookies := []*http.Cookie{&http.Cookie{Name: "gorilla", Value: "ws", Path: "/"}} + cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}} d.Jar.SetCookies(u, cookies) ws, _, err := d.Dial(s.URL, nil) @@ -277,10 +285,7 @@ func TestDialCookieJar(t *testing.T) { sendRecv(t, ws) } -func TestDialTLS(t *testing.T) { - s := newTLSServer(t) - defer s.Close() - +func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool { certs := x509.NewCertPool() for _, c := range s.TLS.Certificates { roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) @@ -291,9 +296,15 @@ func TestDialTLS(t *testing.T) { certs.AddCert(root) } } + return certs +} + +func TestDialTLS(t *testing.T) { + s := newTLSServer(t) + defer s.Close() d := cstDialer - d.TLSClientConfig = &tls.Config{RootCAs: certs} + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} ws, _, err := d.Dial(s.URL, nil) if err != nil { t.Fatalf("Dial: %v", err) @@ -302,43 +313,97 @@ func TestDialTLS(t *testing.T) { sendRecv(t, ws) } -func xTestDialTLSBadCert(t *testing.T) { - // This test is deactivated because of noisy logging from the net/http package. - s := newTLSServer(t) +func TestDialTimeout(t *testing.T) { + s := newServer(t) defer s.Close() - ws, _, err := cstDialer.Dial(s.URL, nil) + d := cstDialer + d.HandshakeTimeout = -1 + ws, _, err := d.Dial(s.URL, nil) if err == nil { ws.Close() t.Fatalf("Dial: nil") } } -func TestDialTLSNoVerify(t *testing.T) { - s := newTLSServer(t) +// requireDeadlineNetConn fails the current test when Read or Write are called +// with no deadline. +type requireDeadlineNetConn struct { + t *testing.T + c net.Conn + readDeadlineIsSet bool + writeDeadlineIsSet bool +} + +func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error { + c.writeDeadlineIsSet = !t.Equal(time.Time{}) + c.readDeadlineIsSet = c.writeDeadlineIsSet + return c.c.SetDeadline(t) +} + +func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error { + c.readDeadlineIsSet = !t.Equal(time.Time{}) + return c.c.SetDeadline(t) +} + +func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error { + c.writeDeadlineIsSet = !t.Equal(time.Time{}) + return c.c.SetDeadline(t) +} + +func (c *requireDeadlineNetConn) Write(p []byte) (int, error) { + if !c.writeDeadlineIsSet { + c.t.Fatalf("write with no deadline") + } + return c.c.Write(p) +} + +func (c *requireDeadlineNetConn) Read(p []byte) (int, error) { + if !c.readDeadlineIsSet { + c.t.Fatalf("read with no deadline") + } + return c.c.Read(p) +} + +func (c *requireDeadlineNetConn) Close() error { return c.c.Close() } +func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() } +func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } + +func TestHandshakeTimeout(t *testing.T) { + s := newServer(t) defer s.Close() d := cstDialer - d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + d.NetDial = func(n, a string) (net.Conn, error) { + c, err := net.Dial(n, a) + return &requireDeadlineNetConn{c: c, t: t}, err + } ws, _, err := d.Dial(s.URL, nil) if err != nil { - t.Fatalf("Dial: %v", err) + t.Fatal("Dial:", err) } - defer ws.Close() - sendRecv(t, ws) + ws.Close() } -func TestDialTimeout(t *testing.T) { +func TestHandshakeTimeoutInContext(t *testing.T) { s := newServer(t) defer s.Close() d := cstDialer - d.HandshakeTimeout = -1 - ws, _, err := d.Dial(s.URL, nil) - if err == nil { - ws.Close() - t.Fatalf("Dial: nil") + d.HandshakeTimeout = 0 + d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) { + netDialer := &net.Dialer{} + c, err := netDialer.DialContext(ctx, n, a) + return &requireDeadlineNetConn{c: c, t: t}, err } + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) + defer cancel() + ws, _, err := d.DialContext(ctx, s.URL, nil) + if err != nil { + t.Fatal("Dial:", err) + } + ws.Close() } func TestDialBadScheme(t *testing.T) { @@ -398,9 +463,17 @@ func TestBadMethod(t *testing.T) { })) defer s.Close() - resp, err := http.PostForm(s.URL, url.Values{}) + req, err := http.NewRequest("POST", s.URL, strings.NewReader("")) + if err != nil { + t.Fatalf("NewRequest returned error %v", err) + } + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-Websocket-Version", "13") + + resp, err := http.DefaultClient.Do(req) if err != nil { - t.Fatalf("PostForm returned error %v", err) + t.Fatalf("Do returned error %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusMethodNotAllowed { @@ -408,6 +481,23 @@ func TestBadMethod(t *testing.T) { } } +func TestDialExtraTokensInRespHeaders(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + challengeKey := r.Header.Get("Sec-Websocket-Key") + w.Header().Set("Upgrade", "foo, websocket") + w.Header().Set("Connection", "upgrade, keep-alive") + w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) + w.WriteHeader(101) + })) + defer s.Close() + + ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() +} + func TestHandshake(t *testing.T) { s := newServer(t) defer s.Close() @@ -468,45 +558,365 @@ func TestRespOnBadHandshake(t *testing.T) { } } -// TestHostHeader confirms that the host header provided in the call to Dial is -// sent to the server. -func TestHostHeader(t *testing.T) { +type testLogWriter struct { + t *testing.T +} + +func (w testLogWriter) Write(p []byte) (int, error) { + w.t.Logf("%s", p) + return len(p), nil +} + +// TestHost tests handling of host names and confirms that it matches net/http. +func TestHost(t *testing.T) { + + upgrader := Upgrader{} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if IsWebSocketUpgrade(r) { + c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) + if err != nil { + t.Fatal(err) + } + c.Close() + } else { + w.Header().Set("X-Test-Host", r.Host) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + tlsServer := httptest.NewTLSServer(handler) + defer tlsServer.Close() + + addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()} + wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"} + httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"} + + // Avoid log noise from net/http server by logging to testing.T + server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0) + tlsServer.Config.ErrorLog = server.Config.ErrorLog + + cas := rootCAs(t, tlsServer) + + tests := []struct { + fail bool // true if dial / get should fail + server *httptest.Server // server to use + url string // host for request URI + header string // optional request host header + tls string // optional host for tls ServerName + wantAddr string // expected host for dial + wantHeader string // expected request header on server + insecureSkipVerify bool + }{ + { + server: server, + url: addrs[server], + wantAddr: addrs[server], + wantHeader: addrs[server], + }, + { + server: tlsServer, + url: addrs[tlsServer], + wantAddr: addrs[tlsServer], + wantHeader: addrs[tlsServer], + }, + + { + server: server, + url: addrs[server], + header: "badhost.com", + wantAddr: addrs[server], + wantHeader: "badhost.com", + }, + { + server: tlsServer, + url: addrs[tlsServer], + header: "badhost.com", + wantAddr: addrs[tlsServer], + wantHeader: "badhost.com", + }, + + { + server: server, + url: "example.com", + header: "badhost.com", + wantAddr: "example.com:80", + wantHeader: "badhost.com", + }, + { + server: tlsServer, + url: "example.com", + header: "badhost.com", + wantAddr: "example.com:443", + wantHeader: "badhost.com", + }, + + { + server: server, + url: "badhost.com", + header: "example.com", + wantAddr: "badhost.com:80", + wantHeader: "example.com", + }, + { + fail: true, + server: tlsServer, + url: "badhost.com", + header: "example.com", + wantAddr: "badhost.com:443", + }, + { + server: tlsServer, + url: "badhost.com", + insecureSkipVerify: true, + wantAddr: "badhost.com:443", + wantHeader: "badhost.com", + }, + { + server: tlsServer, + url: "badhost.com", + tls: "example.com", + wantAddr: "badhost.com:443", + wantHeader: "badhost.com", + }, + } + + for i, tt := range tests { + + tls := &tls.Config{ + RootCAs: cas, + ServerName: tt.tls, + InsecureSkipVerify: tt.insecureSkipVerify, + } + + var gotAddr string + dialer := Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + gotAddr = addr + return net.Dial(network, addrs[tt.server]) + }, + TLSClientConfig: tls, + } + + // Test websocket dial + + h := http.Header{} + if tt.header != "" { + h.Set("Host", tt.header) + } + c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h) + if err == nil { + c.Close() + } + + check := func(protos map[*httptest.Server]string) { + name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls) + if gotAddr != tt.wantAddr { + t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr) + } + switch { + case tt.fail && err == nil: + t.Errorf("%s: unexpected success", name) + case !tt.fail && err != nil: + t.Errorf("%s: unexpected error %v", name, err) + case !tt.fail && err == nil: + if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader { + t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader) + } + } + } + + check(wsProtos) + + // Confirm that net/http has same result + + transport := &http.Transport{ + Dial: dialer.NetDial, + TLSClientConfig: dialer.TLSClientConfig, + } + req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil) + if tt.header != "" { + req.Host = tt.header + } + client := &http.Client{Transport: transport} + resp, err = client.Do(req) + if err == nil { + resp.Body.Close() + } + transport.CloseIdleConnections() + check(httpProtos) + } +} + +func TestDialCompression(t *testing.T) { s := newServer(t) defer s.Close() - specifiedHost := make(chan string, 1) - origHandler := s.Server.Config.Handler + dialer := cstDialer + dialer.EnableCompression = true + ws, _, err := dialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} - // Capture the request Host header. - s.Server.Config.Handler = http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - specifiedHost <- r.Host - origHandler.ServeHTTP(w, r) - }) +func TestSocksProxyDial(t *testing.T) { + s := newServer(t) + defer s.Close() + + proxyListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer proxyListener.Close() + go func() { + c1, err := proxyListener.Accept() + if err != nil { + t.Errorf("proxy accept failed: %v", err) + return + } + defer c1.Close() + + c1.SetDeadline(time.Now().Add(30 * time.Second)) + + buf := make([]byte, 32) + if _, err := io.ReadFull(c1, buf[:3]); err != nil { + t.Errorf("read failed: %v", err) + return + } + if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) { + t.Errorf("read %x, want %x", buf[:len(want)], want) + } + if _, err := c1.Write([]byte{5, 0}); err != nil { + t.Errorf("write failed: %v", err) + return + } + if _, err := io.ReadFull(c1, buf[:10]); err != nil { + t.Errorf("read failed: %v", err) + return + } + if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) { + t.Errorf("read %x, want %x", buf[:len(want)], want) + return + } + buf[1] = 0 + if _, err := c1.Write(buf[:10]); err != nil { + t.Errorf("write failed: %v", err) + return + } + + ip := net.IP(buf[4:8]) + port := binary.BigEndian.Uint16(buf[8:10]) + + c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)}) + if err != nil { + t.Errorf("dial failed; %v", err) + return + } + defer c2.Close() + done := make(chan struct{}) + go func() { + io.Copy(c1, c2) + close(done) + }() + io.Copy(c2, c1) + <-done + }() + + purl, err := url.Parse("socks5://" + proxyListener.Addr().String()) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(purl) - ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}}) + ws, _, err := cstDialer.Dial(s.URL, nil) if err != nil { t.Fatalf("Dial: %v", err) } defer ws.Close() + sendRecv(t, ws) +} + +func TestTracingDialWithContext(t *testing.T) { + + var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool + trace := &httptrace.ClientTrace{ + WroteHeaders: func() { + headersWrote = true + }, + WroteRequest: func(httptrace.WroteRequestInfo) { + requestWrote = true + }, + GetConn: func(hostPort string) { + getConn = true + }, + GotConn: func(info httptrace.GotConnInfo) { + gotConn = true + }, + ConnectDone: func(network, addr string, err error) { + connectDone = true + }, + GotFirstResponseByte: func() { + gotFirstResponseByte = true + }, + } + ctx := httptrace.WithClientTrace(context.Background(), trace) + + s := newTLSServer(t) + defer s.Close() + + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} + + ws, _, err := d.DialContext(ctx, s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } - if gotHost := <-specifiedHost; gotHost != "testhost" { - t.Fatalf("gotHost = %q, want \"testhost\"", gotHost) + if !headersWrote { + t.Fatal("Headers was not written") + } + if !requestWrote { + t.Fatal("Request was not written") + } + if !getConn { + t.Fatal("getConn was not called") + } + if !gotConn { + t.Fatal("gotConn was not called") + } + if !connectDone { + t.Fatal("connectDone was not called") + } + if !gotFirstResponseByte { + t.Fatal("GotFirstResponseByte was not called") } + defer ws.Close() sendRecv(t, ws) } -func TestDialCompression(t *testing.T) { - s := newServer(t) +func TestEmptyTracingDialWithContext(t *testing.T) { + + trace := &httptrace.ClientTrace{} + ctx := httptrace.WithClientTrace(context.Background(), trace) + + s := newTLSServer(t) defer s.Close() - dialer := cstDialer - dialer.EnableCompression = true - ws, _, err := dialer.Dial(s.URL, nil) + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} + + ws, _, err := d.DialContext(ctx, s.URL, nil) if err != nil { t.Fatalf("Dial: %v", err) } + defer ws.Close() sendRecv(t, ws) } diff --git a/vendor/src/github.com/gorilla/websocket/client_test.go b/vendor/src/github.com/gorilla/websocket/client_test.go index 7d2b0844..5aa27b37 100755 --- a/vendor/src/github.com/gorilla/websocket/client_test.go +++ b/vendor/src/github.com/gorilla/websocket/client_test.go @@ -6,49 +6,9 @@ package websocket import ( "net/url" - "reflect" "testing" ) -var parseURLTests = []struct { - s string - u *url.URL - rui string -}{ - {"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"}, - {"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"}, - {"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}, "/"}, - {"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}, "/"}, - {"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}, "/a/b"}, - {"ss://example.com/a/b", nil, ""}, - {"ws://webmaster@example.com/", nil, ""}, - {"wss://example.com/a/b?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b", RawQuery: "x=y"}, "/a/b?x=y"}, - {"wss://example.com?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/", RawQuery: "x=y"}, "/?x=y"}, -} - -func TestParseURL(t *testing.T) { - for _, tt := range parseURLTests { - u, err := parseURL(tt.s) - if tt.u != nil && err != nil { - t.Errorf("parseURL(%q) returned error %v", tt.s, err) - continue - } - if tt.u == nil { - if err == nil { - t.Errorf("parseURL(%q) did not return error", tt.s) - } - continue - } - if !reflect.DeepEqual(u, tt.u) { - t.Errorf("parseURL(%q) = %v, want %v", tt.s, u, tt.u) - continue - } - if u.RequestURI() != tt.rui { - t.Errorf("parseURL(%q).RequestURI() = %v, want %v", tt.s, u.RequestURI(), tt.rui) - } - } -} - var hostPortNoPortTests = []struct { u *url.URL hostPort, hostNoPort string diff --git a/vendor/src/github.com/gorilla/websocket/compression_test.go b/vendor/src/github.com/gorilla/websocket/compression_test.go index 659cf421..8a26b30f 100755 --- a/vendor/src/github.com/gorilla/websocket/compression_test.go +++ b/vendor/src/github.com/gorilla/websocket/compression_test.go @@ -43,7 +43,7 @@ func textMessages(num int) [][]byte { func BenchmarkWriteNoCompression(b *testing.B) { w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c := newTestConn(nil, w, false) messages := textMessages(100) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -54,7 +54,7 @@ func BenchmarkWriteNoCompression(b *testing.B) { func BenchmarkWriteWithCompression(b *testing.B) { w := ioutil.Discard - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c := newTestConn(nil, w, false) messages := textMessages(100) c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover @@ -66,7 +66,7 @@ func BenchmarkWriteWithCompression(b *testing.B) { } func TestValidCompressionLevel(t *testing.T) { - c := newConn(fakeNetConn{}, false, 1024, 1024) + c := newTestConn(nil, nil, false) for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { if err := c.SetCompressionLevel(level); err == nil { t.Errorf("no error for level %d", level) diff --git a/vendor/src/github.com/gorilla/websocket/conn.go b/vendor/src/github.com/gorilla/websocket/conn.go index 97e1dbac..ca46d2f7 100755 --- a/vendor/src/github.com/gorilla/websocket/conn.go +++ b/vendor/src/github.com/gorilla/websocket/conn.go @@ -76,7 +76,7 @@ const ( // is UTF-8 encoded text. PingMessage = 9 - // PongMessage denotes a ping control message. The optional message payload + // PongMessage denotes a pong control message. The optional message payload // is UTF-8 encoded text. PongMessage = 10 ) @@ -100,9 +100,8 @@ func (e *netError) Error() string { return e.msg } func (e *netError) Temporary() bool { return e.temporary } func (e *netError) Timeout() bool { return e.timeout } -// CloseError represents close frame. +// CloseError represents a close message. type CloseError struct { - // Code is defined in RFC 6455, section 11.7. Code int @@ -224,6 +223,20 @@ func isValidReceivedCloseCode(code int) bool { return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) } +// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this +// interface. The type of the value stored in a pool is not specified. +type BufferPool interface { + // Get gets a value from the pool or returns nil if the pool is empty. + Get() interface{} + // Put adds a value to the pool. + Put(interface{}) +} + +// writePoolData is the type added to the write buffer pool. This wrapper is +// used to prevent applications from peeking at and depending on the values +// added to the pool. +type writePoolData struct{ buf []byte } + // The Conn type represents a WebSocket connection. type Conn struct { conn net.Conn @@ -231,8 +244,10 @@ type Conn struct { subprotocol string // Write fields - mu chan bool // used as mutex to protect write to conn - writeBuf []byte // frame is constructed in this buffer. + mu chan struct{} // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. + writePool BufferPool + writeBufSize int writeDeadline time.Time writer io.WriteCloser // the current writer returned to the application isWriting bool // for best-effort concurrent write detection @@ -245,10 +260,12 @@ type Conn struct { newCompressionWriter func(io.WriteCloser, int) io.WriteCloser // Read fields - reader io.ReadCloser // the current reader returned to the application - readErr error - br *bufio.Reader - readRemaining int64 // bytes remaining in current frame. + reader io.ReadCloser // the current reader returned to the application + readErr error + br *bufio.Reader + // bytes remaining in current frame. + // set setReadRemaining to safely update this value and prevent overflow + readRemaining int64 readFinal bool // true the current message has more frames. readLength int64 // Message size. readLimit int64 // Maximum message size. @@ -264,64 +281,29 @@ type Conn struct { newDecompressionReader func(io.Reader) io.ReadCloser } -func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { - return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) -} - -type writeHook struct { - p []byte -} - -func (wh *writeHook) Write(p []byte) (int, error) { - wh.p = p - return len(p), nil -} +func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn { -func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn { - mu := make(chan bool, 1) - mu <- true - - var br *bufio.Reader - if readBufferSize == 0 && brw != nil && brw.Reader != nil { - // Reuse the supplied bufio.Reader if the buffer has a useful size. - // This code assumes that peek on a reader returns - // bufio.Reader.buf[:0]. - brw.Reader.Reset(conn) - if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 { - br = brw.Reader - } - } if br == nil { if readBufferSize == 0 { readBufferSize = defaultReadBufferSize - } - if readBufferSize < maxControlFramePayloadSize { + } else if readBufferSize < maxControlFramePayloadSize { + // must be large enough for control frame readBufferSize = maxControlFramePayloadSize } br = bufio.NewReaderSize(conn, readBufferSize) } - var writeBuf []byte - if writeBufferSize == 0 && brw != nil && brw.Writer != nil { - // Use the bufio.Writer's buffer if the buffer has a useful size. This - // code assumes that bufio.Writer.buf[:1] is passed to the - // bufio.Writer's underlying writer. - var wh writeHook - brw.Writer.Reset(&wh) - brw.Writer.WriteByte(0) - brw.Flush() - if cap(wh.p) >= maxFrameHeaderSize+256 { - writeBuf = wh.p[:cap(wh.p)] - } + if writeBufferSize <= 0 { + writeBufferSize = defaultWriteBufferSize } + writeBufferSize += maxFrameHeaderSize - if writeBuf == nil { - if writeBufferSize == 0 { - writeBufferSize = defaultWriteBufferSize - } - writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) + if writeBuf == nil && writeBufferPool == nil { + writeBuf = make([]byte, writeBufferSize) } + mu := make(chan struct{}, 1) + mu <- struct{}{} c := &Conn{ isServer: isServer, br: br, @@ -329,6 +311,8 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in mu: mu, readFinal: true, writeBuf: writeBuf, + writePool: writeBufferPool, + writeBufSize: writeBufferSize, enableWriteCompression: true, compressionLevel: defaultCompressionLevel, } @@ -338,12 +322,24 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in return c } +// setReadRemaining tracks the number of bytes remaining on the connection. If n +// overflows, an ErrReadLimit is returned. +func (c *Conn) setReadRemaining(n int64) error { + if n < 0 { + return ErrReadLimit + } + + c.readRemaining = n + return nil +} + // Subprotocol returns the negotiated protocol for the connection. func (c *Conn) Subprotocol() string { return c.subprotocol } -// Close closes the underlying network connection without sending or waiting for a close frame. +// Close closes the underlying network connection without sending or waiting +// for a close message. func (c *Conn) Close() error { return c.conn.Close() } @@ -370,9 +366,18 @@ func (c *Conn) writeFatal(err error) error { return err } -func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + c.br.Discard(len(p)) + return p, err +} + +func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { <-c.mu - defer func() { c.mu <- true }() + defer func() { c.mu <- struct{}{} }() c.writeErrMu.Lock() err := c.writeErr @@ -382,15 +387,14 @@ func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { } c.conn.SetWriteDeadline(deadline) - for _, buf := range bufs { - if len(buf) > 0 { - _, err := c.conn.Write(buf) - if err != nil { - return c.writeFatal(err) - } - } + if len(buf1) == 0 { + _, err = c.conn.Write(buf0) + } else { + err = c.writeBufs(buf0, buf1) + } + if err != nil { + return c.writeFatal(err) } - if frameType == CloseMessage { c.writeFatal(ErrCloseSent) } @@ -425,7 +429,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er maskBytes(key, 0, buf[6:]) } - d := time.Hour * 1000 + d := 1000 * time.Hour if !deadline.IsZero() { d = deadline.Sub(time.Now()) if d < 0 { @@ -440,7 +444,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er case <-timer.C: return errWriteTimeout } - defer func() { c.mu <- true }() + defer func() { c.mu <- struct{}{} }() c.writeErrMu.Lock() err := c.writeErr @@ -460,7 +464,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er return err } -func (c *Conn) prepWrite(messageType int) error { +// beginMessage prepares a connection and message writer for a new message. +func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { // Close previous writer if not already closed by the application. It's // probably better to return an error in this situation, but we cannot // change this without breaking existing applications. @@ -476,7 +481,23 @@ func (c *Conn) prepWrite(messageType int) error { c.writeErrMu.Lock() err := c.writeErr c.writeErrMu.Unlock() - return err + if err != nil { + return err + } + + mw.c = c + mw.frameType = messageType + mw.pos = maxFrameHeaderSize + + if c.writeBuf == nil { + wpd, ok := c.writePool.Get().(writePoolData) + if ok { + c.writeBuf = wpd.buf + } else { + c.writeBuf = make([]byte, c.writeBufSize) + } + } + return nil } // NextWriter returns a writer for the next message to send. The writer's Close @@ -484,17 +505,15 @@ func (c *Conn) prepWrite(messageType int) error { // // There can be at most one open writer on a connection. NextWriter closes the // previous writer if the application has not already done so. +// +// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and +// PongMessage) are supported. func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { - if err := c.prepWrite(messageType); err != nil { + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { return nil, err } - - mw := &messageWriter{ - c: c, - frameType: messageType, - pos: maxFrameHeaderSize, - } - c.writer = mw + c.writer = &mw if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { w := c.newCompressionWriter(c.writer, c.compressionLevel) mw.compress = true @@ -511,10 +530,16 @@ type messageWriter struct { err error } -func (w *messageWriter) fatal(err error) error { +func (w *messageWriter) endMessage(err error) error { if w.err != nil { - w.err = err - w.c.writer = nil + return err + } + c := w.c + w.err = err + c.writer = nil + if c.writePool != nil { + c.writePool.Put(writePoolData{buf: c.writeBuf}) + c.writeBuf = nil } return err } @@ -528,7 +553,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { // Check for invalid control frames. if isControl(w.frameType) && (!final || length > maxControlFramePayloadSize) { - return w.fatal(errInvalidControlFrame) + return w.endMessage(errInvalidControlFrame) } b0 := byte(w.frameType) @@ -573,7 +598,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) if len(extra) > 0 { - return c.writeFatal(errors.New("websocket: internal error, extra used in client mode")) + return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))) } } @@ -594,11 +619,11 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { c.isWriting = false if err != nil { - return w.fatal(err) + return w.endMessage(err) } if final { - c.writer = nil + w.endMessage(errWriteClosed) return nil } @@ -696,11 +721,7 @@ func (w *messageWriter) Close() error { if w.err != nil { return w.err } - if err := w.flushFrame(true, nil); err != nil { - return err - } - w.err = errWriteClosed - return nil + return w.flushFrame(true, nil) } // WritePreparedMessage writes prepared message into connection. @@ -732,10 +753,10 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { // Fast path with no allocations and single frame. - if err := c.prepWrite(messageType); err != nil { + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { return err } - mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize} n := copy(c.writeBuf[mw.pos:], data) mw.pos += n data = data[n:] @@ -764,7 +785,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { // Read methods func (c *Conn) advanceFrame() (int, error) { - // 1. Skip remainder of previous frame. if c.readRemaining > 0 { @@ -783,7 +803,7 @@ func (c *Conn) advanceFrame() (int, error) { final := p[0]&finalBit != 0 frameType := int(p[0] & 0xf) mask := p[1]&maskBit != 0 - c.readRemaining = int64(p[1] & 0x7f) + c.setReadRemaining(int64(p[1] & 0x7f)) c.readDecompress = false if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { @@ -817,7 +837,17 @@ func (c *Conn) advanceFrame() (int, error) { return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) } - // 3. Read and parse frame length. + // 3. Read and parse frame length as per + // https://tools.ietf.org/html/rfc6455#section-5.2 + // + // The length of the "Payload data", in bytes: if 0-125, that is the payload + // length. + // - If 126, the following 2 bytes interpreted as a 16-bit unsigned + // integer are the payload length. + // - If 127, the following 8 bytes interpreted as + // a 64-bit unsigned integer (the most significant bit MUST be 0) are the + // payload length. Multibyte length quantities are expressed in network byte + // order. switch c.readRemaining { case 126: @@ -825,13 +855,19 @@ func (c *Conn) advanceFrame() (int, error) { if err != nil { return noFrame, err } - c.readRemaining = int64(binary.BigEndian.Uint16(p)) + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil { + return noFrame, err + } case 127: p, err := c.read(8) if err != nil { return noFrame, err } - c.readRemaining = int64(binary.BigEndian.Uint64(p)) + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil { + return noFrame, err + } } // 4. Handle frame masking. @@ -854,6 +890,12 @@ func (c *Conn) advanceFrame() (int, error) { if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { c.readLength += c.readRemaining + // Don't allow readLength to overflow in the presence of a large readRemaining + // counter. + if c.readLength < 0 { + return noFrame, ErrReadLimit + } + if c.readLimit > 0 && c.readLength > c.readLimit { c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) return noFrame, ErrReadLimit @@ -867,7 +909,7 @@ func (c *Conn) advanceFrame() (int, error) { var payload []byte if c.readRemaining > 0 { payload, err = c.read(int(c.readRemaining)) - c.readRemaining = 0 + c.setReadRemaining(0) if err != nil { return noFrame, err } @@ -940,6 +982,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { c.readErr = hideTempErr(err) break } + if frameType == TextMessage || frameType == BinaryMessage { c.messageReader = &messageReader{c} c.reader = c.messageReader @@ -980,7 +1023,9 @@ func (r *messageReader) Read(b []byte) (int, error) { if c.isServer { c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) } - c.readRemaining -= int64(n) + rem := c.readRemaining + rem -= int64(n) + c.setReadRemaining(rem) if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } @@ -1032,8 +1077,8 @@ func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } -// SetReadLimit sets the maximum size for a message read from the peer. If a -// message exceeds the limit, the connection sends a close frame to the peer +// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a +// message exceeds the limit, the connection sends a close message to the peer // and returns ErrReadLimit to the application. func (c *Conn) SetReadLimit(limit int64) { c.readLimit = limit @@ -1046,24 +1091,22 @@ func (c *Conn) CloseHandler() func(code int, text string) error { // SetCloseHandler sets the handler for close messages received from the peer. // The code argument to h is the received close code or CloseNoStatusReceived -// if the close message is empty. The default close handler sends a close frame -// back to the peer. +// if the close message is empty. The default close handler sends a close +// message back to the peer. // -// The application must read the connection to process close messages as -// described in the section on Control Frames above. +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// close messages as described in the section on Control Messages above. // -// The connection read methods return a CloseError when a close frame is +// The connection read methods return a CloseError when a close message is // received. Most applications should handle close messages as part of their // normal error handling. Applications should only set a close handler when the -// application must perform some action before sending a close frame back to +// application must perform some action before sending a close message back to // the peer. func (c *Conn) SetCloseHandler(h func(code int, text string) error) { if h == nil { h = func(code int, text string) error { - message := []byte{} - if code != CloseNoStatusReceived { - message = FormatCloseMessage(code, "") - } + message := FormatCloseMessage(code, "") c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) return nil } @@ -1077,11 +1120,12 @@ func (c *Conn) PingHandler() func(appData string) error { } // SetPingHandler sets the handler for ping messages received from the peer. -// The appData argument to h is the PING frame application data. The default +// The appData argument to h is the PING message application data. The default // ping handler sends a pong to the peer. // -// The application must read the connection to process ping messages as -// described in the section on Control Frames above. +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// ping messages as described in the section on Control Messages above. func (c *Conn) SetPingHandler(h func(appData string) error) { if h == nil { h = func(message string) error { @@ -1103,11 +1147,12 @@ func (c *Conn) PongHandler() func(appData string) error { } // SetPongHandler sets the handler for pong messages received from the peer. -// The appData argument to h is the PONG frame application data. The default +// The appData argument to h is the PONG message application data. The default // pong handler does nothing. // -// The application must read the connection to process ping messages as -// described in the section on Control Frames above. +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// pong messages as described in the section on Control Messages above. func (c *Conn) SetPongHandler(h func(appData string) error) { if h == nil { h = func(string) error { return nil } @@ -1141,7 +1186,14 @@ func (c *Conn) SetCompressionLevel(level int) error { } // FormatCloseMessage formats closeCode and text as a WebSocket close message. +// An empty message is returned for code CloseNoStatusReceived. func FormatCloseMessage(closeCode int, text string) []byte { + if closeCode == CloseNoStatusReceived { + // Return empty message because it's illegal to send + // CloseNoStatusReceived. Return non-nil value in case application + // checks for nil. + return []byte{} + } buf := make([]byte, 2+len(text)) binary.BigEndian.PutUint16(buf, uint16(closeCode)) copy(buf[2:], text) diff --git a/vendor/src/github.com/gorilla/websocket/conn_broadcast_test.go b/vendor/src/github.com/gorilla/websocket/conn_broadcast_test.go index 45038e48..cb88cbb0 100755 --- a/vendor/src/github.com/gorilla/websocket/conn_broadcast_test.go +++ b/vendor/src/github.com/gorilla/websocket/conn_broadcast_test.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build go1.7 - package websocket import ( @@ -70,7 +68,7 @@ func (b *broadcastBench) makeConns(numConns int) { conns := make([]*broadcastConn, numConns) for i := 0; i < numConns; i++ { - c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024) + c := newTestConn(nil, b.w, true) if b.compression { c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover diff --git a/vendor/src/github.com/gorilla/websocket/conn_read_legacy.go b/vendor/src/github.com/gorilla/websocket/conn_read_legacy.go deleted file mode 100755 index 018541cf..00000000 --- a/vendor/src/github.com/gorilla/websocket/conn_read_legacy.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !go1.5 - -package websocket - -import "io" - -func (c *Conn) read(n int) ([]byte, error) { - p, err := c.br.Peek(n) - if err == io.EOF { - err = errUnexpectedEOF - } - if len(p) > 0 { - // advance over the bytes just read - io.ReadFull(c.br, p) - } - return p, err -} diff --git a/vendor/src/github.com/gorilla/websocket/conn_test.go b/vendor/src/github.com/gorilla/websocket/conn_test.go index 06e9bc3f..bd96e0ad 100755 --- a/vendor/src/github.com/gorilla/websocket/conn_test.go +++ b/vendor/src/github.com/gorilla/websocket/conn_test.go @@ -13,6 +13,7 @@ import ( "io/ioutil" "net" "reflect" + "sync" "testing" "testing/iotest" "time" @@ -47,8 +48,17 @@ func (a fakeAddr) String() string { return "str" } +// newTestConn creates a connnection backed by a fake network connection using +// default values for buffering. +func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { + return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) +} + func TestFraming(t *testing.T) { - frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537} + frameSizes := []int{ + 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, + // 65536, 65537 + } var readChunkers = []struct { name string f func(io.Reader) io.Reader @@ -82,8 +92,8 @@ func TestFraming(t *testing.T) { for _, chunker := range readChunkers { var connBuf bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) - rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(chunker.f(&connBuf), nil, !isServer) if compress { wc.newCompressionWriter = compressNoContextTakeover rc.newDecompressionReader = decompressNoContextTakeover @@ -113,6 +123,8 @@ func TestFraming(t *testing.T) { t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) continue } + + t.Logf("frame size: %d", n) rbuf, err := ioutil.ReadAll(r) if err != nil { t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) @@ -143,8 +155,8 @@ func TestControl(t *testing.T) { for _, isWriteControl := range []bool{true, false} { name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) var connBuf bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) - rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024) + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(&connBuf, nil, !isServer) if isWriteControl { wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) } else { @@ -173,14 +185,178 @@ func TestControl(t *testing.T) { } } +// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. +type simpleBufferPool struct { + v interface{} +} + +func (p *simpleBufferPool) Get() interface{} { + v := p.v + p.v = nil + return v +} + +func (p *simpleBufferPool) Put(v interface{}) { + p.v = v +} + +func TestWriteBufferPool(t *testing.T) { + const message = "Now is the time for all good people to come to the aid of the party." + + var buf bytes.Buffer + var pool simpleBufferPool + rc := newTestConn(&buf, nil, false) + + // Specify writeBufferSize smaller than message size to ensure that pooling + // works with fragmented messages. + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil) + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after create") + } + + // Part 1: test NextWriter/Write/Close + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + if _, err := io.WriteString(w, message); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err != nil { + t.Fatalf("w.Close() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after w.Close()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + + // Part 2: Test WriteMessage. + + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after wc.WriteMessage()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool after WriteMessage") + } + + opCode, p, err = rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } +} + +// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. +func TestWriteBufferPoolSync(t *testing.T) { + var buf bytes.Buffer + var pool sync.Pool + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) + rc := newTestConn(&buf, nil, false) + + const message = "Hello World!" + for i := 0; i < 3; i++ { + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + } +} + +// errorWriter is an io.Writer than returns an error on all writes. +type errorWriter struct{} + +func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } + +// TestWriteBufferPoolError ensures that buffer is returned to pool after error +// on write. +func TestWriteBufferPoolError(t *testing.T) { + + // Part 1: Test NextWriter/Write/Close + + var pool simpleBufferPool + wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + if _, err := io.WriteString(w, "Hello"); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err == nil { + t.Fatalf("w.Close() did not return error") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + // Part 2: Test WriteMessage + + wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) + + if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil { + t.Fatalf("wc.WriteMessage did not return error") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } +} + func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { const bufSize = 512 expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) w.Write(make([]byte, bufSize+bufSize/2)) @@ -206,8 +382,8 @@ func TestEOFWithinFrame(t *testing.T) { for n := 0; ; n++ { var b bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024) - rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024) + wc := newTestConn(nil, &b, false) + rc := newTestConn(&b, nil, true) w, _ := wc.NextWriter(BinaryMessage) w.Write(make([]byte, bufSize)) @@ -240,8 +416,8 @@ func TestEOFBeforeFinalFrame(t *testing.T) { const bufSize = 512 var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) w.Write(make([]byte, bufSize+bufSize/2)) @@ -261,7 +437,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) { } func TestWriteAfterMessageWriterClose(t *testing.T) { - wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024) + wc := newTestConn(nil, &bytes.Buffer{}, false) w, _ := wc.NextWriter(BinaryMessage) io.WriteString(w, "hello") if err := w.Close(); err != nil { @@ -287,41 +463,97 @@ func TestWriteAfterMessageWriterClose(t *testing.T) { } func TestReadLimit(t *testing.T) { + t.Run("Test ReadLimit is enforced", func(t *testing.T) { + const readLimit = 512 + message := make([]byte, readLimit+1) - const readLimit = 512 - message := make([]byte, readLimit+1) + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) - var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) - rc.SetReadLimit(readLimit) + // Send message at the limit with interleaved pong. + w, _ := wc.NextWriter(BinaryMessage) + w.Write(message[:readLimit-1]) + wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + w.Write(message[:1]) + w.Close() - // Send message at the limit with interleaved pong. - w, _ := wc.NextWriter(BinaryMessage) - w.Write(message[:readLimit-1]) - wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) - w.Write(message[:1]) - w.Close() + // Send message larger than the limit. + wc.WriteMessage(BinaryMessage, message[:readLimit+1]) - // Send message larger than the limit. - wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + op, _, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("2: NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != ErrReadLimit { + t.Fatalf("io.Copy() returned %v", err) + } + }) - op, _, err := rc.NextReader() - if op != BinaryMessage || err != nil { - t.Fatalf("1: NextReader() returned %d, %v", op, err) - } - op, r, err := rc.NextReader() - if op != BinaryMessage || err != nil { - t.Fatalf("2: NextReader() returned %d, %v", op, err) - } - _, err = io.Copy(ioutil.Discard, r) - if err != ErrReadLimit { - t.Fatalf("io.Copy() returned %v", err) - } + t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { + const readLimit = 1 + + var b1, b2 bytes.Buffer + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) + + // First, send a non-final binary message + b1.Write([]byte("\x02\x81")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // First payload + b1.Write([]byte("A")) + + // Next, send a negative-length, non-final continuation frame + b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Next, send a too long, final continuation frame + b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Too-long payload + b1.Write([]byte("BCDEF")) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + + var buf [10]byte + var read int + n, err := r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + n, err = r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + if err == nil && read > readLimit { + t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) + } + }) } func TestAddrs(t *testing.T) { - c := newConn(&fakeNetConn{}, true, 1024, 1024) + c := newTestConn(nil, nil, true) if c.LocalAddr() != localAddr { t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) } @@ -333,7 +565,7 @@ func TestAddrs(t *testing.T) { func TestUnderlyingConn(t *testing.T) { var b1, b2 bytes.Buffer fc := fakeNetConn{Reader: &b1, Writer: &b2} - c := newConn(fc, true, 1024, 1024) + c := newConn(fc, true, 1024, 1024, nil, nil, nil) ul := c.UnderlyingConn() if ul != fc { t.Fatalf("Underlying conn is not what it should be.") @@ -341,15 +573,14 @@ func TestUnderlyingConn(t *testing.T) { } func TestBufioReadBytes(t *testing.T) { - // Test calling bufio.ReadBytes for value longer than read buffer size. m := make([]byte, 512) m[len(m)-1] = '\n' var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64) + wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) w, _ := wc.NextWriter(BinaryMessage) w.Write(m) @@ -366,7 +597,7 @@ func TestBufioReadBytes(t *testing.T) { t.Fatalf("ReadBytes() returned %v", err) } if len(p) != len(m) { - t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m)) + t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) } } @@ -424,7 +655,7 @@ func (w blockingWriter) Write(p []byte) (int, error) { func TestConcurrentWritePanic(t *testing.T) { w := blockingWriter{make(chan struct{}), make(chan struct{})} - c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) + c := newTestConn(nil, w, false) go func() { c.WriteMessage(TextMessage, []byte{}) }() @@ -450,7 +681,7 @@ func (r failingReader) Read(p []byte) (int, error) { } func TestFailedConnectionReadPanic(t *testing.T) { - c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024) + c := newTestConn(failingReader{}, nil, false) defer func() { if v := recover(); v != nil { @@ -463,35 +694,3 @@ func TestFailedConnectionReadPanic(t *testing.T) { } t.Fatal("should not get here") } - -func TestBufioReuse(t *testing.T) { - brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil)) - c := newConnBRW(nil, false, 0, 0, brw) - - if c.br != brw.Reader { - t.Error("connection did not reuse bufio.Reader") - } - - var wh writeHook - brw.Writer.Reset(&wh) - brw.WriteByte(0) - brw.Flush() - if &c.writeBuf[0] != &wh.p[0] { - t.Error("connection did not reuse bufio.Writer") - } - - brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0)) - c = newConnBRW(nil, false, 0, 0, brw) - - if c.br == brw.Reader { - t.Error("connection used bufio.Reader with small size") - } - - brw.Writer.Reset(&wh) - brw.WriteByte(0) - brw.Flush() - if &c.writeBuf[0] != &wh.p[0] { - t.Error("connection used bufio.Writer with small size") - } - -} diff --git a/vendor/src/github.com/gorilla/websocket/conn_read.go b/vendor/src/github.com/gorilla/websocket/conn_write.go similarity index 52% rename from vendor/src/github.com/gorilla/websocket/conn_read.go rename to vendor/src/github.com/gorilla/websocket/conn_write.go index 1ea15059..a509a21f 100755 --- a/vendor/src/github.com/gorilla/websocket/conn_read.go +++ b/vendor/src/github.com/gorilla/websocket/conn_write.go @@ -2,17 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build go1.5 +// +build go1.8 package websocket -import "io" +import "net" -func (c *Conn) read(n int) ([]byte, error) { - p, err := c.br.Peek(n) - if err == io.EOF { - err = errUnexpectedEOF - } - c.br.Discard(len(p)) - return p, err +func (c *Conn) writeBufs(bufs ...[]byte) error { + b := net.Buffers(bufs) + _, err := b.WriteTo(c.conn) + return err } diff --git a/vendor/src/github.com/gorilla/websocket/conn_write_legacy.go b/vendor/src/github.com/gorilla/websocket/conn_write_legacy.go new file mode 100755 index 00000000..37edaff5 --- /dev/null +++ b/vendor/src/github.com/gorilla/websocket/conn_write_legacy.go @@ -0,0 +1,18 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.8 + +package websocket + +func (c *Conn) writeBufs(bufs ...[]byte) error { + for _, buf := range bufs { + if len(buf) > 0 { + if _, err := c.conn.Write(buf); err != nil { + return err + } + } + } + return nil +} diff --git a/vendor/src/github.com/gorilla/websocket/doc.go b/vendor/src/github.com/gorilla/websocket/doc.go index e291a952..8db0cef9 100755 --- a/vendor/src/github.com/gorilla/websocket/doc.go +++ b/vendor/src/github.com/gorilla/websocket/doc.go @@ -6,9 +6,8 @@ // // Overview // -// The Conn type represents a WebSocket connection. A server application uses -// the Upgrade function from an Upgrader object with a HTTP request handler -// to get a pointer to a Conn: +// The Conn type represents a WebSocket connection. A server application calls +// the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: // // var upgrader = websocket.Upgrader{ // ReadBufferSize: 1024, @@ -31,10 +30,12 @@ // for { // messageType, p, err := conn.ReadMessage() // if err != nil { +// log.Println(err) // return // } -// if err = conn.WriteMessage(messageType, p); err != nil { -// return err +// if err := conn.WriteMessage(messageType, p); err != nil { +// log.Println(err) +// return // } // } // @@ -85,20 +86,26 @@ // and pong. Call the connection WriteControl, WriteMessage or NextWriter // methods to send a control message to the peer. // -// Connections handle received close messages by sending a close message to the -// peer and returning a *CloseError from the the NextReader, ReadMessage or the -// message Read method. +// Connections handle received close messages by calling the handler function +// set with the SetCloseHandler method and by returning a *CloseError from the +// NextReader, ReadMessage or the message Read method. The default close +// handler sends a close message to the peer. // -// Connections handle received ping and pong messages by invoking callback -// functions set with SetPingHandler and SetPongHandler methods. The callback -// functions are called from the NextReader, ReadMessage and the message Read -// methods. +// Connections handle received ping messages by calling the handler function +// set with the SetPingHandler method. The default ping handler sends a pong +// message to the peer. +// +// Connections handle received pong messages by calling the handler function +// set with the SetPongHandler method. The default pong handler does nothing. +// If an application sends ping messages, then the application should set a +// pong handler to receive the corresponding pong. // -// The default ping handler sends a pong to the peer. The application's reading -// goroutine can block for a short time while the handler writes the pong data -// to the connection. +// The control message handler functions are called from the NextReader, +// ReadMessage and message reader Read methods. The default close and ping +// handlers can block these methods for a short time when the handler writes to +// the connection. // -// The application must read the connection to process ping, pong and close +// The application must read the connection to process close, ping and pong // messages sent from the peer. If the application is not otherwise interested // in messages from the peer, then the application should start a goroutine to // read and discard messages from the peer. A simple example is: @@ -137,19 +144,59 @@ // method fails the WebSocket handshake with HTTP status 403. // // If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail -// the handshake if the Origin request header is present and not equal to the -// Host request header. -// -// An application can allow connections from any origin by specifying a -// function that always returns true: -// -// var upgrader = websocket.Upgrader{ -// CheckOrigin: func(r *http.Request) bool { return true }, -// } -// -// The deprecated Upgrade function does not enforce an origin policy. It's the -// application's responsibility to check the Origin header before calling -// Upgrade. +// the handshake if the Origin request header is present and the Origin host is +// not equal to the Host request header. +// +// The deprecated package-level Upgrade function does not perform origin +// checking. The application is responsible for checking the Origin header +// before calling the Upgrade function. +// +// Buffers +// +// Connections buffer network input and output to reduce the number +// of system calls when reading or writing messages. +// +// Write buffers are also used for constructing WebSocket frames. See RFC 6455, +// Section 5 for a discussion of message framing. A WebSocket frame header is +// written to the network each time a write buffer is flushed to the network. +// Decreasing the size of the write buffer can increase the amount of framing +// overhead on the connection. +// +// The buffer sizes in bytes are specified by the ReadBufferSize and +// WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default +// size of 4096 when a buffer size field is set to zero. The Upgrader reuses +// buffers created by the HTTP server when a buffer size field is set to zero. +// The HTTP server buffers have a size of 4096 at the time of this writing. +// +// The buffer sizes do not limit the size of a message that can be read or +// written by a connection. +// +// Buffers are held for the lifetime of the connection by default. If the +// Dialer or Upgrader WriteBufferPool field is set, then a connection holds the +// write buffer only when writing a message. +// +// Applications should tune the buffer sizes to balance memory use and +// performance. Increasing the buffer size uses more memory, but can reduce the +// number of system calls to read or write the network. In the case of writing, +// increasing the buffer size can reduce the number of frame headers written to +// the network. +// +// Some guidelines for setting buffer parameters are: +// +// Limit the buffer sizes to the maximum expected message size. Buffers larger +// than the largest message do not provide any benefit. +// +// Depending on the distribution of message sizes, setting the buffer size to +// a value less than the maximum expected message size can greatly reduce memory +// use with a small impact on performance. Here's an example: If 99% of the +// messages are smaller than 256 bytes and the maximum message size is 512 +// bytes, then a buffer size of 256 bytes will result in 1.01 more system calls +// than a buffer size of 512 bytes. The memory savings is 50%. +// +// A write buffer pool is useful when the application has a modest number +// writes over a large number of connections. when buffers are pooled, a larger +// buffer size has a reduced impact on total memory use and has the benefit of +// reducing system calls and frame overhead. // // Compression EXPERIMENTAL // diff --git a/vendor/src/github.com/gorilla/websocket/example_test.go b/vendor/src/github.com/gorilla/websocket/example_test.go index 96449eac..cd1883b8 100755 --- a/vendor/src/github.com/gorilla/websocket/example_test.go +++ b/vendor/src/github.com/gorilla/websocket/example_test.go @@ -23,10 +23,9 @@ var ( // This server application works with a client application running in the // browser. The client application does not explicitly close the websocket. The // only expected close message from the client has the code -// websocket.CloseGoingAway. All other other close messages are likely the +// websocket.CloseGoingAway. All other close messages are likely the // result of an application or protocol error and are logged to aid debugging. func ExampleIsUnexpectedCloseError() { - for { messageType, p, err := c.ReadMessage() if err != nil { @@ -35,11 +34,11 @@ func ExampleIsUnexpectedCloseError() { } return } - processMesage(messageType, p) + processMessage(messageType, p) } } -func processMesage(mt int, p []byte) {} +func processMessage(mt int, p []byte) {} // TestX prevents godoc from showing this entire file in the example. Remove // this function when a second example is added. diff --git a/vendor/src/github.com/gorilla/websocket/examples/autobahn/README.md b/vendor/src/github.com/gorilla/websocket/examples/autobahn/README.md index 075ac153..dde8525a 100755 --- a/vendor/src/github.com/gorilla/websocket/examples/autobahn/README.md +++ b/vendor/src/github.com/gorilla/websocket/examples/autobahn/README.md @@ -1,6 +1,6 @@ # Test Server -This package contains a server for the [Autobahn WebSockets Test Suite](http://autobahn.ws/testsuite). +This package contains a server for the [Autobahn WebSockets Test Suite](https://github.com/crossbario/autobahn-testsuite). To test the server, run diff --git a/vendor/src/github.com/gorilla/websocket/examples/autobahn/server.go b/vendor/src/github.com/gorilla/websocket/examples/autobahn/server.go index 3db880f9..c2d6ee50 100755 --- a/vendor/src/github.com/gorilla/websocket/examples/autobahn/server.go +++ b/vendor/src/github.com/gorilla/websocket/examples/autobahn/server.go @@ -157,11 +157,11 @@ func echoReadAllWritePreparedMessage(w http.ResponseWriter, r *http.Request) { func serveHome(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - http.Error(w, "Not found.", 404) + http.Error(w, "Not found.", http.StatusNotFound) return } if r.Method != "GET" { - http.Error(w, "Method not allowed", 405) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } w.Header().Set("Content-Type", "text/html; charset=utf-8") diff --git a/vendor/src/github.com/gorilla/websocket/examples/chat/README.md b/vendor/src/github.com/gorilla/websocket/examples/chat/README.md index 47c82f90..7baf3e32 100755 --- a/vendor/src/github.com/gorilla/websocket/examples/chat/README.md +++ b/vendor/src/github.com/gorilla/websocket/examples/chat/README.md @@ -1,6 +1,6 @@ # Chat Example -This application shows how to use use the +This application shows how to use the [websocket](https://github.com/gorilla/websocket) package to implement a simple web chat application. diff --git a/vendor/src/github.com/gorilla/websocket/examples/chat/client.go b/vendor/src/github.com/gorilla/websocket/examples/chat/client.go index 26468477..9461c1ea 100755 --- a/vendor/src/github.com/gorilla/websocket/examples/chat/client.go +++ b/vendor/src/github.com/gorilla/websocket/examples/chat/client.go @@ -64,7 +64,7 @@ func (c *Client) readPump() { for { _, message, err := c.conn.ReadMessage() if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("error: %v", err) } break @@ -113,7 +113,7 @@ func (c *Client) writePump() { } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) - if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } @@ -129,6 +129,9 @@ func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { } client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} client.hub.register <- client + + // Allow collection of memory referenced by the caller by doing all work in + // new goroutines. go client.writePump() - client.readPump() + go client.readPump() } diff --git a/vendor/src/github.com/gorilla/websocket/examples/chat/home.html b/vendor/src/github.com/gorilla/websocket/examples/chat/home.html index a39a0c27..bf866aff 100755 --- a/vendor/src/github.com/gorilla/websocket/examples/chat/home.html +++ b/vendor/src/github.com/gorilla/websocket/examples/chat/home.html @@ -92,7 +92,7 @@
- +
diff --git a/vendor/src/github.com/gorilla/websocket/examples/chat/hub.go b/vendor/src/github.com/gorilla/websocket/examples/chat/hub.go index 7f07ea07..bb5c0e3b 100755 --- a/vendor/src/github.com/gorilla/websocket/examples/chat/hub.go +++ b/vendor/src/github.com/gorilla/websocket/examples/chat/hub.go @@ -4,7 +4,7 @@ package main -// hub maintains the set of active clients and broadcasts messages to the +// Hub maintains the set of active clients and broadcasts messages to the // clients. type Hub struct { // Registered clients. diff --git a/vendor/src/github.com/gorilla/websocket/examples/chat/main.go b/vendor/src/github.com/gorilla/websocket/examples/chat/main.go index 74615d59..9d4737a6 100755 --- a/vendor/src/github.com/gorilla/websocket/examples/chat/main.go +++ b/vendor/src/github.com/gorilla/websocket/examples/chat/main.go @@ -15,11 +15,11 @@ var addr = flag.String("addr", ":8080", "http service address") func serveHome(w http.ResponseWriter, r *http.Request) { log.Println(r.URL) if r.URL.Path != "/" { - http.Error(w, "Not found", 404) + http.Error(w, "Not found", http.StatusNotFound) return } if r.Method != "GET" { - http.Error(w, "Method not allowed", 405) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } http.ServeFile(w, r, "home.html") diff --git a/vendor/src/github.com/gorilla/websocket/examples/command/main.go b/vendor/src/github.com/gorilla/websocket/examples/command/main.go index 239c5c85..304f1a52 100755 --- a/vendor/src/github.com/gorilla/websocket/examples/command/main.go +++ b/vendor/src/github.com/gorilla/websocket/examples/command/main.go @@ -167,11 +167,11 @@ func serveWs(w http.ResponseWriter, r *http.Request) { func serveHome(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - http.Error(w, "Not found", 404) + http.Error(w, "Not found", http.StatusNotFound) return } if r.Method != "GET" { - http.Error(w, "Method not allowed", 405) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } http.ServeFile(w, r, "home.html") diff --git a/vendor/src/github.com/gorilla/websocket/examples/echo/client.go b/vendor/src/github.com/gorilla/websocket/examples/echo/client.go index 6578094e..bf0e6573 100755 --- a/vendor/src/github.com/gorilla/websocket/examples/echo/client.go +++ b/vendor/src/github.com/gorilla/websocket/examples/echo/client.go @@ -38,7 +38,6 @@ func main() { done := make(chan struct{}) go func() { - defer c.Close() defer close(done) for { _, message, err := c.ReadMessage() @@ -55,6 +54,8 @@ func main() { for { select { + case <-done: + return case t := <-ticker.C: err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) if err != nil { @@ -63,8 +64,9 @@ func main() { } case <-interrupt: log.Println("interrupt") - // To cleanly close a connection, a client should send a close - // frame and wait for the server to close the connection. + + // Cleanly close the connection by sending a close message and then + // waiting (with timeout) for the server to close the connection. err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { log.Println("write close:", err) @@ -74,7 +76,6 @@ func main() { case <-done: case <-time.After(time.Second): } - c.Close() return } } diff --git a/vendor/src/github.com/gorilla/websocket/examples/echo/server.go b/vendor/src/github.com/gorilla/websocket/examples/echo/server.go index a685b097..61f3dd9b 100755 --- a/vendor/src/github.com/gorilla/websocket/examples/echo/server.go +++ b/vendor/src/github.com/gorilla/websocket/examples/echo/server.go @@ -55,6 +55,7 @@ func main() { var homeTemplate = template.Must(template.New("").Parse(` +