diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 9ee244b6dd4..39b17be87a0 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -188,6 +188,14 @@ type PrepareData struct { BindVars map[string]*querypb.BindVariable } +type execResult byte + +const ( + execSuccess execResult = iota + execErr + connErr +) + // bufPool is used to allocate and free buffers in an efficient way. var bufPool = bucketpool.New(connBufferSize, MaxPacketSize) @@ -756,7 +764,7 @@ func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error { // handleNextCommand is called in the server loop to process // incoming packets. -func (c *Conn) handleNextCommand(handler Handler) error { +func (c *Conn) handleNextCommand(handler Handler) bool { c.sequence = 0 data, err := c.readEphemeralPacket() if err != nil { @@ -764,78 +772,72 @@ func (c *Conn) handleNextCommand(handler Handler) error { if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { log.Errorf("Error reading packet from %s: %v", c, err) } - return err + return false } switch data[0] { case ComQuit: c.recycleReadPacket() - return errors.New("ComQuit") + return false case ComInitDB: db := c.parseComInitDB(data) c.recycleReadPacket() - if err := c.execQuery(fmt.Sprintf("use `%s`", db), handler, false); err != nil { - return err - } + res := c.execQuery(fmt.Sprintf("use `%s`", db), handler, false) + return res != connErr case ComQuery: - err := func() error { - c.startWriterBuffering() - defer func() { - if err := c.endWriterBuffering(); err != nil { - log.Errorf("conn %v: flush() failed: %v", c.ID(), err) - } - }() + c.startWriterBuffering() + defer func() { + if err := c.endWriterBuffering(); err != nil { + log.Errorf("conn %v: flush() failed: %v", c.ID(), err) + } + }() + + queryStart := time.Now() + query := c.parseComQuery(data) + c.recycleReadPacket() - queryStart := time.Now() - query := c.parseComQuery(data) - c.recycleReadPacket() - - var queries []string - if c.Capabilities&CapabilityClientMultiStatements != 0 { - queries, err = sqlparser.SplitStatementToPieces(query) - if err != nil { - log.Errorf("Conn %v: Error splitting query: %v", c, err) - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Errorf("Conn %v: Error writing query error: %v", c, werr) - return werr - } + var queries []string + if c.Capabilities&CapabilityClientMultiStatements != 0 { + queries, err = sqlparser.SplitStatementToPieces(query) + if err != nil { + log.Errorf("Conn %v: Error splitting query: %v", c, err) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Conn %v: Error writing query error: %v", c, werr) + return false } - } else { - queries = []string{query} } - for index, sql := range queries { - more := false - if index != len(queries)-1 { - more = true - } - if err := c.execQuery(sql, handler, more); err != nil { - return err - } + } else { + queries = []string{query} + } + for index, sql := range queries { + more := false + if index != len(queries)-1 { + more = true + } + res := c.execQuery(sql, handler, more) + if res != execSuccess { + return res != connErr } - - timings.Record(queryTimingKey, queryStart) - - return nil - }() - if err != nil { - return err } + timings.Record(queryTimingKey, queryStart) + case ComPing: c.recycleReadPacket() // Return error if listener was shut down and OK otherwise if c.listener.isShutdown() { if err := c.writeErrorPacket(ERServerShutdown, SSServerShutdown, "Server shutdown in progress"); err != nil { log.Errorf("Error writing ComPing error to %s: %v", c, err) - return err + return false } } else { if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { log.Errorf("Error writing ComPing result to %s: %v", c, err) - return err + return false } } + case ComSetOption: operation, ok := c.parseComSetOption(data) c.recycleReadPacket() @@ -849,20 +851,21 @@ func (c *Conn) handleNextCommand(handler Handler) error { log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", c.ConnectionID, data) if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { log.Errorf("Error writing error packet to client: %v", err) - return err + return false } } if err := c.writeEndResult(false, 0, 0, 0); err != nil { log.Errorf("Error writeEndResult error %v ", err) - return err + return false } } else { log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", c.ConnectionID, data) if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { log.Errorf("Error writing error packet to client: %v", err) - return err + return false } } + case ComPrepare: query := c.parseComPrepare(data) c.recycleReadPacket() @@ -875,17 +878,22 @@ func (c *Conn) handleNextCommand(handler Handler) error { if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. log.Errorf("Conn %v: Error writing query error: %v", c, werr) - return werr + return false + } + } + if len(queries) != 1 { + log.Errorf("Conn %v: can not prepare multiple statements", c, err) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Conn %v: Error writing query error: %v", c, werr) + return false } + return true } } else { queries = []string{query} } - if len(queries) != 1 { - return fmt.Errorf("can not prepare multiple statements") - } - // Popoulate PrepareData c.StatementID++ prepare := &PrepareData{ @@ -899,7 +907,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. log.Errorf("Conn %v: Error writing prepared statement error: %v", c, werr) - return werr + return false } } @@ -934,128 +942,136 @@ func (c *Conn) handleNextCommand(handler Handler) error { if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) - return werr + return false } - return nil + return true } if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil { - return err + log.Error("Error writing prepare data to client %v: %v", c.ConnectionID, err) + return false } case ComStmtExecute: - err := func() error { - c.startWriterBuffering() + c.startWriterBuffering() + defer func() { + if err := c.endWriterBuffering(); err != nil { + log.Errorf("conn %v: flush() failed: %v", c.ID(), err) + } + }() + queryStart := time.Now() + stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) + c.recycleReadPacket() + + if stmtID != uint32(0) { defer func() { - if err := c.endWriterBuffering(); err != nil { - log.Errorf("conn %v: flush() failed: %v", c.ID(), err) - } + // Allocate a new bindvar map every time since VTGate.Execute() mutates it. + prepare := c.PrepareData[stmtID] + prepare.BindVars = make(map[string]*querypb.BindVariable, prepare.ParamsCount) }() - queryStart := time.Now() - stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) - c.recycleReadPacket() - - if stmtID != uint32(0) { - defer func() { - // Allocate a new bindvar map every time since VTGate.Execute() mutates it. - prepare := c.PrepareData[stmtID] - prepare.BindVars = make(map[string]*querypb.BindVariable, prepare.ParamsCount) - }() - } + } - if err != nil { - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) - return werr - } - return nil + if err != nil { + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + return false } + return true + } - fieldSent := false - // sendFinished is set if the response should just be an OK packet. - sendFinished := false - prepare := c.PrepareData[stmtID] - err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { - if sendFinished { - // Failsafe: Unreachable if server is well-behaved. - return io.EOF - } - - if !fieldSent { - fieldSent = true - - if len(qr.Fields) == 0 { - sendFinished = true - // We should not send any more packets after this. - return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0) - } - if err := c.writeFields(qr); err != nil { - return err - } - } - - return c.writeBinaryRows(qr) - }) + fieldSent := false + // sendFinished is set if the response should just be an OK packet. + sendFinished := false + prepare := c.PrepareData[stmtID] + err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { + if sendFinished { + // Failsafe: Unreachable if server is well-behaved. + return io.EOF + } - // If no field was sent, we expect an error. if !fieldSent { - // This is just a failsafe. Should never happen. - if err == nil || err == io.EOF { - err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) - } - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Errorf("Error writing query error to %s: %v", c, werr) - return werr + fieldSent = true + + if len(qr.Fields) == 0 { + sendFinished = true + // We should not send any more packets after this. + return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0) } - } else { - if err != nil { - // We can't send an error in the middle of a stream. - // All we can do is abort the send, which will cause a 2013. - log.Errorf("Error in the middle of a stream to %s: %v", c, err) + if err := c.writeFields(qr); err != nil { return err } + } - // Send the end packet only sendFinished is false (results were streamed). - // In this case the affectedRows and lastInsertID are always 0 since it - // was a read operation. - if !sendFinished { - if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { - log.Errorf("Error writing result to %s: %v", c, err) - return err - } - } + return c.writeBinaryRows(qr) + }) + + // If no field was sent, we expect an error. + if !fieldSent { + // This is just a failsafe. Should never happen. + if err == nil || err == io.EOF { + err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) + } + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Error writing query error to %s: %v", c, werr) + return false + } + } else { + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return false } - timings.Record(queryTimingKey, queryStart) - return nil - }() - if err != nil { - return err + // Send the end packet only sendFinished is false (results were streamed). + // In this case the affectedRows and lastInsertID are always 0 since it + // was a read operation. + if !sendFinished { + if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return false + } + } } + + timings.Record(queryTimingKey, queryStart) + case ComStmtSendLongData: stmtID, paramID, chunkData, ok := c.parseComStmtSendLongData(data) c.recycleReadPacket() if !ok { - err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data) - log.Error(err.Error()) - return err + err = fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + return false + } + return true } prepare, ok := c.PrepareData[stmtID] if !ok { - err := fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID) - log.Error(err.Error()) - return err + err = fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + return false + } + return true } if prepare.BindVars == nil || prepare.ParamsCount == uint16(0) || paramID >= prepare.ParamsCount { - err := fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt) - log.Error(err.Error()) - return err + err = fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + return false + } + return true } chunk := make([]byte, len(chunkData)) @@ -1080,7 +1096,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { log.Error("Error writing error packet to client: %v", err) - return err + return false } } @@ -1089,7 +1105,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data) if err := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order: %v", data); err != nil { log.Error("Error writing error packet to client: %v", err) - return err + return false } } @@ -1101,7 +1117,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) - return err + return false } case ComResetConnection: @@ -1120,14 +1136,14 @@ func (c *Conn) handleNextCommand(handler Handler) error { c.recycleReadPacket() if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "command handling not implemented yet: %v", data[0]); err != nil { log.Errorf("Error writing error packet to %s: %s", c, err) - return err + return false } } - return nil + return true } -func (c *Conn) execQuery(query string, handler Handler, more bool) error { +func (c *Conn) execQuery(query string, handler Handler, more bool) execResult { fieldSent := false // sendFinished is set if the response should just be an OK packet. sendFinished := false @@ -1173,28 +1189,28 @@ func (c *Conn) execQuery(query string, handler Handler, more bool) error { if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. log.Errorf("Error writing query error to %s: %v", c, werr) - return werr - } - } else { - if err != nil { - // We can't send an error in the middle of a stream. - // All we can do is abort the send, which will cause a 2013. - log.Errorf("Error in the middle of a stream to %s: %v", c, err) - return err + return connErr } + return execErr + } + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return connErr + } - // Send the end packet only sendFinished is false (results were streamed). - // In this case the affectedRows and lastInsertID are always 0 since it - // was a read operation. - if !sendFinished { - if err := c.writeEndResult(more, 0, 0, handler.WarningCount(c)); err != nil { - log.Errorf("Error writing result to %s: %v", c, err) - return err - } + // Send the end packet only sendFinished is false (results were streamed). + // In this case the affectedRows and lastInsertID are always 0 since it + // was a read operation. + if !sendFinished { + if err := c.writeEndResult(more, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return connErr } } - return nil + return execSuccess } // diff --git a/go/mysql/conn_test.go b/go/mysql/conn_test.go index d8ab1a8526a..aa5bd2eaf45 100644 --- a/go/mysql/conn_test.go +++ b/go/mysql/conn_test.go @@ -19,12 +19,18 @@ package mysql import ( "bytes" crypto_rand "crypto/rand" + "fmt" "math/rand" "net" "reflect" + "runtime/debug" "sync" "testing" "time" + + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" ) func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) { @@ -288,3 +294,89 @@ func TestEOFOrLengthEncodedIntFuzz(t *testing.T) { } } } + +func TestMultiStatementStopsOnError(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + sConn.Capabilities |= CapabilityClientMultiStatements + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + err := cConn.WriteComQuery("select 1;select 2") + require.NoError(t, err) + + // this handler will return an error on the first run, and fail the test if it's run more times + handler := &singleRun{t: t, err: fmt.Errorf("execution failed")} + res := sConn.handleNextCommand(handler) + require.True(t, res, res, "we should not break the connection because of execution errors") + + data, err := cConn.ReadPacket() + require.NoError(t, err) + require.NotEmpty(t, data) + require.EqualValues(t, data[0], ErrPacket) // we should see the error here +} + +func TestInitDbAgainstWrongDbDoesNotDropConnection(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + sConn.Capabilities |= CapabilityClientMultiStatements + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + err := cConn.writeComInitDB("database") + require.NoError(t, err) + + handler := &singleRun{t: t, err: fmt.Errorf("execution failed")} + res := sConn.handleNextCommand(handler) + require.True(t, res, "we should not break the connection because of execution errors") + + data, err := cConn.ReadPacket() + require.NoError(t, err) + require.NotEmpty(t, data) + require.EqualValues(t, data[0], ErrPacket) // we should see the error here +} + +type singleRun struct { + hasRun bool + t *testing.T + err error +} + +func (h *singleRun) NewConnection(*Conn) { + panic("implement me") +} + +func (h *singleRun) ConnectionClosed(*Conn) { + panic("implement me") +} + +func (h *singleRun) ComQuery(*Conn, string, func(*sqltypes.Result) error) error { + if h.hasRun { + debug.PrintStack() + h.t.Fatal("don't do this!") + } + h.hasRun = true + return h.err +} + +func (h *singleRun) ComPrepare(*Conn, string, map[string]*querypb.BindVariable) ([]*querypb.Field, error) { + panic("implement me") +} + +func (h *singleRun) ComStmtExecute(*Conn, *PrepareData, func(*sqltypes.Result) error) error { + panic("implement me") +} + +func (h *singleRun) WarningCount(*Conn) uint16 { + return 0 +} + +func (h *singleRun) ComResetConnection(*Conn) { + panic("implement me") +} + +var _ Handler = (*singleRun)(nil) diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 8273759c38d..105271dc6d2 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -628,9 +628,9 @@ func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result * } for i := 0; i < count; i++ { - err := sConn.handleNextCommand(&handler) - if err != nil { - t.Fatalf("error handling command: %v", err) + kontinue := sConn.handleNextCommand(&handler) + if !kontinue { + t.Fatalf("error handling command: %d", i) } } diff --git a/go/mysql/server.go b/go/mysql/server.go index bc6e1e93266..65f0c5a3dd6 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -469,8 +469,8 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } for { - err := c.handleNextCommand(l.handler) - if err != nil { + kontinue := c.handleNextCommand(l.handler) + if !kontinue { return } } diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index d84f6db8118..9b493ada6b0 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -580,9 +580,7 @@ func TestServer(t *testing.T) { }} defer authServer.close() l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false) - if err != nil { - t.Fatalf("NewListener failed: %v", err) - } + require.NoError(t, err) l.SlowConnectWarnThreshold.Set(time.Duration(time.Nanosecond * 1)) defer l.Close() go l.Accept() @@ -597,94 +595,21 @@ func TestServer(t *testing.T) { Pass: "password1", } - initialTimingCounts := timings.Counts() - initialConnAccept := connAccept.Get() - initialConnSlow := connSlow.Get() - initialconnRefuse := connRefuse.Get() - - // Run an 'error' command. - th.SetErr(NewSQLError(ERUnknownComError, SSUnknownComError, "forced query error")) - output, ok := runMysql(t, params, "error") - if ok { - t.Fatalf("mysql should have failed: %v", output) - } - if !strings.Contains(output, "ERROR 1047 (08S01)") || - !strings.Contains(output, "forced query error") { - t.Errorf("Unexpected output for 'error': %v", output) - } - if connCount.Get() != 0 { - t.Errorf("Expected ConnCount=0, got %d", connCount.Get()) - } - if connAccept.Get()-initialConnAccept != 1 { - t.Errorf("Expected ConnAccept delta=1, got %d", connAccept.Get()-initialConnAccept) - } - if connSlow.Get()-initialConnSlow != 1 { - t.Errorf("Expected ConnSlow delta=1, got %d", connSlow.Get()-initialConnSlow) - } - if connRefuse.Get()-initialconnRefuse != 0 { - t.Errorf("Expected connRefuse delta=0, got %d", connRefuse.Get()-initialconnRefuse) - } - - expectedTimingDeltas := map[string]int64{ - "All": 2, - connectTimingKey: 1, - queryTimingKey: 1, - } - gotTimingCounts := timings.Counts() - for key, got := range gotTimingCounts { - expected := expectedTimingDeltas[key] - delta := got - initialTimingCounts[key] - if delta < expected { - t.Errorf("Expected Timing count delta %s should be >= %d, got %d", key, expected, delta) - } - } - - // Set the slow connect threshold to something high that we don't expect to trigger - l.SlowConnectWarnThreshold.Set(time.Duration(time.Second * 1)) - - // Run a 'panic' command, other side should panic, recover and - // close the connection. - output, ok = runMysql(t, params, "panic") - if ok { - t.Fatalf("mysql should have failed: %v", output) - } - if !strings.Contains(output, "ERROR 2013 (HY000)") || - !strings.Contains(output, "Lost connection to MySQL server during query") { - t.Errorf("Unexpected output for 'panic'") - } - if connCount.Get() != 0 { - t.Errorf("Expected ConnCount=0, got %d", connCount.Get()) - } - if connAccept.Get()-initialConnAccept != 2 { - t.Errorf("Expected ConnAccept delta=2, got %d", connAccept.Get()-initialConnAccept) - } - if connSlow.Get()-initialConnSlow != 1 { - t.Errorf("Expected ConnSlow delta=1, got %d", connSlow.Get()-initialConnSlow) - } - if connRefuse.Get()-initialconnRefuse != 0 { - t.Errorf("Expected connRefuse delta=0, got %d", connRefuse.Get()-initialconnRefuse) - } - // Run a 'select rows' command with results. - output, ok = runMysql(t, params, "select rows") - if !ok { - t.Fatalf("mysql failed: %v", output) - } + output, err := runMysqlWithErr(t, params, "select rows") + require.NoError(t, err) + if !strings.Contains(output, "nice name") || !strings.Contains(output, "nicer name") || !strings.Contains(output, "2 rows in set") { t.Errorf("Unexpected output for 'select rows'") } - if strings.Contains(output, "warnings") { - t.Errorf("Unexpected warnings in 'select rows'") - } + assert.NotContains(t, output, "warnings") // Run a 'select rows' command with warnings th.SetWarnings(13) - output, ok = runMysql(t, params, "select rows") - if !ok { - t.Fatalf("mysql failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "select rows") + require.NoError(t, err) if !strings.Contains(output, "nice name") || !strings.Contains(output, "nicer name") || !strings.Contains(output, "2 rows in set") || @@ -696,39 +621,31 @@ func TestServer(t *testing.T) { // If there's an error after streaming has started, // we should get a 2013 th.SetErr(NewSQLError(ERUnknownComError, SSUnknownComError, "forced error after send")) - output, ok = runMysql(t, params, "error after send") - if ok { - t.Fatalf("mysql should have failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "error after send") + require.Error(t, err) if !strings.Contains(output, "ERROR 2013 (HY000)") || !strings.Contains(output, "Lost connection to MySQL server during query") { t.Errorf("Unexpected output for 'panic'") } // Run an 'insert' command, no rows, but rows affected. - output, ok = runMysql(t, params, "insert") - if !ok { - t.Fatalf("mysql failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "insert") + require.NoError(t, err) if !strings.Contains(output, "Query OK, 123 rows affected") { t.Errorf("Unexpected output for 'insert'") } // Run a 'schema echo' command, to make sure db name is right. params.DbName = "XXXfancyXXX" - output, ok = runMysql(t, params, "schema echo") - if !ok { - t.Fatalf("mysql failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "schema echo") + require.NoError(t, err) if !strings.Contains(output, params.DbName) { t.Errorf("Unexpected output for 'schema echo'") } // Sanity check: make sure this didn't go through SSL - output, ok = runMysql(t, params, "ssl echo") - if !ok { - t.Fatalf("mysql failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "ssl echo") + require.NoError(t, err) if !strings.Contains(output, "ssl_flag") || !strings.Contains(output, "OFF") || !strings.Contains(output, "1 row in set") { @@ -736,10 +653,8 @@ func TestServer(t *testing.T) { } // UserData check: checks the server user data is correct. - output, ok = runMysql(t, params, "userData echo") - if !ok { - t.Fatalf("mysql failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "userData echo") + require.NoError(t, err) if !strings.Contains(output, "user1") || !strings.Contains(output, "user_data") || !strings.Contains(output, "userData1") { @@ -748,10 +663,8 @@ func TestServer(t *testing.T) { // Permissions check: check a bad password is rejected. params.Pass = "bad" - output, ok = runMysql(t, params, "select rows") - if ok { - t.Fatalf("mysql should have failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "select rows") + require.Error(t, err) if !strings.Contains(output, "1045") || !strings.Contains(output, "28000") || !strings.Contains(output, "Access denied") { @@ -761,10 +674,8 @@ func TestServer(t *testing.T) { // Permissions check: check an unknown user is rejected. params.Pass = "password1" params.Uname = "user2" - output, ok = runMysql(t, params, "select rows") - if ok { - t.Fatalf("mysql should have failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "select rows") + require.Error(t, err) if !strings.Contains(output, "1045") || !strings.Contains(output, "28000") || !strings.Contains(output, "Access denied") { @@ -776,6 +687,85 @@ func TestServer(t *testing.T) { // time.Sleep(60 * time.Minute) } +func TestServerStats(t *testing.T) { + th := &testHandler{} + + authServer := NewAuthServerStatic("", "", 0) + authServer.entries["user1"] = []*AuthServerStaticEntry{{ + Password: "password1", + UserData: "userData1", + }} + defer authServer.close() + l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false) + if err != nil { + t.Fatalf("NewListener failed: %v", err) + } + l.SlowConnectWarnThreshold.Set(time.Duration(time.Nanosecond * 1)) + defer l.Close() + go l.Accept() + + host, port := getHostPort(t, l.Addr()) + + // Setup the right parameters. + params := &ConnParams{ + Host: host, + Port: port, + Uname: "user1", + Pass: "password1", + } + + timings.Reset() + connAccept.Reset() + connCount.Reset() + connSlow.Reset() + connRefuse.Reset() + + // Run an 'error' command. + th.SetErr(NewSQLError(ERUnknownComError, SSUnknownComError, "forced query error")) + output, ok := runMysql(t, params, "error") + if ok { + t.Fatalf("mysql should have failed: %v", output) + } + if !strings.Contains(output, "ERROR 1047 (08S01)") || + !strings.Contains(output, "forced query error") { + t.Errorf("Unexpected output for 'error': %v", output) + } + assert.EqualValues(t, 0, connCount.Get(), "connCount") + assert.EqualValues(t, 1, connAccept.Get(), "connAccept") + assert.EqualValues(t, 1, connSlow.Get(), "connSlow") + assert.EqualValues(t, 0, connRefuse.Get(), "connRefuse") + + expectedTimingDeltas := map[string]int64{ + "All": 2, + connectTimingKey: 1, + queryTimingKey: 1, + } + gotTimingCounts := timings.Counts() + for key, got := range gotTimingCounts { + expected := expectedTimingDeltas[key] + if got < expected { + t.Errorf("Expected Timing count delta %s should be >= %d, got %d", key, expected, got) + } + } + + // Set the slow connect threshold to something high that we don't expect to trigger + l.SlowConnectWarnThreshold.Set(time.Duration(time.Second * 1)) + + // Run a 'panic' command, other side should panic, recover and + // close the connection. + output, err = runMysqlWithErr(t, params, "panic") + require.Error(t, err) + if !strings.Contains(output, "ERROR 2013 (HY000)") || + !strings.Contains(output, "Lost connection to MySQL server during query") { + t.Errorf("Unexpected output for 'panic'") + } + + assert.EqualValues(t, 0, connCount.Get(), "connCount") + assert.EqualValues(t, 2, connAccept.Get(), "connAccept") + assert.EqualValues(t, 1, connSlow.Get(), "connSlow") + assert.EqualValues(t, 0, connRefuse.Get(), "connRefuse") +} + // TestClearTextServer creates a Server that needs clear text // passwords from the client. func TestClearTextServer(t *testing.T) { @@ -1219,6 +1209,14 @@ const enableCleartextPluginPrefix = "enable-cleartext-plugin: " // runMysql forks a mysql command line process connecting to the provided server. func runMysql(t *testing.T, params *ConnParams, command string) (string, bool) { + output, err := runMysqlWithErr(t, params, command) + if err != nil { + return output, false + } + return output, true + +} +func runMysqlWithErr(t *testing.T, params *ConnParams, command string) (string, error) { dir, err := vtenv.VtMysqlRoot() if err != nil { t.Fatalf("vtenv.VtMysqlRoot failed: %v", err) @@ -1277,9 +1275,9 @@ func runMysql(t *testing.T, params *ConnParams, command string) (string, bool) { out, err := cmd.CombinedOutput() output := string(out) if err != nil { - return output, false + return output, err } - return output, true + return output, nil } // binaryPath does a limited path lookup for a command, diff --git a/go/stats/timings.go b/go/stats/timings.go index 697963c4773..31243ea36a0 100644 --- a/go/stats/timings.go +++ b/go/stats/timings.go @@ -168,6 +168,13 @@ func (t *Timings) Label() string { return t.label } +// Reset will clear histograms: used during testing +func (t *Timings) Reset() { + t.mu.RLock() + t.histograms = make(map[string]*Histogram) + t.mu.RUnlock() +} + var bucketCutoffs = []int64{5e5, 1e6, 5e6, 1e7, 5e7, 1e8, 5e8, 1e9, 5e9, 1e10} var bucketLabels []string