From a2a2d31cb3d23134087d033f88b340bf3b25b686 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 1 Jul 2019 10:29:59 -0400 Subject: [PATCH 1/3] Add NetConn adapter Closes #100 --- netconn.go | 116 ++++++++++++++++++++++++++++++++++++++++++++++ websocket_test.go | 48 +++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 netconn.go diff --git a/netconn.go b/netconn.go new file mode 100644 index 00000000..0de2f1cb --- /dev/null +++ b/netconn.go @@ -0,0 +1,116 @@ +package websocket + +import ( + "context" + "golang.org/x/xerrors" + "io" + "math" + "net" + "time" +) + +// NetConn converts a *websocket.Conn into a net.Conn. +// Every Write to the net.Conn will correspond to a binary message +// write on *webscoket.Conn. +// Close will close the *websocket.Conn with StatusNormalClosure. +// When a deadline is hit, the connection will be closed. This is +// different from most net.Conn implementations where only the +// reading/writing goroutines are interrupted but the connection is kept alive. +// The Addr methods will return zero value net.TCPAddr. +func NetConn(c *Conn) net.Conn { + nc := &netConn{ + c: c, + } + + var cancel context.CancelFunc + nc.writeContext, cancel = context.WithCancel(context.Background()) + nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel) + nc.writeTimer.Stop() + + nc.readContext, cancel = context.WithCancel(context.Background()) + nc.readTimer = time.AfterFunc(math.MaxInt64, cancel) + nc.readTimer.Stop() + + return nc +} + +type netConn struct { + c *Conn + + writeTimer *time.Timer + writeContext context.Context + + readTimer *time.Timer + readContext context.Context + + reader io.Reader +} + +var _ net.Conn = &netConn{} + +func (c *netConn) Close() error { + return c.c.Close(StatusNormalClosure, "") +} + +func (c *netConn) Write(p []byte) (int, error) { + err := c.c.Write(c.writeContext, MessageBinary, p) + if err != nil { + return 0, err + } + return len(p), nil +} + +func (c *netConn) Read(p []byte) (int, error) { + if c.reader == nil { + typ, r, err := c.c.Reader(c.readContext) + if err != nil { + return 0, err + } + if typ != MessageBinary { + c.c.Close(StatusUnsupportedData, "can only accept binary messages") + return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", MessageBinary, typ) + } + c.reader = r + } + + n, err := c.reader.Read(p) + if err == io.EOF { + c.reader = nil + } + return n, err +} + +type unknownAddr struct { +} + +func (a unknownAddr) Network() string { + return "unknown" +} + +func (a unknownAddr) String() string { + return "unknown" +} + +func (c *netConn) RemoteAddr() net.Addr { + return unknownAddr{} +} + +func (c *netConn) LocalAddr() net.Addr { + return unknownAddr{} +} + +func (c *netConn) SetDeadline(t time.Time) error { + c.SetWriteDeadline(t) + c.SetReadDeadline(t) + return nil +} + +func (c *netConn) SetWriteDeadline(t time.Time) error { + c.writeTimer.Reset(t.Sub(time.Now())) + return nil +} + +func (c *netConn) SetReadDeadline(t time.Time) error { + c.readTimer.Reset(t.Sub(time.Now())) + return nil +} diff --git a/websocket_test.go b/websocket_test.go index 2d7db271..2112ff7e 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -118,6 +118,54 @@ func TestHandshake(t *testing.T) { return nil }, }, + { + name: "netConn", + server: func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + nc := websocket.NetConn(c) + defer nc.Close() + + nc.SetWriteDeadline(time.Now().Add(time.Second * 10)) + + _, err = nc.Write([]byte("hello")) + if err != nil { + return err + } + + return nil + }, + client: func(ctx context.Context, u string) error { + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + Subprotocols: []string{"meow"}, + }) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + nc := websocket.NetConn(c) + defer nc.Close() + + nc.SetReadDeadline(time.Now().Add(time.Second * 10)) + + p := make([]byte, len("hello")) + _, err = io.ReadFull(nc, p) + if err != nil { + return err + } + + if string(p) != "hello" { + return xerrors.Errorf("unexpected payload %q received", string(p)) + } + + return nil + }, + }, { name: "defaultSubprotocol", server: func(w http.ResponseWriter, r *http.Request) error { From 2e4b1105932814e737c4fa3b5048bc9d72d7dea3 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 1 Jul 2019 10:36:09 -0400 Subject: [PATCH 2/3] Protect against Reader after CloseRead Closes #101 --- netconn.go | 3 ++- websocket.go | 11 ++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/netconn.go b/netconn.go index 0de2f1cb..184d5d6c 100644 --- a/netconn.go +++ b/netconn.go @@ -2,11 +2,12 @@ package websocket import ( "context" - "golang.org/x/xerrors" "io" "math" "net" "time" + + "golang.org/x/xerrors" ) // NetConn converts a *websocket.Conn into a net.Conn. diff --git a/websocket.go b/websocket.go index e7fb0dfa..f875a142 100644 --- a/websocket.go +++ b/websocket.go @@ -12,6 +12,7 @@ import ( "runtime" "strconv" "sync" + "sync/atomic" "time" "golang.org/x/xerrors" @@ -64,6 +65,7 @@ type Conn struct { previousReader *messageReader // readFrameLock is acquired to read from bw. readFrameLock chan struct{} + readClosed int64 readHeaderBuf []byte controlPayloadBuf []byte @@ -329,6 +331,10 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { // See https://github.com/nhooyr/websocket/issues/87#issue-451703332 // Most users should not need this. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + if atomic.LoadInt64(&c.readClosed) == 1 { + return 0, nil, xerrors.Errorf("websocket connection read closed") + } + typ, r, err := c.reader(ctx) if err != nil { return 0, nil, xerrors.Errorf("failed to get reader: %w", err) @@ -395,10 +401,13 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { // Use this when you do not want to read data messages from the connection anymore but will // want to write messages to it. func (c *Conn) CloseRead(ctx context.Context) context.Context { + atomic.StoreInt64(&c.readClosed, 1) + ctx, cancel := context.WithCancel(ctx) go func() { defer cancel() - c.Reader(ctx) + // We use the unexported reader so that we don't get the read closed error. + c.reader(ctx) c.Close(StatusPolicyViolation, "unexpected data message") }() return ctx From 9d31b8d2a78704c6a994508c29116342335b1a5c Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 1 Jul 2019 10:56:23 -0400 Subject: [PATCH 3/3] Fix docs on NetConn --- netconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netconn.go b/netconn.go index 184d5d6c..3e43d905 100644 --- a/netconn.go +++ b/netconn.go @@ -17,7 +17,7 @@ import ( // When a deadline is hit, the connection will be closed. This is // different from most net.Conn implementations where only the // reading/writing goroutines are interrupted but the connection is kept alive. -// The Addr methods will return zero value net.TCPAddr. +// The Addr methods will return a mock net.Addr. func NetConn(c *Conn) net.Conn { nc := &netConn{ c: c,