diff --git a/connection.go b/connection.go index c0746e5..139bd8f 100644 --- a/connection.go +++ b/connection.go @@ -67,16 +67,11 @@ func (c *connection) Read(b []byte) (int, error) { func (c *connection) Write(b []byte) (int, error) { if c.err != nil { - return 0, c.err + return 0, io.ErrClosedPipe } - msg := newMessage(c.connID, b) metrics.AddSMTotalTransmitBytesOnWS(c.session.clientKey, float64(len(msg.Bytes()))) - n, err := c.session.writeMessage(c.writeDeadline, msg) - if err != nil { - return 0, err - } - return n, c.err + return c.session.writeMessage(c.writeDeadline, msg) } func (c *connection) writeErr(err error) { diff --git a/wsconn.go b/wsconn.go index 776c8dc..c0a8eae 100644 --- a/wsconn.go +++ b/wsconn.go @@ -1,6 +1,8 @@ package remotedialer import ( + "context" + "fmt" "io" "sync" "time" @@ -22,15 +24,28 @@ func newWSConn(conn *websocket.Conn) *wsConn { } func (w *wsConn) WriteMessage(messageType int, deadline time.Time, data []byte) error { - w.Lock() - defer w.Unlock() - if err := w.conn.SetWriteDeadline(deadline); err != nil { - return err + if deadline.IsZero() { + w.Lock() + defer w.Unlock() + return w.conn.WriteMessage(messageType, data) } - if err := w.conn.SetReadDeadline(deadline); err != nil { + + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + done := make(chan error, 1) + go func() { + w.Lock() + defer w.Unlock() + done <- w.conn.WriteMessage(messageType, data) + }() + + select { + case <-ctx.Done(): + return fmt.Errorf("i/o timeout") + case err := <-done: return err } - return w.conn.WriteMessage(messageType, data) } func (w *wsConn) NextReader() (int, io.Reader, error) { @@ -41,12 +56,21 @@ func (w *wsConn) setupDeadline() { w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration)) w.conn.SetPingHandler(func(string) error { w.Lock() - w.conn.WriteControl(websocket.PongMessage, []byte(""), time.Now().Add(PingWaitDuration)) + err := w.conn.WriteControl(websocket.PongMessage, []byte(""), time.Now().Add(PingWaitDuration)) w.Unlock() - return w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration)) + if err != nil { + return err + } + if err := w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration)); err != nil { + return err + } + return w.conn.SetWriteDeadline(time.Now().Add(PingWaitDuration)) }) w.conn.SetPongHandler(func(string) error { - return w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration)) + if err := w.conn.SetReadDeadline(time.Now().Add(PingWaitDuration)); err != nil { + return err + } + return w.conn.SetWriteDeadline(time.Now().Add(PingWaitDuration)) }) }