Skip to content

Commit

Permalink
server, tidb-server: improve unix socket handling (#8836)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgo authored and jackysp committed Jan 9, 2019
1 parent c68ee73 commit 692693a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 6 deletions.
61 changes: 55 additions & 6 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
Expand Down Expand Up @@ -80,6 +81,7 @@ type Server struct {
tlsConfig *tls.Config
driver IDriver
listener net.Listener
socket net.Listener
rwlock *sync.RWMutex
concurrentLimiter *TokenLimiter
clients map[uint32]*clientConn
Expand Down Expand Up @@ -133,6 +135,39 @@ func (s *Server) isUnixSocket() bool {
return s.cfg.Socket != ""
}

func (s *Server) forwardUnixSocketToTCP() {
addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
for {
if s.listener == nil {
return // server shutdown has started
}
if uconn, err := s.socket.Accept(); err == nil {
log.Infof("server socket forwarding from [%s] to [%s]", s.cfg.Socket, addr)
go s.handleForwardedConnection(uconn, addr)
} else {
if s.listener != nil {
log.Errorf("server failed to forward from [%s] to [%s], err: %s", s.cfg.Socket, addr, err)
}
}
}
}

func (s *Server) handleForwardedConnection(uconn net.Conn, addr string) {
defer terror.Call(uconn.Close)
if tconn, err := net.Dial("tcp", addr); err == nil {
go func() {
if _, err := io.Copy(uconn, tconn); err != nil {
log.Warningf("copy server to socket failed: %s", err)
}
}()
if _, err := io.Copy(tconn, uconn); err != nil {
log.Warningf("socket forward copy failed: %s", err)
}
} else {
log.Warningf("socket forward failed: could not connect to [%s], err: %s", addr, err)
}
}

// NewServer creates a new Server.
func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
s := &Server{
Expand All @@ -151,15 +186,24 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
}

var err error
if cfg.Socket != "" {
if s.listener, err = net.Listen("unix", cfg.Socket); err == nil {
log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket)
}
} else {

if s.cfg.Host != "" && s.cfg.Port != 0 {
addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
if s.listener, err = net.Listen("tcp", addr); err == nil {
log.Infof("Server is running MySQL Protocol at [%s]", addr)
if cfg.Socket != "" {
if s.socket, err = net.Listen("unix", s.cfg.Socket); err == nil {
log.Infof("Server redirecting [%s] to [%s]", s.cfg.Socket, addr)
go s.forwardUnixSocketToTCP()
}
}
}
} else if cfg.Socket != "" {
if s.listener, err = net.Listen("unix", cfg.Socket); err == nil {
log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket)
}
} else {
err = errors.New("Server not configured to listen on either -socket or -host and -port")
}

if cfg.ProxyProtocol.Networks != "" {
Expand Down Expand Up @@ -292,6 +336,11 @@ func (s *Server) Close() {
terror.Log(errors.Trace(err))
s.listener = nil
}
if s.socket != nil {
err := s.socket.Close()
terror.Log(errors.Trace(err))
s.socket = nil
}
if s.statusServer != nil {
err := s.statusServer.Close()
terror.Log(errors.Trace(err))
Expand Down Expand Up @@ -419,7 +468,7 @@ func (s *Server) kickIdleConnection() {
for _, cc := range conns {
err := cc.Close()
if err != nil {
log.Error("close connection error:", err)
log.Errorf("close connection error: %s", err)
}
}
}
Expand Down
26 changes: 26 additions & 0 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,34 @@ func (ts *TidbTestSuite) TestMultiStatements(c *C) {
runTestMultiStatements(c)
}

func (ts *TidbTestSuite) TestSocketForwarding(c *C) {
cfg := config.NewConfig()
cfg.Socket = "/tmp/tidbtest.sock"
cfg.Port = 3999
os.Remove(cfg.Socket)
cfg.Status.ReportStatus = false

server, err := NewServer(cfg, ts.tidbdrv)
c.Assert(err, IsNil)
go server.Run()
time.Sleep(time.Millisecond * 100)
defer server.Close()

runTestRegression(c, func(config *mysql.Config) {
config.User = "root"
config.Net = "unix"
config.Addr = "/tmp/tidbtest.sock"
config.DBName = "test"
config.Strict = true
}, "SocketRegression")
}

func (ts *TidbTestSuite) TestSocket(c *C) {
cfg := config.NewConfig()
cfg.Socket = "/tmp/tidbtest.sock"
cfg.Port = 0
os.Remove(cfg.Socket)
cfg.Host = ""
cfg.Status.ReportStatus = false

server, err := NewServer(cfg, ts.tidbdrv)
Expand All @@ -178,6 +203,7 @@ func (ts *TidbTestSuite) TestSocket(c *C) {
config.DBName = "test"
config.Strict = true
}, "SocketRegression")

}

// generateCert generates a private key and a certificate in PEM format based on parameters.
Expand Down

0 comments on commit 692693a

Please sign in to comment.