From 16944f140c16014702cef2c01d1238db9a2da293 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 12 Apr 2022 16:34:35 +0800 Subject: [PATCH] server: fix bug https://asktug.com/t/topic/213082/11 (#29577) (#30046) close pingcap/tidb#29709 --- server/conn.go | 21 +++++++++++++++------ server/conn_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/server/conn.go b/server/conn.go index 561dae86c7b80..9c50fd58c816f 100644 --- a/server/conn.go +++ b/server/conn.go @@ -693,7 +693,7 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeResponse41) error { if resp.Capability&mysql.ClientPluginAuth > 0 { - newAuth, err := cc.checkAuthPlugin(ctx, &resp.AuthPlugin) + newAuth, err := cc.checkAuthPlugin(ctx, resp) if err != nil { logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err)) } @@ -810,7 +810,7 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte) error { } // Check if the Authentication Plugin of the server, client and user configuration matches -func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ([]byte, error) { +func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeResponse41) ([]byte, error) { // Open a context unless this was done before. if cc.ctx == nil { err := cc.openSession() @@ -819,12 +819,21 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ( } } - userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: cc.peerHost}) + authData := resp.Auth + hasPassword := "YES" + if len(authData) == 0 { + hasPassword = "NO" + } + host, _, err := cc.PeerHost(hasPassword) + if err != nil { + return nil, err + } + userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: host}) if err != nil { return nil, err } if len(userplugin) == 0 { - *authPlugin = mysql.AuthNativePassword + resp.AuthPlugin = mysql.AuthNativePassword return nil, nil } @@ -833,12 +842,12 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ( // or if the authentication method send by the server doesn't match the authentication // method send by the client (*authPlugin) then we need to switch the authentication // method to match the one configured for that specific user. - if (cc.authPlugin != userplugin) || (cc.authPlugin != *authPlugin) { + if (cc.authPlugin != userplugin) || (cc.authPlugin != resp.AuthPlugin) { authData, err := cc.authSwitchRequest(ctx, userplugin) if err != nil { return nil, err } - *authPlugin = userplugin + resp.AuthPlugin = userplugin return authData, nil } diff --git a/server/conn_test.go b/server/conn_test.go index 0c99937eac746..19dc2094dad74 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -914,3 +914,43 @@ func (ts *ConnTestSuite) TestHandleAuthPlugin(c *C) { err = cc.handleAuthPlugin(ctx, &resp) c.Assert(err, IsNil) } + +func (ts *ConnTestSuite) TestAuthPlugin2(c *C) { + + c.Parallel() + + cfg := newTestConfig() + cfg.Socket = "" + cfg.Port = 0 + cfg.Status.StatusPort = 0 + + drv := NewTiDBDriver(ts.store) + srv, err := NewServer(cfg, drv) + c.Assert(err, IsNil) + + cc := &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "root", + } + ctx := context.Background() + se, _ := session.CreateSession4Test(ts.store) + tc := &TiDBContext{ + Session: se, + stmts: make(map[int]*TiDBStatement), + } + cc.ctx = tc + + resp := handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + } + + cc.isUnixSocket = true + _, err = cc.checkAuthPlugin(ctx, &resp) + c.Assert(err, IsNil) + +}