Skip to content

Commit

Permalink
server: start to listen after init stats complete (#51472) (#51512)
Browse files Browse the repository at this point in the history
close #51473
  • Loading branch information
ti-chi-bot authored Mar 6, 2024
1 parent 5461751 commit 1dc6edf
Show file tree
Hide file tree
Showing 13 changed files with 128 additions and 74 deletions.
5 changes: 4 additions & 1 deletion br/pkg/mock/mock_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func NewCluster() (*Cluster, error) {
// Start runs a mock cluster.
func (mock *Cluster) Start() error {
server.RunInGoTest = true
server.RunInGoTestChan = make(chan struct{})
mock.TiDBDriver = server.NewTiDBDriver(mock.Storage)
cfg := config.NewConfig()
// let tidb random select a port
Expand All @@ -104,6 +105,7 @@ func (mock *Cluster) Start() error {
panic(err1)
}
}()
<-server.RunInGoTestChan
mock.DSN = waitUntilServerOnline("127.0.0.1", cfg.Status.StatusPort)
return nil
}
Expand Down Expand Up @@ -178,7 +180,8 @@ func waitUntilServerOnline(addr string, statusPort uint) string {
}
if retry == retryTime {
log.Panic("failed to connect HTTP status in every 10 ms",
zap.Int("retryTime", retryTime))
zap.Int("retryTime", retryTime),
zap.String("url", statusURL))
}
return strings.SplitAfter(dsn, "/")[0]
}
1 change: 1 addition & 0 deletions domain/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

