Skip to content

Commit

Permalink
server: Unix Domain Socket not setting 'localhost' as host (#25914)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjonss authored Jul 6, 2021
1 parent 787772d commit 383fb9a
Show file tree
Hide file tree
Showing 4 changed files with 365 additions and 78 deletions.
3 changes: 2 additions & 1 deletion server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ type clientConn struct {
collation uint8 // collation used by client, may be different from the collation used by database.
lastActive time.Time // last active time
authPlugin string // default authentication plugin
isUnixSocket bool // connection is Unix Socket file

// mu is used for cancelling the execution of current transaction.
mu struct {
Expand Down Expand Up @@ -837,7 +838,7 @@ func (cc *clientConn) PeerHost(hasPassword string) (host, port string, err error
return cc.peerHost, "", nil
}
host = variable.DefHostname
if cc.server.isUnixSocket() {
if cc.isUnixSocket {
cc.peerHost = host
return
}
Expand Down
135 changes: 66 additions & 69 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import (
"crypto/tls"
"flag"
"fmt"
"io"
"math/rand"
"net"
"net/http"
Expand Down Expand Up @@ -184,44 +183,6 @@ func (s *Server) newConn(conn net.Conn) *clientConn {
return cc
}

// isUnixSocket should ideally be a function of clientConnection!
// But currently since unix-socket connections are forwarded to TCP when the server listens on both, it can really only be accurate on a server-level.
// If the server is listening on both, it *must* return FALSE for remote-host authentication to be performed correctly. See #23460.
func (s *Server) isUnixSocket() bool {
return s.cfg.Socket != "" && s.cfg.Port == 0
}

func (s *Server) forwardUnixSocketToTCP() {
addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
for {
if s.listener == nil {
return // server shutdown has started
}
if uconn, err := s.socket.Accept(); err == nil {
logutil.BgLogger().Info("server socket forwarding", zap.String("from", s.cfg.Socket), zap.String("to", addr))
go s.handleForwardedConnection(uconn, addr)
} else if s.listener != nil {
logutil.BgLogger().Error("server failed to forward", zap.String("from", s.cfg.Socket), zap.String("to", addr), zap.Error(err))
}
}
}

func (s *Server) handleForwardedConnection(uconn net.Conn, addr string) {
defer terror.Call(uconn.Close)
if tconn, err := net.Dial("tcp", addr); err == nil {
go func() {
if _, err := io.Copy(uconn, tconn); err != nil {
logutil.BgLogger().Warn("copy server to socket failed", zap.Error(err))
}
}()
if _, err := io.Copy(tconn, uconn); err != nil {
logutil.BgLogger().Warn("socket forward copy failed", zap.Error(err))
}
} else {
logutil.BgLogger().Warn("socket forward failed: could not connect", zap.String("addr", addr), zap.Error(err))
}
}

// NewServer creates a new Server.
func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
s := &Server{
Expand Down Expand Up @@ -257,42 +218,52 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
if s.cfg.EnableTCP4Only {
tcpProto = "tcp4"
}
if s.listener, err = net.Listen(tcpProto, addr); err == nil {
logutil.BgLogger().Info("server is running MySQL protocol", zap.String("addr", addr))
if cfg.Socket != "" {
if s.socket, err = net.Listen("unix", s.cfg.Socket); err == nil {
logutil.BgLogger().Info("server redirecting", zap.String("from", s.cfg.Socket), zap.String("to", addr))
go s.forwardUnixSocketToTCP()
}
}
if runInGoTest && s.cfg.Port == 0 {
s.cfg.Port = uint(s.listener.Addr().(*net.TCPAddr).Port)
}
if s.listener, err = net.Listen(tcpProto, addr); err != nil {
return nil, errors.Trace(err)
}
} else if cfg.Socket != "" {
if s.listener, err = net.Listen("unix", cfg.Socket); err == nil {
logutil.BgLogger().Info("server is running MySQL protocol", zap.String("socket", cfg.Socket))
logutil.BgLogger().Info("server is running MySQL protocol", zap.String("addr", addr))
if runInGoTest && s.cfg.Port == 0 {
s.cfg.Port = uint(s.listener.Addr().(*net.TCPAddr).Port)
}
} else {
}

if s.cfg.Socket != "" {
if s.socket, err = net.Listen("unix", s.cfg.Socket); err != nil {
return nil, errors.Trace(err)
}
logutil.BgLogger().Info("server is running MySQL protocol", zap.String("socket", s.cfg.Socket))
}

if s.socket == nil && s.listener == nil {
err = errors.New("Server not configured to listen on either -socket or -host and -port")
return nil, errors.Trace(err)
}

if cfg.ProxyProtocol.Networks != "" {
pplistener, errProxy := proxyprotocol.NewListener(s.listener, cfg.ProxyProtocol.Networks,
int(cfg.ProxyProtocol.HeaderTimeout))
if errProxy != nil {
if s.cfg.ProxyProtocol.Networks != "" {
proxyTarget := s.listener
if proxyTarget == nil {
proxyTarget = s.socket
}
pplistener, err := proxyprotocol.NewListener(proxyTarget, s.cfg.ProxyProtocol.Networks,
int(s.cfg.ProxyProtocol.HeaderTimeout))
if err != nil {
logutil.BgLogger().Error("ProxyProtocol networks parameter invalid")
return nil, errors.Trace(errProxy)
return nil, errors.Trace(err)
}
if s.listener != nil {
s.listener = pplistener
logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("host", s.cfg.Host))
} else {
s.socket = pplistener
logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("socket", s.cfg.Socket))
}
logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("host", s.cfg.Host))
s.listener = pplistener
}

if s.cfg.Status.ReportStatus && err == nil {
if s.cfg.Status.ReportStatus {
err = s.listenStatusHTTPServer()
}
if err != nil {
return nil, errors.Trace(err)
if err != nil {
return nil, errors.Trace(err)
}
}

// Init rand seed for randomBuf()
Expand Down Expand Up @@ -336,12 +307,34 @@ func (s *Server) Run() error {
if s.cfg.Status.ReportStatus {
s.startStatusHTTP()
}
// If error should be reported and exit the server it can be sent on this
// channel. Otherwise end with sending a nil error to signal "done"
errChan := make(chan error)
go s.startNetworkListener(s.listener, false, errChan)
go s.startNetworkListener(s.socket, true, errChan)
err := <-errChan
if err != nil {
return err
}
return <-errChan
}

func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool, errChan chan error) {
if listener == nil {
errChan <- nil
return
}
for {
conn, err := s.listener.Accept()
conn, err := listener.Accept()
if err != nil {
if opErr, ok := err.(*net.OpError); ok {
if opErr.Err.Error() == "use of closed network connection" {
return nil
if s.inShutdownMode {
errChan <- nil
} else {
errChan <- err
}
return
}
}

Expand All @@ -352,10 +345,14 @@ func (s *Server) Run() error {
}

logutil.BgLogger().Error("accept failed", zap.Error(err))
return errors.Trace(err)
errChan <- err
return
}

clientConn := s.newConn(conn)
if isUnixSocket {
clientConn.isUnixSocket = true
}

err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
Expand Down Expand Up @@ -506,7 +503,7 @@ func (s *Server) onConn(conn *clientConn) {

func (cc *clientConn) connectInfo() *variable.ConnectionInfo {
connType := "Socket"
if cc.server.isUnixSocket() {
if cc.isUnixSocket {
connType = "UnixSocket"
} else if cc.tlsConn != nil {
connType = "SSL/TLS"
Expand Down
5 changes: 0 additions & 5 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,9 @@ func (cli *testServerClient) runTests(c *C, overrider configOverrider, tests ...
c.Assert(err, IsNil)
}()

_, err = db.Exec("DROP TABLE IF EXISTS test")
c.Assert(err, IsNil)

dbt := &DBTest{c, db}
for _, test := range tests {
test(dbt)
// fixed query error
_, _ = dbt.db.Exec("DROP TABLE IF EXISTS test")
}
}

Expand Down
Loading

0 comments on commit 383fb9a

Please sign in to comment.