diff --git a/AUTHORS b/AUTHORS index 10e2ebb94..abd7b1de3 100644 --- a/AUTHORS +++ b/AUTHORS @@ -18,6 +18,7 @@ Asta Xie Bulat Gaifullin Carlos Nieto Chris Moos +Chris Piraino Daniel Nichter Daniƫl van Eeden Dave Protasowski @@ -34,6 +35,7 @@ INADA Naoki Jacek Szwec James Harr Jian Zhen +John Shahid Joshua Prunier Julien Lefevre Julien Schmidt @@ -58,6 +60,7 @@ Runrioter Wung Soroush Pour Stan Putrya Stanley Gunawan +Thomas Parrott Xiangyu Hu Xiaobing Jiang Xiuming Chen diff --git a/driver_test.go b/driver_test.go index 206e07cc9..7f3bb19fa 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1773,6 +1773,88 @@ func TestCustomDial(t *testing.T) { } } +// Ensure mariadb's ER_CONNECTION_KILLED will cause the query to be restarted +func TestConnectionLost(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + var proxyConn net.Conn + + killCh := make(chan struct{}) + + // our custom dial function which justs wraps net.Dial here + RegisterDial("mydial", func(addr string) (net.Conn, error) { + conn, err := net.Dial(prot, addr) + if err != nil { + return nil, err + } + + var clientConn net.Conn + proxyConn, clientConn = net.Pipe() + go io.Copy(conn, proxyConn) + + bytesCh := make(chan []byte) + go func() { + for { + bs := make([]byte, 1024) + n, err := conn.Read(bs) + if err == io.EOF { + return + } + if err != nil { + panic(err) + } + bytesCh <- bs[:n] + } + }() + go func() { + for { + select { + case bs := <-bytesCh: + _, err := proxyConn.Write(bs) + if err == io.ErrClosedPipe { + return + } + if err != nil { + panic(err) + } + case <-killCh: + go func() { + proxyConn.Write([]byte{ + 0x08, // packet size + 0x00, + 0x00, + 0x00, // sequence 0 + 0xFF, // err_packet + 0x87, // ER_CONNECTION_KILLED error + 0x07, + 0x00, // sql_state_marker + }) + }() + } + } + }() + return clientConn, err + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + if _, err = db.Exec("DO 1"); err != nil { + t.Fatalf("connection failed: %s", err.Error()) + } + + killCh <- struct{}{} + + if _, err = db.Exec("DO 1"); err != nil { + t.Fatalf("connection failed: %s", err.Error()) + } +} + func TestSQLInjection(t *testing.T) { createTest := func(arg string) func(dbt *DBTest) { return func(dbt *DBTest) { diff --git a/packets.go b/packets.go index 9715067c4..8e96677d6 100644 --- a/packets.go +++ b/packets.go @@ -46,8 +46,15 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if data[3] > mc.sequence { return nil, ErrPktSyncMul } - return nil, ErrPktSync + + // The MariaDB server sends an error packet with sequence numer 0 during + // server shutdown. Continue to process it so the specific error can be + // detected. + if data[3] != 0 { + return nil, ErrPktSync + } } + mc.sequence++ // packets with length 0 terminate a previous packet which is a @@ -585,6 +592,13 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { pos = 9 } + // If error code is for ER_CONNECTION_KILLED, then mark connection as bad. + // https://mariadb.com/kb/en/mariadb/mariadb-error-codes/ + if errno == 1927 { + errLog.Print("Error ", errno, ": ", string(data[pos:])) + return driver.ErrBadConn + } + // Error Message [string] return &MySQLError{ Number: errno, diff --git a/packets_test.go b/packets_test.go index 31c892d85..b3ab09640 100644 --- a/packets_test.go +++ b/packets_test.go @@ -115,9 +115,9 @@ func TestReadPacketWrongSequenceID(t *testing.T) { } // too low sequence id - conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} + conn.data = []byte{0x01, 0x00, 0x00, 0x01, 0xff} conn.maxReads = 1 - mc.sequence = 1 + mc.sequence = 2 _, err := mc.readPacket() if err != ErrPktSync { t.Errorf("expected ErrPktSync, got %v", err)