func TestMain(m *testing.M) {
server.RunInGoTest = true
server.RunInGoTestChan = make(chan struct{})
testsetup.SetupForCommonTest()
opts := []goleak.Option{
goleak.IgnoreTopFunction("github.com/golang/glog.(*loggingT).flushDaemon"),
Expand Down
6 changes: 3 additions & 3 deletions server/extract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ func TestExtractHandler(t *testing.T) {
server, err := NewServer(cfg, driver)
require.NoError(t, err)
defer server.Close()

client.port = getPortFromTCPAddr(server.listener.Addr())
client.statusPort = getPortFromTCPAddr(server.statusListener.Addr())
go func() {
err := server.Run(nil)
require.NoError(t, err)
}()
<-RunInGoTestChan
client.port = getPortFromTCPAddr(server.listener.Addr())
client.statusPort = getPortFromTCPAddr(server.statusListener.Addr())
client.waitUntilServerOnline()
startTime := time.Now()
time.Sleep(time.Second)
Expand Down
9 changes: 5 additions & 4 deletions server/http_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,16 +454,17 @@ func (ts *basicHTTPHandlerTestSuite) startServer(t *testing.T) {
cfg.Port = 0
cfg.Status.StatusPort = 0
cfg.Status.ReportStatus = true

RunInGoTestChan = make(chan struct{})
server, err := NewServer(cfg, ts.tidbdrv)
require.NoError(t, err)
ts.port = getPortFromTCPAddr(server.listener.Addr())
ts.statusPort = getPortFromTCPAddr(server.statusListener.Addr())
ts.server = server
go func() {
err := server.Run(ts.domain)
require.NoError(t, err)
}()
<-RunInGoTestChan
ts.port = getPortFromTCPAddr(server.listener.Addr())
ts.statusPort = getPortFromTCPAddr(server.statusListener.Addr())
ts.server = server
ts.waitUntilServerOnline()

do, err := session.GetDomain(ts.store)
Expand Down
7 changes: 6 additions & 1 deletion server/http_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,13 @@ import (

const defaultStatusPort = 10080

func (s *Server) startStatusHTTP() {
func (s *Server) startStatusHTTP() error {
err := s.initHTTPListener()
if err != nil {
return err
}
go s.startHTTPServer()
return nil
}

func serveError(w http.ResponseWriter, status int, txt string) {
Expand Down
1 change: 1 addition & 0 deletions server/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
var testDataMap = make(testdata.BookKeeper, 1)

func TestMain(m *testing.M) {
RunInGoTestChan = make(chan struct{})
testsetup.SetupForCommonTest()

RunInGoTest = true // flag for NewServer to known it is running in test environment
Expand Down
8 changes: 4 additions & 4 deletions server/optimize_trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ func TestDumpOptimizeTraceAPI(t *testing.T) {
cfg.Port = client.port
cfg.Status.StatusPort = client.statusPort
cfg.Status.ReportStatus = true

RunInGoTestChan = make(chan struct{})
server, err := NewServer(cfg, driver)
require.NoError(t, err)
defer server.Close()

client.port = getPortFromTCPAddr(server.listener.Addr())
client.statusPort = getPortFromTCPAddr(server.statusListener.Addr())
go func() {
err := server.Run(nil)
require.NoError(t, err)
}()
<-RunInGoTestChan
client.port = getPortFromTCPAddr(server.listener.Addr())
client.statusPort = getPortFromTCPAddr(server.statusListener.Addr())
client.waitUntilServerOnline()

dom, err := session.GetDomain(store)
Expand Down
8 changes: 4 additions & 4 deletions server/plan_replayer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,21 @@ func TestDumpPlanReplayerAPI(t *testing.T) {
cfg.Port = client.port
cfg.Status.StatusPort = client.statusPort
cfg.Status.ReportStatus = true

RunInGoTestChan = make(chan struct{})
server, err := NewServer(cfg, driver)
require.NoError(t, err)
defer server.Close()

dom, err := session.GetDomain(store)
require.NoError(t, err)
server.SetDomain(dom)

client.port = getPortFromTCPAddr(server.listener.Addr())
client.statusPort = getPortFromTCPAddr(server.statusListener.Addr())
go func() {
err := server.Run(nil)
require.NoError(t, err)
}()
<-RunInGoTestChan
client.port = getPortFromTCPAddr(server.listener.Addr())
client.statusPort = getPortFromTCPAddr(server.statusListener.Addr())
client.waitUntilServerOnline()
filename, fileNameFromCapture := prepareData4PlanReplayer(t, client, dom)

Expand Down
55 changes: 46 additions & 9 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (

"github.com/blacktear23/go-proxyprotocol"
"github.com/pingcap/errors"
"github.com/pingcap/log"
autoid "github.com/pingcap/tidb/autoid_service"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/domain"
Expand Down Expand Up @@ -83,6 +84,8 @@ var (
osVersion string
// RunInGoTest represents whether we are run code in test.
RunInGoTest bool
// RunInGoTestChan is used to control the RunInGoTest.
RunInGoTestChan chan struct{}
)

func init() {
Expand Down Expand Up @@ -258,15 +261,19 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
if s.tlsConfig != nil {
s.capability |= mysql.ClientSSL
}
variable.RegisterStatistics(s)
return s, nil
}

func (s *Server) initTiDBListener() (err error) {
if s.cfg.Host != "" && (s.cfg.Port != 0 || RunInGoTest) {
addr := net.JoinHostPort(s.cfg.Host, strconv.Itoa(int(s.cfg.Port)))
tcpProto := "tcp"
if s.cfg.EnableTCP4Only {
tcpProto = "tcp4"
}
if s.listener, err = net.Listen(tcpProto, addr); err != nil {
return nil, errors.Trace(err)
return errors.Trace(err)
}
logutil.BgLogger().Info("server is running MySQL protocol", zap.String("addr", addr))
if RunInGoTest && s.cfg.Port == 0 {
Expand All @@ -276,18 +283,18 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {

if s.cfg.Socket != "" {
if err := cleanupStaleSocket(s.cfg.Socket); err != nil {
return nil, errors.Trace(err)
return errors.Trace(err)
}

if s.socket, err = net.Listen("unix", s.cfg.Socket); err != nil {
return nil, errors.Trace(err)
return errors.Trace(err)
}
logutil.BgLogger().Info("server is running MySQL protocol", zap.String("socket", s.cfg.Socket))
}

if s.socket == nil && s.listener == nil {
err = errors.New("Server not configured to listen on either -socket or -host and -port")
return nil, errors.Trace(err)
return errors.Trace(err)
}

if s.cfg.ProxyProtocol.Networks != "" {
Expand All @@ -299,7 +306,7 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
int(s.cfg.ProxyProtocol.HeaderTimeout), s.cfg.ProxyProtocol.Fallbackable)
if err != nil {
logutil.BgLogger().Error("ProxyProtocol networks parameter invalid")
return nil, errors.Trace(err)
return errors.Trace(err)
}
if s.listener != nil {
s.listener = ppListener
Expand All @@ -309,10 +316,13 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("socket", s.cfg.Socket))
}
}
return nil
}

func (s *Server) initHTTPListener() (err error) {
if s.cfg.Status.ReportStatus {
if err = s.listenStatusHTTPServer(); err != nil {
return nil, errors.Trace(err)
return errors.Trace(err)
}
}

Expand All @@ -339,7 +349,7 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {

variable.RegisterStatistics(s)

return s, nil
return nil
}

func cleanupStaleSocket(socket string) error {
Expand Down Expand Up @@ -398,23 +408,50 @@ func (s *Server) Run(dom *domain.Domain) error {

// Start HTTP API to report tidb info such as TPS.
if s.cfg.Status.ReportStatus {
s.startStatusHTTP()
err := s.startStatusHTTP()
if err != nil {
log.Error("failed to create the server", zap.Error(err), zap.Stack("stack"))
return err
}
}
if config.GetGlobalConfig().Performance.ForceInitStats && dom != nil {
<-dom.StatsHandle().InitStatsDone
}
// If error should be reported and exit the server it can be sent on this
// channel. Otherwise, end with sending a nil error to signal "done"
errChan := make(chan error, 2)
err := s.initTiDBListener()
if err != nil {
log.Error("failed to create the server", zap.Error(err), zap.Stack("stack"))
return err
}
// Register error API is not thread-safe, the caller MUST NOT register errors after initialization.
// To prevent misuse, set a flag to indicate that register new error will panic immediately.
// For regression of issue like https://github.com/pingcap/tidb/issues/28190
terror.RegisterFinish()
go s.startNetworkListener(s.listener, false, errChan)
go s.startNetworkListener(s.socket, true, errChan)
err := <-errChan
if RunInGoTest && !isClosed(RunInGoTestChan) {
close(RunInGoTestChan)
}
err = <-errChan
if err != nil {
return err
}
return <-errChan
}

// isClosed is to check if the channel is closed
func isClosed(ch chan struct{}) bool {
select {
case <-ch:
return true
default:
}

return false
}

func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool, errChan chan error) {
if listener == nil {
errChan <- nil
Expand Down
6 changes: 3 additions & 3 deletions server/statistics_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ func TestDumpStatsAPI(t *testing.T) {
server, err := NewServer(cfg, driver)
require.NoError(t, err)
defer server.Close()

client.port = getPortFromTCPAddr(server.listener.Addr())
client.statusPort = getPortFromTCPAddr(server.statusListener.Addr())
go func() {
err := server.Run(nil)
require.NoError(t, err)
}()
<-RunInGoTestChan
client.port = getPortFromTCPAddr(server.listener.Addr())
client.statusPort = getPortFromTCPAddr(server.statusListener.Addr())
client.waitUntilServerOnline()

dom, err := session.GetDomain(store)
Expand Down
26 changes: 16 additions & 10 deletions server/tidb_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,15 @@ func TestTLSAuto(t *testing.T) {
cfg.Security.RSAKeySize = 528 // Reduces unittest runtime
err := os.MkdirAll(cfg.TempStoragePath, 0700)
require.NoError(t, err)
RunInGoTestChan = make(chan struct{})
server, err := NewServer(cfg, ts.tidbdrv)
require.NoError(t, err)
cli.port = getPortFromTCPAddr(server.listener.Addr())
go func() {
err := server.Run(nil)
require.NoError(t, err)
}()
time.Sleep(time.Millisecond * 100)
<-RunInGoTestChan
cli.port = getPortFromTCPAddr(server.listener.Addr())
err = cli.runTestTLSConnection(t, connOverrider) // Relying on automatically created TLS certificates
require.NoError(t, err)

Expand Down Expand Up @@ -203,14 +204,15 @@ func TestTLSBasic(t *testing.T) {
SSLCert: fileName("server-cert.pem"),
SSLKey: fileName("server-key.pem"),
}
RunInGoTestChan = make(chan struct{})
server, err := NewServer(cfg, ts.tidbdrv)
require.NoError(t, err)
cli.port = getPortFromTCPAddr(server.listener.Addr())
go func() {
err := server.Run(nil)
require.NoError(t, err)
}()
time.Sleep(time.Millisecond * 100)
<-RunInGoTestChan
cli.port = getPortFromTCPAddr(server.listener.Addr())
err = cli.runTestTLSConnection(t, connOverrider) // We should establish connection successfully.
require.NoError(t, err)
cli.runTestRegression(t, connOverrider, "TLSRegression")
Expand Down Expand Up @@ -266,15 +268,16 @@ func TestTLSVerify(t *testing.T) {
SSLCert: fileName("server-cert.pem"),
SSLKey: fileName("server-key.pem"),
}
RunInGoTestChan = make(chan struct{})
server, err := NewServer(cfg, ts.tidbdrv)
require.NoError(t, err)
defer server.Close()
cli.port = getPortFromTCPAddr(server.listener.Addr())
go func() {
err := server.Run(nil)
require.NoError(t, err)
}()
time.Sleep(time.Millisecond * 100)
<-RunInGoTestChan
cli.port = getPortFromTCPAddr(server.listener.Addr())
// The client does not provide a certificate, the connection should succeed.
err = cli.runTestTLSConnection(t, nil)
require.NoError(t, err)
Expand Down Expand Up @@ -372,13 +375,14 @@ func TestErrorNoRollback(t *testing.T) {
}
server, err := NewServer(cfg, ts.tidbdrv)
require.NoError(t, err)
cli.port = getPortFromTCPAddr(server.listener.Addr())
RunInGoTestChan = make(chan struct{})
go func() {
err := server.Run(nil)
require.NoError(t, err)
}()
defer server.Close()
time.Sleep(time.Millisecond * 100)
<-RunInGoTestChan
cli.port = getPortFromTCPAddr(server.listener.Addr())
connOverrider := func(config *mysql.Config) {
config.TLSConfig = "client-cert-rollback-test"
}
Expand Down Expand Up @@ -514,14 +518,16 @@ func TestReloadTLS(t *testing.T) {
SSLCert: "/tmp/server-cert-reload.pem",
SSLKey: "/tmp/server-key-reload.pem",
}
RunInGoTestChan = make(chan struct{})
server, err := NewServer(cfg, ts.tidbdrv)
require.NoError(t, err)
cli.port = getPortFromTCPAddr(server.listener.Addr())

go func() {
err := server.Run(nil)
require.NoError(t, err)
}()
time.Sleep(time.Millisecond * 100)
<-RunInGoTestChan
cli.port = getPortFromTCPAddr(server.listener.Addr())
// The client provides a valid certificate.
connOverrider := func(config *mysql.Config) {
config.TLSConfig = "client-certificate-reload"
Expand Down
Loading

0 comments on commit 1dc6edf

Please sign in to comment.