diff --git a/ci/test.sh b/ci/test.sh index 1d4a8b07..3c476d93 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -10,18 +10,16 @@ argv=( "--junitfile=ci/out/websocket/testReport.xml" "--format=short-verbose" -- - -race "-vet=off" - "-bench=." ) -# Interactive usage probably does not want to enable benchmarks, race detection -# turn off vet or use gotestsum by default. +# Interactive usage does not want to turn off vet or use gotestsum by default. if [[ $# -gt 0 ]]; then argv=(go test "$@") fi -# We always want coverage. +# We always want coverage and race detection. argv+=( + -race "-coverprofile=ci/out/coverage.prof" "-coverpkg=./..." ) diff --git a/cmp_test.go b/cmp_test.go new file mode 100644 index 00000000..ad4cd75a --- /dev/null +++ b/cmp_test.go @@ -0,0 +1,53 @@ +package websocket_test + +import ( + "reflect" + + "github.com/google/go-cmp/cmp" +) + +// https://github.com/google/go-cmp/issues/40#issuecomment-328615283 +func cmpDiff(exp, act interface{}) string { + return cmp.Diff(exp, act, deepAllowUnexported(exp, act)) +} + +func deepAllowUnexported(vs ...interface{}) cmp.Option { + m := make(map[reflect.Type]struct{}) + for _, v := range vs { + structTypes(reflect.ValueOf(v), m) + } + var typs []interface{} + for t := range m { + typs = append(typs, reflect.New(t).Elem().Interface()) + } + return cmp.AllowUnexported(typs...) +} + +func structTypes(v reflect.Value, m map[reflect.Type]struct{}) { + if !v.IsValid() { + return + } + switch v.Kind() { + case reflect.Ptr: + if !v.IsNil() { + structTypes(v.Elem(), m) + } + case reflect.Interface: + if !v.IsNil() { + structTypes(v.Elem(), m) + } + case reflect.Slice, reflect.Array: + for i := 0; i < v.Len(); i++ { + structTypes(v.Index(i), m) + } + case reflect.Map: + for _, k := range v.MapKeys() { + structTypes(v.MapIndex(k), m) + } + case reflect.Struct: + m[v.Type()] = struct{}{} + for i := 0; i < v.NumField(); i++ { + structTypes(v.Field(i), m) + } + } +} diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index f003e743..74a25540 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -43,12 +43,10 @@ For coverage details locally, please see `ci/out/coverage.html` after running `c See [ci/image/Dockerfile](ci/image/Dockerfile) for the installation of the CI dependencies on Ubuntu. -You can also run tests normally with `go test`. -`ci/test.sh` just passes a default set of flags to `go test` to collect coverage, -enable the race detector, run benchmarks and also prettifies the output. +You can also run tests normally with `go test`. `ci/test.sh` just passes a default set of flags to +`go test` to collect coverage, enable the race detector and also prettifies the output. -If you pass flags to `ci/test.sh`, it will pass those flags directly to `go test` but will also -collect coverage for you. This is nice for when you don't want to wait for benchmarks -or the race detector but want to have coverage. +You can pass flags to `ci/test.sh` if you want to run a specific test or otherwise +control the behaviour of `go test`. Coverage percentage from codecov and the CI scripts will be different because they are calculated differently. diff --git a/go.mod b/go.mod index 35d500dd..c9cc6fc4 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,8 @@ require ( github.com/google/go-cmp v0.2.0 github.com/kr/pretty v0.1.0 // indirect go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16 + go.uber.org/atomic v1.4.0 // indirect + go.uber.org/multierr v1.1.0 golang.org/x/lint v0.0.0-20190409202823-959b441ac422 golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6 golang.org/x/text v0.3.2 // indirect diff --git a/go.sum b/go.sum index b9e3737c..187a2285 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,10 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16 h1:3gGa1bM0nG7Ruhu5b7wKnoOOwAD/fJ8iyyAcpOzDG3A= go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16/go.mod h1:iKV5yK9t+J5nG9O3uF6KYdPEz3dyfMyB15MN1rbQ8Qw= +go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= golang.org/x/crypto v0.0.0-20180426230345-b49d69b5da94/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/websocket.go b/websocket.go index 833c1209..6f28a4bf 100644 --- a/websocket.go +++ b/websocket.go @@ -166,9 +166,9 @@ func (c *Conn) timeoutLoop() { case readCtx = <-c.setReadTimeout: case <-readCtx.Done(): - c.close(xerrors.Errorf("data read timed out: %w", readCtx.Err())) + c.close(xerrors.Errorf("read timed out: %w", readCtx.Err())) case <-writeCtx.Done(): - c.close(xerrors.Errorf("data write timed out: %w", writeCtx.Err())) + c.close(xerrors.Errorf("write timed out: %w", writeCtx.Err())) } } } diff --git a/websocket_test.go b/websocket_test.go index 1963ce70..e6529f3b 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -23,9 +23,8 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" - "github.com/golang/protobuf/ptypes/duration" "github.com/golang/protobuf/ptypes/timestamp" - "github.com/google/go-cmp/cmp" + "go.uber.org/multierr" "golang.org/x/xerrors" "nhooyr.io/websocket" @@ -41,103 +40,6 @@ func TestHandshake(t *testing.T) { client func(ctx context.Context, url string) error server func(w http.ResponseWriter, r *http.Request) error }{ - { - name: "handshake", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"myproto"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - return nil - }, - client: func(ctx context.Context, u string) error { - c, resp, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - Subprotocols: []string{"myproto"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - checkHeader := func(h, exp string) { - t.Helper() - value := resp.Header.Get(h) - if exp != value { - t.Errorf("expected different value for header %v: %v", h, cmp.Diff(exp, value)) - } - } - - checkHeader("Connection", "Upgrade") - checkHeader("Upgrade", "websocket") - checkHeader("Sec-WebSocket-Protocol", "myproto") - - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - }, - { - name: "defaultSubprotocol", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - if c.Subprotocol() != "" { - return xerrors.Errorf("unexpected subprotocol: %v", c.Subprotocol()) - } - return nil - }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - Subprotocols: []string{"meow"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - if c.Subprotocol() != "" { - return xerrors.Errorf("unexpected subprotocol: %v", c.Subprotocol()) - } - return nil - }, - }, - { - name: "subprotocol", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo", "lar"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - if c.Subprotocol() != "echo" { - return xerrors.Errorf("unexpected subprotocol: %q", c.Subprotocol()) - } - return nil - }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - Subprotocols: []string{"poof", "echo"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - if c.Subprotocol() != "echo" { - return xerrors.Errorf("unexpected subprotocol: %q", c.Subprotocol()) - } - return nil - }, - }, { name: "badOrigin", server: func(w http.ResponseWriter, r *http.Request) error { @@ -174,7 +76,7 @@ func TestHandshake(t *testing.T) { if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") + c.Close(websocket.StatusNormalClosure, "") return nil }, client: func(ctx context.Context, u string) error { @@ -186,7 +88,7 @@ func TestHandshake(t *testing.T) { if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") + c.Close(websocket.StatusNormalClosure, "") return nil }, }, @@ -199,7 +101,7 @@ func TestHandshake(t *testing.T) { if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") + c.Close(websocket.StatusNormalClosure, "") return nil }, client: func(ctx context.Context, u string) error { @@ -211,7 +113,7 @@ func TestHandshake(t *testing.T) { if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") + c.Close(websocket.StatusNormalClosure, "") return nil }, }, @@ -229,7 +131,7 @@ func TestHandshake(t *testing.T) { if err != nil { return err } - c.Close(websocket.StatusInternalError, "") + c.Close(websocket.StatusNormalClosure, "") return nil }, client: func(ctx context.Context, u string) error { @@ -257,7 +159,7 @@ func TestHandshake(t *testing.T) { if err != nil { return err } - c.Close(websocket.StatusInternalError, "") + c.Close(websocket.StatusNormalClosure, "") return nil }, }, @@ -288,33 +190,76 @@ func TestConn(t *testing.T) { t.Parallel() testCases := []struct { - name string - client func(ctx context.Context, c *websocket.Conn) error - server func(ctx context.Context, c *websocket.Conn) error + name string + + acceptOpts *websocket.AcceptOptions + server func(ctx context.Context, c *websocket.Conn) error + + dialOpts *websocket.DialOptions + response func(resp *http.Response) error + client func(ctx context.Context, c *websocket.Conn) error }{ + { + name: "handshake", + acceptOpts: &websocket.AcceptOptions{ + Subprotocols: []string{"myproto"}, + }, + dialOpts: &websocket.DialOptions{ + Subprotocols: []string{"myproto"}, + }, + response: func(resp *http.Response) error { + headers := map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Protocol": "myproto", + } + for h, exp := range headers { + value := resp.Header.Get(h) + err := assertEqualf(exp, value, "unexpected value for header %v", h) + if err != nil { + return err + } + } + return nil + }, + }, + { + name: "handshake/defaultSubprotocol", + server: func(ctx context.Context, c *websocket.Conn) error { + return assertSubprotocol(c, "") + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return assertSubprotocol(c, "") + }, + }, + { + name: "handshake/subprotocolPriority", + acceptOpts: &websocket.AcceptOptions{ + Subprotocols: []string{"echo", "lar"}, + }, + server: func(ctx context.Context, c *websocket.Conn) error { + return assertSubprotocol(c, "echo") + }, + dialOpts: &websocket.DialOptions{ + Subprotocols: []string{"poof", "echo"}, + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return assertSubprotocol(c, "echo") + }, + }, { name: "closeError", server: func(ctx context.Context, c *websocket.Conn) error { return wsjson.Write(ctx, c, "hello") }, client: func(ctx context.Context, c *websocket.Conn) error { - var m string - err := wsjson.Read(ctx, c, &m) + err := assertJSONRead(ctx, c, "hello") if err != nil { return err } - if m != "hello" { - return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m) - } - _, _, err = c.Reader(ctx) - var cerr websocket.CloseError - if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError { - return xerrors.Errorf("unexpected error: %+v", err) - } - - return nil + return assertCloseStatus(err, websocket.StatusInternalError) }, }, { @@ -327,11 +272,13 @@ func TestConn(t *testing.T) { time.Sleep(1) nc.SetWriteDeadline(time.Now().Add(time.Second * 15)) - if nc.LocalAddr() != (websocket.Addr{}) { - return xerrors.Errorf("net conn local address is not equal to websocket.Addr") + err := assertEqualf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr") + if err != nil { + return err } - if nc.RemoteAddr() != (websocket.Addr{}) { - return xerrors.Errorf("net conn remote address is not equal to websocket.Addr") + err = assertEqualf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr") + if err != nil { + return err } for i := 0; i < 3; i++ { @@ -345,62 +292,38 @@ func TestConn(t *testing.T) { }, client: func(ctx context.Context, c *websocket.Conn) error { nc := websocket.NetConn(c, websocket.MessageBinary) - defer nc.Close() nc.SetReadDeadline(time.Time{}) time.Sleep(1) nc.SetReadDeadline(time.Now().Add(time.Second * 15)) - read := func() error { - p := make([]byte, len("hello")) - // We do not use io.ReadFull here as it masks EOFs. - // See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024 - _, err := nc.Read(p) - if err != nil { - return err - } - - if string(p) != "hello" { - return xerrors.Errorf("unexpected payload %q received", string(p)) - } - return nil - } - for i := 0; i < 3; i++ { - err := read() + err := assertNetConnRead(nc, "hello") if err != nil { return err } } // Ensure the close frame is converted to an EOF and multiple read's after all return EOF. - err := read() - if err != io.EOF { - return err - } - - err = read() - if err != io.EOF { + err2 := assertNetConnRead(nc, "hello") + err := assertEqualf(io.EOF, err2, "unexpected error") + if err != nil { return err } - return nil + err2 = assertNetConnRead(nc, "hello") + return assertEqualf(io.EOF, err2, "unexpected error") }, }, { name: "netConn/badReadMsgType", server: func(ctx context.Context, c *websocket.Conn) error { nc := websocket.NetConn(c, websocket.MessageBinary) - defer nc.Close() nc.SetDeadline(time.Now().Add(time.Second * 15)) _, err := nc.Read(make([]byte, 1)) - if err == nil || !strings.Contains(err.Error(), "unexpected frame type read") { - return xerrors.Errorf("expected error: %+v", err) - } - - return nil + return assertErrorContains(err, "unexpected frame type") }, client: func(ctx context.Context, c *websocket.Conn) error { err := wsjson.Write(ctx, c, "meow") @@ -409,12 +332,7 @@ func TestConn(t *testing.T) { } _, _, err = c.Read(ctx) - cerr := &websocket.CloseError{} - if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusUnsupportedData { - return xerrors.Errorf("expected close error with code StatusUnsupportedData: %+v", err) - } - - return nil + return assertCloseStatus(err, websocket.StatusUnsupportedData) }, }, { @@ -425,205 +343,105 @@ func TestConn(t *testing.T) { nc.SetDeadline(time.Now().Add(time.Second * 15)) - _, err := nc.Read(make([]byte, 1)) - cerr := &websocket.CloseError{} - if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusBadGateway { - return xerrors.Errorf("expected close error with code StatusBadGateway: %+v", err) - } - - _, err = nc.Write([]byte{0xff}) - if err == nil || !strings.Contains(err.Error(), "websocket closed") { - return xerrors.Errorf("expected writes to fail after reading a close frame: %v", err) + _, err2 := nc.Read(make([]byte, 1)) + err := assertCloseStatus(err2, websocket.StatusBadGateway) + if err != nil { + return err } - return nil + _, err2 = nc.Write([]byte{0xff}) + return assertErrorContains(err2, "websocket closed") }, client: func(ctx context.Context, c *websocket.Conn) error { return c.Close(websocket.StatusBadGateway, "") }, }, { - name: "jsonEcho", + name: "wsjson/echo", server: func(ctx context.Context, c *websocket.Conn) error { - write := func() error { - v := map[string]interface{}{ - "anmol": "wowow", - } - err := wsjson.Write(ctx, c, v) - return err - } - err := write() - if err != nil { - return err - } - err = write() - if err != nil { - return err - } - - c.Close(websocket.StatusNormalClosure, "") - return nil + return wsjson.Write(ctx, c, "meow") }, client: func(ctx context.Context, c *websocket.Conn) error { - read := func() error { - var v interface{} - err := wsjson.Read(ctx, c, &v) - if err != nil { - return err - } - - exp := map[string]interface{}{ - "anmol": "wowow", - } - if !reflect.DeepEqual(exp, v) { - return xerrors.Errorf("expected %v but got %v", exp, v) - } - return nil - } - err := read() - if err != nil { - return err - } - err = read() - if err != nil { - return err - } - - c.Close(websocket.StatusNormalClosure, "") - return nil + return assertJSONRead(ctx, c, "meow") }, }, { - name: "protobufEcho", + name: "protobuf/echo", server: func(ctx context.Context, c *websocket.Conn) error { - write := func() error { - err := wspb.Write(ctx, c, ptypes.DurationProto(100)) - return err - } - err := write() - if err != nil { - return err - } - - c.Close(websocket.StatusNormalClosure, "") - return nil + return wspb.Write(ctx, c, ptypes.DurationProto(100)) }, client: func(ctx context.Context, c *websocket.Conn) error { - read := func() error { - var v duration.Duration - err := wspb.Read(ctx, c, &v) - if err != nil { - return err - } - - d, err := ptypes.Duration(&v) - if err != nil { - return xerrors.Errorf("failed to convert duration.Duration to time.Duration: %w", err) - } - const exp = time.Duration(100) - if !reflect.DeepEqual(exp, d) { - return xerrors.Errorf("expected %v but got %v", exp, d) - } - return nil - } - err := read() - if err != nil { - return err - } - - c.Close(websocket.StatusNormalClosure, "") - return nil + return assertProtobufRead(ctx, c, ptypes.DurationProto(100)) }, }, { name: "ping", server: func(ctx context.Context, c *websocket.Conn) error { - errc := make(chan error, 1) - go func() { - _, _, err2 := c.Read(ctx) - errc <- err2 - }() + ctx = c.CloseRead(ctx) err := c.Ping(ctx) if err != nil { return err } - err = c.Write(ctx, websocket.MessageText, []byte("hi")) + err = wsjson.Write(ctx, c, "hi") if err != nil { return err } - err = <-errc - var ce websocket.CloseError - if xerrors.As(err, &ce) && ce.Code == websocket.StatusNormalClosure { - return nil - } - return xerrors.Errorf("unexpected error: %w", err) + <-ctx.Done() + err = c.Ping(context.Background()) + return assertCloseStatus(err, websocket.StatusNormalClosure) }, client: func(ctx context.Context, c *websocket.Conn) error { // We read a message from the connection and then keep reading until // the Ping completes. - done := make(chan struct{}) + pingErrc := make(chan error, 1) go func() { - _, _, err := c.Read(ctx) - if err != nil { - c.Close(websocket.StatusInternalError, err.Error()) - return - } - - close(done) - - c.Read(ctx) + pingErrc <- c.Ping(ctx) }() - err := c.Ping(ctx) + // Once this completes successfully, that means they sent their ping and we responded to it. + err := assertJSONRead(ctx, c, "hi") if err != nil { return err } - <-done + // Now we need to ensure we're reading for their pong from our ping. + // Need new var to not race with above goroutine. + ctx2 := c.CloseRead(ctx) - c.Close(websocket.StatusNormalClosure, "") - return nil + // Now we wait for our pong. + select { + case err = <-pingErrc: + return err + case <-ctx2.Done(): + return xerrors.Errorf("failed to wait for pong: %w", ctx2.Err()) + } }, }, { name: "readLimit", server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - if err == nil || !strings.Contains(err.Error(), "read limited at") { - return xerrors.Errorf("expected error but got nil: %+v", err) - } - return nil + _, _, err2 := c.Read(ctx) + return assertErrorContains(err2, "read limited at 32768 bytes") }, client: func(ctx context.Context, c *websocket.Conn) error { - c.CloseRead(ctx) - err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) if err != nil { return err } - err = c.Ping(ctx) - - var ce websocket.CloseError - if !xerrors.As(err, &ce) || ce.Code != websocket.StatusMessageTooBig { - return xerrors.Errorf("unexpected error: %w", err) - } - - return nil + _, _, err2 := c.Read(ctx) + return assertCloseStatus(err2, websocket.StatusMessageTooBig) }, }, { name: "wsjson/binary", server: func(ctx context.Context, c *websocket.Conn) error { var v interface{} - err := wsjson.Read(ctx, c, &v) - if err == nil || !strings.Contains(err.Error(), "unexpected frame type") { - return xerrors.Errorf("expected error: %v", err) - } - return nil + err2 := wsjson.Read(ctx, c, &v) + return assertErrorContains(err2, "unexpected frame type") }, client: func(ctx context.Context, c *websocket.Conn) error { return wspb.Write(ctx, c, ptypes.DurationProto(100)) @@ -633,11 +451,8 @@ func TestConn(t *testing.T) { name: "wsjson/badRead", server: func(ctx context.Context, c *websocket.Conn) error { var v interface{} - err := wsjson.Read(ctx, c, &v) - if err == nil || !strings.Contains(err.Error(), "failed to unmarshal json") { - return xerrors.Errorf("expected error: %v", err) - } - return nil + err2 := wsjson.Read(ctx, c, &v) + return assertErrorContains(err2, "failed to unmarshal json") }, client: func(ctx context.Context, c *websocket.Conn) error { return c.Write(ctx, websocket.MessageText, []byte("notjson")) @@ -646,18 +461,12 @@ func TestConn(t *testing.T) { { name: "wsjson/badWrite", server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - if err == nil || !strings.Contains(err.Error(), "StatusInternalError") { - return xerrors.Errorf("expected error: %v", err) - } - return nil + _, _, err2 := c.Read(ctx) + return assertCloseStatus(err2, websocket.StatusNormalClosure) }, client: func(ctx context.Context, c *websocket.Conn) error { err := wsjson.Write(ctx, c, fmt.Println) - if err == nil { - return xerrors.Errorf("expected error: %v", err) - } - return nil + return assertErrorContains(err, "failed to encode json") }, }, { @@ -665,10 +474,7 @@ func TestConn(t *testing.T) { server: func(ctx context.Context, c *websocket.Conn) error { var v proto.Message err := wspb.Read(ctx, c, v) - if err == nil || !strings.Contains(err.Error(), "unexpected frame type") { - return xerrors.Errorf("expected error: %v", err) - } - return nil + return assertErrorContains(err, "unexpected frame type") }, client: func(ctx context.Context, c *websocket.Conn) error { return wsjson.Write(ctx, c, "hi") @@ -679,10 +485,7 @@ func TestConn(t *testing.T) { server: func(ctx context.Context, c *websocket.Conn) error { var v timestamp.Timestamp err := wspb.Read(ctx, c, &v) - if err == nil || !strings.Contains(err.Error(), "failed to unmarshal protobuf") { - return xerrors.Errorf("expected error: %v", err) - } - return nil + return assertErrorContains(err, "failed to unmarshal protobuf") }, client: func(ctx context.Context, c *websocket.Conn) error { return c.Write(ctx, websocket.MessageBinary, []byte("notpb")) @@ -692,17 +495,11 @@ func TestConn(t *testing.T) { name: "wspb/badWrite", server: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil || !strings.Contains(err.Error(), "StatusInternalError") { - return xerrors.Errorf("expected error: %v", err) - } - return nil + return assertCloseStatus(err, websocket.StatusNormalClosure) }, client: func(ctx context.Context, c *websocket.Conn) error { err := wspb.Write(ctx, c, nil) - if err == nil { - return xerrors.Errorf("expected error: %v", err) - } - return nil + return assertErrorIs(proto.ErrNil, err) }, }, { @@ -712,11 +509,7 @@ func TestConn(t *testing.T) { }, client: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - cerr := &websocket.CloseError{} - if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusInternalError { - return xerrors.Errorf("expected close error with StatusInternalError: %+v", err) - } - return nil + return assertCloseStatus(err, websocket.StatusInternalError) }, }, { @@ -725,14 +518,16 @@ func TestConn(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() err := c.Ping(ctx) - if err == nil || !xerrors.Is(err, context.DeadlineExceeded) { - return xerrors.Errorf("expected nil error: %+v", err) - } - return nil + return assertErrorIs(context.DeadlineExceeded, err) }, client: func(ctx context.Context, c *websocket.Conn) error { - c.Read(ctx) - return nil + _, _, err := c.Read(ctx) + err1 := assertErrorContains(err, "connection reset") + err2 := assertErrorIs(io.EOF, err) + if err1 != nil || err2 != nil { + return nil + } + return multierr.Combine(err1, err2) }, }, { @@ -743,14 +538,11 @@ func TestConn(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() err := c.Write(ctx, websocket.MessageBinary, []byte("meow")) - if !xerrors.Is(err, context.DeadlineExceeded) { - return xerrors.Errorf("expected deadline exceeded error: %+v", err) - } - return nil + return assertErrorIs(context.DeadlineExceeded, err) }, client: func(ctx context.Context, c *websocket.Conn) error { - time.Sleep(time.Second) - return nil + _, _, err := c.Read(ctx) + return assertErrorIs(io.EOF, err) }, }, { @@ -759,14 +551,11 @@ func TestConn(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() _, _, err := c.Read(ctx) - if !xerrors.Is(err, context.DeadlineExceeded) { - return xerrors.Errorf("expected deadline exceeded error: %+v", err) - } - return nil + return assertErrorIs(context.DeadlineExceeded, err) }, client: func(ctx context.Context, c *websocket.Conn) error { - c.Read(ctx) - return nil + _, _, err := c.Read(ctx) + return assertErrorIs(io.EOF, err) }, }, { @@ -777,18 +566,11 @@ func TestConn(t *testing.T) { return err } _, _, err = c.Read(ctx) - cerr := &websocket.CloseError{} - if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { - return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) - } - return nil + return assertErrorContains(err, "unknown opcode") }, client: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil || !strings.Contains(err.Error(), "opcode") { - return xerrors.Errorf("expected error that contains opcode: %+v", err) - } - return nil + return assertErrorContains(err, "unknown opcode") }, }, { @@ -821,18 +603,11 @@ func TestConn(t *testing.T) { return err } _, _, err = c.Read(ctx) - cerr := &websocket.CloseError{} - if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { - return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) - } - return nil + return assertCloseStatus(err, websocket.StatusProtocolError) }, client: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil || !strings.Contains(err.Error(), "too large") { - return xerrors.Errorf("expected error that contains too large: %+v", err) - } - return nil + return assertErrorContains(err, "too large") }, }, { @@ -847,18 +622,11 @@ func TestConn(t *testing.T) { return err } _, _, err = c.Read(ctx) - cerr := &websocket.CloseError{} - if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { - return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) - } - return nil + return assertCloseStatus(err, websocket.StatusProtocolError) }, client: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil || !strings.Contains(err.Error(), "fragmented") { - return xerrors.Errorf("expected error that contains fragmented: %+v", err) - } - return nil + return assertErrorContains(err, "fragmented") }, }, { @@ -869,18 +637,11 @@ func TestConn(t *testing.T) { return err } _, _, err = c.Read(ctx) - cerr := &websocket.CloseError{} - if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { - return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) - } - return nil + return assertCloseStatus(err, websocket.StatusProtocolError) }, client: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil || !strings.Contains(err.Error(), "invalid status code") { - return xerrors.Errorf("expected error that contains invalid status code: %+v", err) - } - return nil + return assertErrorContains(err, "invalid status code") }, }, { @@ -896,10 +657,7 @@ func TestConn(t *testing.T) { return err } _, _, err = c.Reader(ctx) - if err == nil || !strings.Contains(err.Error(), "previous message not read to completion") { - return xerrors.Errorf("expected non nil error: %v", err) - } - return nil + return assertErrorContains(err, "previous message not read to completion") }, client: func(ctx context.Context, c *websocket.Conn) error { err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 11))) @@ -907,10 +665,7 @@ func TestConn(t *testing.T) { return err } _, _, err = c.Read(ctx) - if err == nil { - return xerrors.Errorf("expected non nil error: %v", err) - } - return nil + return assertCloseStatus(err, websocket.StatusInternalError) }, }, { @@ -926,10 +681,7 @@ func TestConn(t *testing.T) { return err } _, _, err = c.Reader(ctx) - if err == nil || !strings.Contains(err.Error(), "previous message not read to completion") { - return xerrors.Errorf("expected non nil error: %v", err) - } - return nil + return assertErrorContains(err, "previous message not read to completion") }, client: func(ctx context.Context, c *websocket.Conn) error { w, err := c.Writer(ctx, websocket.MessageBinary) @@ -953,10 +705,7 @@ func TestConn(t *testing.T) { return xerrors.Errorf("failed to flush: %w", err) } _, _, err = c.Read(ctx) - if err == nil { - return xerrors.Errorf("expected non nil error: %v", err) - } - return nil + return assertCloseStatus(err, websocket.StatusInternalError) }, }, { @@ -972,10 +721,7 @@ func TestConn(t *testing.T) { return err } _, _, err = c.Reader(ctx) - if err == nil || !strings.Contains(err.Error(), "received new data message without finishing") { - return xerrors.Errorf("expected non nil error: %v", err) - } - return nil + return assertErrorContains(err, "received new data message without finishing") }, client: func(ctx context.Context, c *websocket.Conn) error { w, err := c.Writer(ctx, websocket.MessageBinary) @@ -995,27 +741,18 @@ func TestConn(t *testing.T) { return xerrors.Errorf("expected non nil error") } _, _, err = c.Read(ctx) - if err == nil || !strings.Contains(err.Error(), "received new data message without finishing") { - return xerrors.Errorf("expected non nil error: %v", err) - } - return nil + return assertErrorContains(err, "received new data message without finishing") }, }, { name: "continuationFrameWithoutDataFrame", server: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Reader(ctx) - if err == nil || !strings.Contains(err.Error(), "received continuation frame not after data") { - return xerrors.Errorf("expected non nil error: %v", err) - } - return nil + return assertErrorContains(err, "received continuation frame not after data") }, client: func(ctx context.Context, c *websocket.Conn) error { _, err := c.WriteFrame(ctx, false, websocket.OPContinuation, []byte(strings.Repeat("x", 10))) - if err != nil { - return xerrors.Errorf("expected non nil error") - } - return nil + return err }, }, { @@ -1031,21 +768,22 @@ func TestConn(t *testing.T) { if err != nil { return err } - _, b, err := c.Read(ctx) + err = assertEqualf("hi", v, "unexpected JSON") if err != nil { return err } - if string(b) != "hi" { - return xerrors.Errorf("expected hi but got %q", string(b)) + _, b, err := c.Read(ctx) + if err != nil { + return err } - return nil + return assertEqualf("hi", string(b), "unexpected JSON") }, client: func(ctx context.Context, c *websocket.Conn) error { err := wsjson.Write(ctx, c, "hi") if err != nil { return err } - return c.Write(ctx, websocket.MessageBinary, []byte("hi")) + return c.Write(ctx, websocket.MessageText, []byte("hi")) }, }, { @@ -1057,10 +795,7 @@ func TestConn(t *testing.T) { } p := make([]byte, 11) _, err = io.ReadFull(r, p) - if err == nil || !strings.Contains(err.Error(), "received new data message without finishing") { - return xerrors.Errorf("expected non nil error: %v", err) - } - return nil + return assertErrorContains(err, "received new data message without finishing") }, client: func(ctx context.Context, c *websocket.Conn) error { w, err := c.Writer(ctx, websocket.MessageBinary) @@ -1080,10 +815,7 @@ func TestConn(t *testing.T) { return xerrors.Errorf("expected non nil error") } _, _, err = c.Read(ctx) - if err == nil { - return xerrors.Errorf("expected non nil error: %v", err) - } - return nil + return assertCloseStatus(err, websocket.StatusProtocolError) }, }, { @@ -1098,10 +830,7 @@ func TestConn(t *testing.T) { return err } _, err = r.Read(make([]byte, 1)) - if err == nil || !strings.Contains(err.Error(), "cannot use EOFed reader") { - return xerrors.Errorf("expected non nil error: %+v", err) - } - return nil + return assertErrorContains(err, "cannot use EOFed reader") }, client: func(ctx context.Context, c *websocket.Conn) error { return c.Write(ctx, websocket.MessageBinary, []byte("hi")) @@ -1111,10 +840,7 @@ func TestConn(t *testing.T) { name: "eofInPayload", server: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil || !strings.Contains(err.Error(), "failed to read frame payload") { - return xerrors.Errorf("expected failed to read frame payload: %v", err) - } - return nil + return assertErrorContains(err, "failed to read frame payload") }, client: func(ctx context.Context, c *websocket.Conn) error { _, err := c.WriteHalfFrame(ctx) @@ -1131,11 +857,14 @@ func TestConn(t *testing.T) { tls := rand.Intn(2) == 1 s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) + c, err := websocket.Accept(w, r, tc.acceptOpts) if err != nil { return err } defer c.Close(websocket.StatusInternalError, "") + if tc.server == nil { + return nil + } return tc.server(r.Context(), c) }, tls) defer closeFn() @@ -1145,21 +874,35 @@ func TestConn(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - opts := &websocket.DialOptions{} + opts := tc.dialOpts if tls { + if opts == nil { + opts = &websocket.DialOptions{} + } opts.HTTPClient = s.Client() } - c, _, err := websocket.Dial(ctx, wsURL, opts) + c, resp, err := websocket.Dial(ctx, wsURL, opts) if err != nil { t.Fatal(err) } defer c.Close(websocket.StatusInternalError, "") - err = tc.client(ctx, c) - if err != nil { - t.Fatalf("client failed: %+v", err) + if tc.response != nil { + err = tc.response(resp) + if err != nil { + t.Fatalf("response asserter failed: %+v", err) + } } + + if tc.client != nil { + err = tc.client(ctx, c) + if err != nil { + t.Fatalf("client failed: %+v", err) + } + } + + c.Close(websocket.StatusNormalClosure, "") }) } } @@ -1174,7 +917,7 @@ func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) e atomic.AddInt64(&conns, 1) defer atomic.AddInt64(&conns, -1) - ctx, cancel := context.WithTimeout(r.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() r = r.WithContext(ctx) @@ -1598,3 +1341,67 @@ func BenchmarkConn(b *testing.B) { } }) } + +func assertCloseStatus(err error, code websocket.StatusCode) error { + var cerr websocket.CloseError + if !xerrors.As(err, &cerr) { + return xerrors.Errorf("no websocket close error in error chain: %+v", err) + } + return assertEqualf(code, cerr.Code, "unexpected status code") +} + +func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) (err error) { + var act interface{} + err = wsjson.Read(ctx, c, &act) + if err != nil { + return err + } + + return assertEqualf(exp, act, "unexpected JSON") +} + +func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) error { + expType := reflect.TypeOf(exp) + actv := reflect.New(expType.Elem()) + act := actv.Interface().(proto.Message) + err := wspb.Read(ctx, c, act) + if err != nil { + return err + } + + return assertEqualf(exp, act, "unexpected protobuf") +} + +func assertSubprotocol(c *websocket.Conn, exp string) error { + return assertEqualf(exp, c.Subprotocol(), "unexpected subprotocol") +} + +func assertEqualf(exp, act interface{}, f string, v ...interface{}) error { + if diff := cmpDiff(exp, act); diff != "" { + return xerrors.Errorf(f+": %v", append(v, diff)) + } + return nil +} + +func assertNetConnRead(r io.Reader, exp string) error { + act := make([]byte, len(exp)) + _, err := r.Read(act) + if err != nil { + return err + } + return assertEqualf(exp, string(act), "unexpected net conn read") +} + +func assertErrorContains(err error, exp string) error { + if err == nil || !strings.Contains(err.Error(), exp) { + return xerrors.Errorf("expected error that contains %q but got: %+v", exp, err) + } + return nil +} + +func assertErrorIs(exp, act error) error { + if !xerrors.Is(act, exp) { + return xerrors.Errorf("expected error %+v to be in %+v", exp, act) + } + return nil +}