diff --git a/connection.go b/connection.go index a4ae8cc36..b176833f1 100644 --- a/connection.go +++ b/connection.go @@ -581,43 +581,62 @@ func pack(h *smallWBuf, enc *encoder, reqid uint32, return } -func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) (err error) { +func (conn *Connection) writeRequest(w *bufio.Writer, req Request) (err error) { var packet smallWBuf - req := newAuthRequest(conn.opts.User, string(scramble)) err = pack(&packet, newEncoder(&packet), 0, req, ignoreStreamId, conn.Schema) if err != nil { - return errors.New("auth: pack error " + err.Error()) + return errors.New("pack error " + err.Error()) } if err := write(w, packet.b); err != nil { - return errors.New("auth: write error " + err.Error()) + return errors.New("write error " + err.Error()) } if err = w.Flush(); err != nil { - return errors.New("auth: flush error " + err.Error()) + return errors.New("flush error " + err.Error()) } return } -func (conn *Connection) readAuthResponse(r io.Reader) (err error) { +func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) (err error) { + req := newAuthRequest(conn.opts.User, string(scramble)) + + err = conn.writeRequest(w, req) + if err != nil { + return fmt.Errorf("auth: %w", err) + } + + return nil +} + +func (conn *Connection) readResponse(r io.Reader) (resp Response, err error) { respBytes, err := conn.read(r) if err != nil { - return errors.New("auth: read error " + err.Error()) + return resp, errors.New("read error " + err.Error()) } - resp := Response{buf: smallBuf{b: respBytes}} + resp = Response{buf: smallBuf{b: respBytes}} err = resp.decodeHeader(conn.dec) if err != nil { - return errors.New("auth: decode response header error " + err.Error()) + return resp, errors.New("decode response header error " + err.Error()) } err = resp.decodeBody() if err != nil { switch err.(type) { case Error: - return err + return resp, err default: - return errors.New("auth: decode response body error " + err.Error()) + return resp, errors.New("decode response body error " + err.Error()) } } - return + return resp, nil +} + +func (conn *Connection) readAuthResponse(r io.Reader) (err error) { + _, err = conn.readResponse(r) + if err != nil { + return fmt.Errorf("auth: %w", err) + } + + return nil } func (conn *Connection) createConnection(reconnect bool) (err error) {