diff --git a/br/pkg/mock/mock_cluster.go b/br/pkg/mock/mock_cluster.go index 0845d684dec8b..d42aca23e8ce7 100644 --- a/br/pkg/mock/mock_cluster.go +++ b/br/pkg/mock/mock_cluster.go @@ -85,6 +85,7 @@ func NewCluster() (*Cluster, error) { // Start runs a mock cluster. func (mock *Cluster) Start() error { server.RunInGoTest = true + server.RunInGoTestChan = make(chan struct{}) mock.TiDBDriver = server.NewTiDBDriver(mock.Storage) cfg := config.NewConfig() // let tidb random select a port @@ -104,6 +105,7 @@ func (mock *Cluster) Start() error { panic(err1) } }() + <-server.RunInGoTestChan mock.DSN = waitUntilServerOnline("127.0.0.1", cfg.Status.StatusPort) return nil } @@ -178,7 +180,8 @@ func waitUntilServerOnline(addr string, statusPort uint) string { } if retry == retryTime { log.Panic("failed to connect HTTP status in every 10 ms", - zap.Int("retryTime", retryTime)) + zap.Int("retryTime", retryTime), + zap.String("url", statusURL)) } return strings.SplitAfter(dsn, "/")[0] } diff --git a/domain/main_test.go b/domain/main_test.go index f236b8461fa12..7606a91ae89e0 100644 --- a/domain/main_test.go +++ b/domain/main_test.go @@ -24,6 +24,7 @@ import ( func TestMain(m *testing.M) { server.RunInGoTest = true + server.RunInGoTestChan = make(chan struct{}) testsetup.SetupForCommonTest() opts := []goleak.Option{ goleak.IgnoreTopFunction("github.com/golang/glog.(*loggingT).flushDaemon"), diff --git a/server/http_handler_test.go b/server/http_handler_test.go index 03598e6301481..beb996a8acda8 100644 --- a/server/http_handler_test.go +++ b/server/http_handler_test.go @@ -457,16 +457,17 @@ func (ts *basicHTTPHandlerTestSuite) startServer(t *testing.T) { cfg.Port = 0 cfg.Status.StatusPort = 0 cfg.Status.ReportStatus = true - + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) - ts.port = getPortFromTCPAddr(server.listener.Addr()) - ts.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) ts.server = server go func() { err := server.Run(ts.domain) require.NoError(t, err) }() + <-RunInGoTestChan + ts.port = getPortFromTCPAddr(server.listener.Addr()) + ts.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) ts.waitUntilServerOnline() do, err := session.GetDomain(ts.store) diff --git a/server/http_status.go b/server/http_status.go index 8070bd91e2b99..65b8289b9454f 100644 --- a/server/http_status.go +++ b/server/http_status.go @@ -61,8 +61,13 @@ import ( const defaultStatusPort = 10080 -func (s *Server) startStatusHTTP() { +func (s *Server) startStatusHTTP() error { + err := s.initHTTPListener() + if err != nil { + return err + } go s.startHTTPServer() + return nil } func serveError(w http.ResponseWriter, status int, txt string) { diff --git a/server/main_test.go b/server/main_test.go index cdf12551934db..aefd3ba4078a2 100644 --- a/server/main_test.go +++ b/server/main_test.go @@ -31,6 +31,7 @@ import ( ) func TestMain(m *testing.M) { + RunInGoTestChan = make(chan struct{}) testsetup.SetupForCommonTest() RunInGoTest = true // flag for NewServer to known it is running in test environment diff --git a/server/optimize_trace_test.go b/server/optimize_trace_test.go index f08f324f86dc7..558c38912295e 100644 --- a/server/optimize_trace_test.go +++ b/server/optimize_trace_test.go @@ -39,13 +39,13 @@ func TestDumpOptimizeTraceAPI(t *testing.T) { server, err := NewServer(cfg, driver) require.NoError(t, err) defer server.Close() - - client.port = getPortFromTCPAddr(server.listener.Addr()) - client.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go func() { err := server.Run(nil) require.NoError(t, err) }() + <-RunInGoTestChan + client.port = getPortFromTCPAddr(server.listener.Addr()) + client.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) client.waitUntilServerOnline() dom, err := session.GetDomain(store) diff --git a/server/plan_replayer_test.go b/server/plan_replayer_test.go index 633188625e7c5..a2083dec35126 100644 --- a/server/plan_replayer_test.go +++ b/server/plan_replayer_test.go @@ -40,17 +40,17 @@ func TestDumpPlanReplayerAPI(t *testing.T) { cfg.Port = client.port cfg.Status.StatusPort = client.statusPort cfg.Status.ReportStatus = true - + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, driver) require.NoError(t, err) defer server.Close() - - client.port = getPortFromTCPAddr(server.listener.Addr()) - client.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go func() { err := server.Run(nil) require.NoError(t, err) }() + <-RunInGoTestChan + client.port = getPortFromTCPAddr(server.listener.Addr()) + client.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) client.waitUntilServerOnline() dom, err := session.GetDomain(store) diff --git a/server/server.go b/server/server.go index 68923a382f9dc..41036c4bb31b0 100644 --- a/server/server.go +++ b/server/server.go @@ -50,6 +50,7 @@ import ( "github.com/blacktear23/go-proxyprotocol" "github.com/pingcap/errors" + "github.com/pingcap/log" autoid "github.com/pingcap/tidb/autoid_service" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" @@ -82,6 +83,8 @@ var ( osVersion string // RunInGoTest represents whether we are run code in test. RunInGoTest bool + // RunInGoTestChan is used to control the RunInGoTest. + RunInGoTestChan chan struct{} ) func init() { @@ -252,7 +255,11 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { if s.tlsConfig != nil { s.capability |= mysql.ClientSSL } + variable.RegisterStatistics(s) + return s, nil +} +func (s *Server) initTiDBListener() (err error) { if s.cfg.Host != "" && (s.cfg.Port != 0 || RunInGoTest) { addr := net.JoinHostPort(s.cfg.Host, strconv.Itoa(int(s.cfg.Port))) tcpProto := "tcp" @@ -260,7 +267,7 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { tcpProto = "tcp4" } if s.listener, err = net.Listen(tcpProto, addr); err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } logutil.BgLogger().Info("server is running MySQL protocol", zap.String("addr", addr)) if RunInGoTest && s.cfg.Port == 0 { @@ -270,18 +277,18 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { if s.cfg.Socket != "" { if err := cleanupStaleSocket(s.cfg.Socket); err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } if s.socket, err = net.Listen("unix", s.cfg.Socket); err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } logutil.BgLogger().Info("server is running MySQL protocol", zap.String("socket", s.cfg.Socket)) } if s.socket == nil && s.listener == nil { err = errors.New("Server not configured to listen on either -socket or -host and -port") - return nil, errors.Trace(err) + return errors.Trace(err) } if s.cfg.ProxyProtocol.Networks != "" { @@ -293,7 +300,7 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { int(s.cfg.ProxyProtocol.HeaderTimeout), s.cfg.ProxyProtocol.Fallbackable) if err != nil { logutil.BgLogger().Error("ProxyProtocol networks parameter invalid") - return nil, errors.Trace(err) + return errors.Trace(err) } if s.listener != nil { s.listener = ppListener @@ -303,10 +310,13 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("socket", s.cfg.Socket)) } } + return nil +} +func (s *Server) initHTTPListener() (err error) { if s.cfg.Status.ReportStatus { if err = s.listenStatusHTTPServer(); err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } } @@ -333,7 +343,7 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { variable.RegisterStatistics(s) - return s, nil + return nil } func cleanupStaleSocket(socket string) error { @@ -392,23 +402,50 @@ func (s *Server) Run(dom *domain.Domain) error { // Start HTTP API to report tidb info such as TPS. if s.cfg.Status.ReportStatus { - s.startStatusHTTP() + err := s.startStatusHTTP() + if err != nil { + log.Error("failed to create the server", zap.Error(err), zap.Stack("stack")) + return err + } } if config.GetGlobalConfig().Performance.ForceInitStats && dom != nil { <-dom.StatsHandle().InitStatsDone } // If error should be reported and exit the server it can be sent on this // channel. Otherwise, end with sending a nil error to signal "done" - errChan := make(chan error) + errChan := make(chan error, 2) + err := s.initTiDBListener() + if err != nil { + log.Error("failed to create the server", zap.Error(err), zap.Stack("stack")) + return err + } + // Register error API is not thread-safe, the caller MUST NOT register errors after initialization. + // To prevent misuse, set a flag to indicate that register new error will panic immediately. + // For regression of issue like https://github.com/pingcap/tidb/issues/28190 + terror.RegisterFinish() go s.startNetworkListener(s.listener, false, errChan) go s.startNetworkListener(s.socket, true, errChan) - err := <-errChan + if RunInGoTest && !isClosed(RunInGoTestChan) { + close(RunInGoTestChan) + } + err = <-errChan if err != nil { return err } return <-errChan } +// isClosed is to check if the channel is closed +func isClosed(ch chan struct{}) bool { + select { + case <-ch: + return true + default: + } + + return false +} + func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool, errChan chan error) { if listener == nil { errChan <- nil diff --git a/server/statistics_handler_test.go b/server/statistics_handler_test.go index 017b2743e1495..24092168ccf9c 100644 --- a/server/statistics_handler_test.go +++ b/server/statistics_handler_test.go @@ -41,17 +41,17 @@ func TestDumpStatsAPI(t *testing.T) { cfg.Status.StatusPort = client.statusPort cfg.Status.ReportStatus = true cfg.Socket = fmt.Sprintf("/tmp/tidb-mock-%d.sock", time.Now().UnixNano()) - + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, driver) require.NoError(t, err) defer server.Close() - - client.port = getPortFromTCPAddr(server.listener.Addr()) - client.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go func() { err := server.Run(nil) require.NoError(t, err) }() + <-RunInGoTestChan + client.port = getPortFromTCPAddr(server.listener.Addr()) + client.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) client.waitUntilServerOnline() dom, err := session.GetDomain(store) diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index 53b7106a9eebc..5c0c098066a0f 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -119,14 +119,15 @@ func TestTLSAuto(t *testing.T) { cfg.Security.RSAKeySize = 528 // Reduces unittest runtime err := os.MkdirAll(cfg.TempStoragePath, 0700) require.NoError(t, err) + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) - cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-RunInGoTestChan + cli.port = getPortFromTCPAddr(server.listener.Addr()) err = cli.runTestTLSConnection(t, connOverrider) // Relying on automatically created TLS certificates require.NoError(t, err) @@ -164,14 +165,15 @@ func TestTLSBasic(t *testing.T) { SSLCert: fileName("server-cert.pem"), SSLKey: fileName("server-key.pem"), } + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) - cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-RunInGoTestChan + cli.port = getPortFromTCPAddr(server.listener.Addr()) err = cli.runTestTLSConnection(t, connOverrider) // We should establish connection successfully. require.NoError(t, err) cli.runTestRegression(t, connOverrider, "TLSRegression") @@ -227,15 +229,16 @@ func TestTLSVerify(t *testing.T) { SSLCert: fileName("server-cert.pem"), SSLKey: fileName("server-key.pem"), } + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) defer server.Close() - cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-RunInGoTestChan + cli.port = getPortFromTCPAddr(server.listener.Addr()) // The client does not provide a certificate, the connection should succeed. err = cli.runTestTLSConnection(t, nil) require.NoError(t, err) @@ -331,15 +334,16 @@ func TestErrorNoRollback(t *testing.T) { SSLCert: "/tmp/server-cert-rollback.pem", SSLKey: "/tmp/server-key-rollback.pem", } + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) - cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run(nil) require.NoError(t, err) }() defer server.Close() - time.Sleep(time.Millisecond * 100) + <-RunInGoTestChan + cli.port = getPortFromTCPAddr(server.listener.Addr()) connOverrider := func(config *mysql.Config) { config.TLSConfig = "client-cert-rollback-test" } @@ -475,14 +479,15 @@ func TestReloadTLS(t *testing.T) { SSLCert: "/tmp/server-cert-reload.pem", SSLKey: "/tmp/server-key-reload.pem", } + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) - cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-RunInGoTestChan + cli.port = getPortFromTCPAddr(server.listener.Addr()) // The client provides a valid certificate. connOverrider := func(config *mysql.Config) { config.TLSConfig = "client-certificate-reload" diff --git a/server/tidb_test.go b/server/tidb_test.go index 78ee35fbfabb8..259942d929a52 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -98,11 +98,9 @@ func createTidbTestSuite(t *testing.T) *tidbTestSuite { cfg.Status.ReportStatus = true cfg.Status.StatusPort = ts.statusPort cfg.Performance.TCPKeepAlive = true - + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) - ts.port = getPortFromTCPAddr(server.listener.Addr()) - ts.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) ts.server = server ts.server.SetDomain(ts.domain) ts.server.InitGlobalConnID(ts.domain.ServerID) @@ -111,6 +109,9 @@ func createTidbTestSuite(t *testing.T) *tidbTestSuite { err := ts.server.Run(nil) require.NoError(t, err) }() + <-RunInGoTestChan + ts.port = getPortFromTCPAddr(server.listener.Addr()) + ts.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) ts.waitUntilServerOnline() t.Cleanup(func() { @@ -245,8 +246,9 @@ func TestStatusPort(t *testing.T) { cfg.Performance.TCPKeepAlive = true server, err := NewServer(cfg, ts.tidbdrv) + require.NoError(t, err) + err = server.Run(ts.domain) require.Error(t, err) - require.Nil(t, server) } func TestStatusAPIWithTLS(t *testing.T) { @@ -271,16 +273,17 @@ func TestStatusAPIWithTLS(t *testing.T) { cfg.Security.ClusterSSLCA = fileName("ca-cert-2.pem") cfg.Security.ClusterSSLCert = fileName("server-cert-2.pem") cfg.Security.ClusterSSLKey = fileName("server-key-2.pem") + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) - cli.port = getPortFromTCPAddr(server.listener.Addr()) - cli.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) + go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) - + <-RunInGoTestChan + cli.port = getPortFromTCPAddr(server.listener.Addr()) + cli.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) // https connection should work. ts.runTestStatusAPI(t) @@ -328,18 +331,17 @@ func TestStatusAPIWithTLSCNCheck(t *testing.T) { cfg.Security.ClusterSSLCert = serverCertPath cfg.Security.ClusterSSLKey = serverKeyPath cfg.Security.ClusterVerifyCN = []string{"tidb-client-2"} + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) - - cli.port = getPortFromTCPAddr(server.listener.Addr()) - cli.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go func() { err := server.Run(nil) require.NoError(t, err) }() defer server.Close() - time.Sleep(time.Millisecond * 100) - + <-RunInGoTestChan + cli.port = getPortFromTCPAddr(server.listener.Addr()) + cli.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) hc := newTLSHttpClient(t, caPath, client1CertPath, client1KeyPath, @@ -392,15 +394,16 @@ func TestSocketForwarding(t *testing.T) { cfg.Port = cli.port os.Remove(cfg.Socket) cfg.Status.ReportStatus = false - + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) - cli.port = getPortFromTCPAddr(server.listener.Addr()) + server.SetDomain(ts.domain) go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-RunInGoTestChan + cli.port = getPortFromTCPAddr(server.listener.Addr()) defer server.Close() cli.runTestRegression(t, func(config *mysql.Config) { @@ -424,14 +427,14 @@ func TestSocket(t *testing.T) { cfg.Status.ReportStatus = false ts := createTidbTestSuite(t) - + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-RunInGoTestChan defer server.Close() confFunc := func(config *mysql.Config) { @@ -458,14 +461,17 @@ func TestSocketAndIp(t *testing.T) { cfg.Status.ReportStatus = false ts := createTidbTestSuite(t) - + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) - cli.port = getPortFromTCPAddr(server.listener.Addr()) + server.SetDomain(ts.domain) + go func() { err := server.Run(nil) require.NoError(t, err) }() + <-RunInGoTestChan + cli.port = getPortFromTCPAddr(server.listener.Addr()) cli.waitUntilServerCanConnect() defer server.Close() @@ -620,16 +626,15 @@ func TestOnlySocket(t *testing.T) { cfg.Socket = socketFile cfg.Host = "" // No network interface listening for mysql traffic cfg.Status.ReportStatus = false - ts := createTidbTestSuite(t) - + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) + <-RunInGoTestChan defer server.Close() require.Nil(t, server.listener) require.NotNil(t, server.socket) @@ -1236,17 +1241,17 @@ func TestGracefulShutdown(t *testing.T) { cfg.Status.StatusPort = 0 cfg.Status.ReportStatus = true cfg.Performance.TCPKeepAlive = true + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) require.NotNil(t, server) - cli.port = getPortFromTCPAddr(server.listener.Addr()) - cli.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go func() { err := server.Run(nil) require.NoError(t, err) }() - time.Sleep(time.Millisecond * 100) - + <-RunInGoTestChan + cli.port = getPortFromTCPAddr(server.listener.Addr()) + cli.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) resp, err := cli.fetchStatus("/status") // server is up require.NoError(t, err) require.Nil(t, resp.Body.Close()) @@ -2508,17 +2513,18 @@ func TestLocalhostClientMapping(t *testing.T) { cfg.Status.ReportStatus = false ts := createTidbTestSuite(t) - + RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) - cli.port = getPortFromTCPAddr(server.listener.Addr()) + server.SetDomain(ts.domain) go func() { err := server.Run(nil) require.NoError(t, err) }() defer server.Close() + <-RunInGoTestChan + cli.port = getPortFromTCPAddr(server.listener.Addr()) cli.waitUntilServerCanConnect() - cli.port = getPortFromTCPAddr(server.listener.Addr()) // Create a db connection for root db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) { diff --git a/tidb-server/main.go b/tidb-server/main.go index 991d67b0311d0..cc6cc23bfb01a 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -219,11 +219,6 @@ func main() { storage, dom := createStoreAndDomain() svr := createServer(storage, dom) - // Register error API is not thread-safe, the caller MUST NOT register errors after initialization. - // To prevent misuse, set a flag to indicate that register new error will panic immediately. - // For regression of issue like https://github.com/pingcap/tidb/issues/28190 - terror.RegisterFinish() - exited := make(chan struct{}) signal.SetupSignalHandler(func(graceful bool) { svr.Close() @@ -232,7 +227,6 @@ func main() { close(exited) }) topsql.SetupTopSQL() - terror.MustNil(svr.Run(dom)) <-exited syncLog()