Skip to content

Commit

Permalink
server, privileges: make tidb_auth_session_token compatible with pass…
Browse files Browse the repository at this point in the history
…word expiration and resource group (#40735)

ref #40614
  • Loading branch information
djshow832 authored Jan 20, 2023
1 parent de856d9 commit 556c267
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 14 deletions.
1 change: 1 addition & 0 deletions parser/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type UserIdentity struct {
CurrentUser bool
AuthUsername string // Username matched in privileges system
AuthHostname string // Match in privs system (i.e. could be a wildcard)
AuthPlugin string // The plugin specified in handshake, only used during authentication.
}

// Restore implements Node interface.
Expand Down
2 changes: 2 additions & 0 deletions privilege/privileges/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ go_library(
"//parser/terror",
"//privilege",
"//sessionctx",
"//sessionctx/sessionstates",
"//sessionctx/variable",
"//types",
"//util",
Expand Down Expand Up @@ -66,6 +67,7 @@ go_test(
"//privilege",
"//session",
"//sessionctx",
"//sessionctx/sessionstates",
"//sessionctx/variable",
"//testkit",
"//testkit/testsetup",
Expand Down
15 changes: 13 additions & 2 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util"
Expand Down Expand Up @@ -544,7 +545,13 @@ func (p *UserPrivileges) ConnectionVerification(user *auth.UserIdentity, authUse
return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword)
}

if record.AuthPlugin == mysql.AuthTiDBAuthToken {
// If the user uses session token to log in, skip checking record.AuthPlugin.
if user.AuthPlugin == mysql.AuthTiDBSessionToken {
if err = sessionstates.ValidateSessionToken(authentication, user.Username); err != nil {
logutil.BgLogger().Warn("verify session token failed", zap.String("username", user.Username), zap.Error(err))
return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword)
}
} else if record.AuthPlugin == mysql.AuthTiDBAuthToken {
if len(authentication) == 0 {
logutil.BgLogger().Error("empty authentication")
return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword)
Expand Down Expand Up @@ -617,7 +624,11 @@ func (p *UserPrivileges) ConnectionVerification(user *auth.UserIdentity, authUse
} else {
info.ResourceGroupName = record.ResourceGroup
}
info.InSandBoxMode, err = p.CheckPasswordExpired(sessionVars, record)
// Skip checking password expiration if the session is migrated from another session.
// Otherwise, the user cannot log in or execute statements after migration.
if user.AuthPlugin != mysql.AuthTiDBSessionToken {
info.InSandBoxMode, err = p.CheckPasswordExpired(sessionVars, record)
}
return
}

Expand Down
57 changes: 57 additions & 0 deletions privilege/privileges/privileges_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"time"
Expand All @@ -39,6 +41,7 @@ import (
"github.com/pingcap/tidb/privilege/privileges"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/testkit/testutil"
Expand Down Expand Up @@ -3148,3 +3151,57 @@ func TestPasswordExpireWithSandBoxMode(t *testing.T) {
require.NoError(t, err)
require.False(t, tk.Session().InSandBoxMode())
}

func TestVerificationInfoWithSessionTokenPlugin(t *testing.T) {
// prepare signing certs
tempDir := t.TempDir()
certPath := filepath.Join(tempDir, "test1_cert.pem")
keyPath := filepath.Join(tempDir, "test1_key.pem")
err := util.CreateCertificates(certPath, keyPath, 4096, x509.RSA, x509.UnknownSignatureAlgorithm)
require.NoError(t, err)
sessionstates.SetKeyPath(keyPath)
sessionstates.SetCertPath(certPath)

// prepare user
store := createStoreAndPrepareDB(t)
rootTk := testkit.NewTestKit(t, store)
rootTk.MustExec(`CREATE USER 'testuser'@'localhost' PASSWORD EXPIRE`)
// prepare session token
token, err := sessionstates.CreateSessionToken("testuser")
require.NoError(t, err)
tokenBytes, err := json.Marshal(token)
require.NoError(t, err)

// Test password expiration without sandbox.
user := &auth.UserIdentity{Username: "testuser", Hostname: "localhost", AuthPlugin: mysql.AuthTiDBSessionToken}
tk := testkit.NewTestKit(t, store)
err = tk.Session().Auth(user, tokenBytes, nil)
require.NoError(t, err)
require.False(t, tk.Session().InSandBoxMode())

// Test password expiration with sandbox.
variable.IsSandBoxModeEnabled.Store(true)
err = tk.Session().Auth(user, tokenBytes, nil)
require.NoError(t, err)
require.False(t, tk.Session().InSandBoxMode())

// Disable resource group.
require.Equal(t, "", tk.Session().GetSessionVars().ResourceGroupName)

// Enable resource group.
variable.EnableResourceControl.Store(true)
err = tk.Session().Auth(user, tokenBytes, nil)
require.NoError(t, err)
require.Equal(t, "default", tk.Session().GetSessionVars().ResourceGroupName)

// Non-default resource group.
rootTk.MustExec("CREATE RESOURCE GROUP rg1 WRU_PER_SEC = 999")
rootTk.MustExec(`ALTER USER 'testuser'@'localhost' RESOURCE GROUP rg1`)
err = tk.Session().Auth(user, tokenBytes, nil)
require.NoError(t, err)
require.Equal(t, "rg1", tk.Session().GetSessionVars().ResourceGroupName)

// Wrong token
err = tk.Session().Auth(user, nil, nil)
require.ErrorContains(t, err, "Access denied")
}
13 changes: 2 additions & 11 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ import (
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/sessiontxn"
Expand Down Expand Up @@ -833,16 +832,8 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) e
return errAccessDeniedNoPassword.FastGenByArgs(cc.user, host)
}

userIdentity := &auth.UserIdentity{Username: cc.user, Hostname: host}
if authPlugin == mysql.AuthTiDBSessionToken {
if !cc.ctx.AuthWithoutVerification(userIdentity) {
return errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
if err = sessionstates.ValidateSessionToken(authData, cc.user); err != nil {
logutil.BgLogger().Warn("verify session token failed", zap.String("username", cc.user), zap.Error(err))
return errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
} else if err = cc.ctx.Auth(userIdentity, authData, cc.salt); err != nil {
userIdentity := &auth.UserIdentity{Username: cc.user, Hostname: host, AuthPlugin: authPlugin}
if err = cc.ctx.Auth(userIdentity, authData, cc.salt); err != nil {
return err
}
cc.ctx.SetPort(port)
Expand Down
9 changes: 8 additions & 1 deletion server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,7 @@ func TestAuthPlugin2(t *testing.T) {
require.NoError(t, err)
}

func TestAuthTokenPlugin(t *testing.T) {
func TestAuthSessionTokenPlugin(t *testing.T) {
// create the cert
tempDir := t.TempDir()
certPath := filepath.Join(tempDir, "test1_cert.pem")
Expand Down Expand Up @@ -1555,6 +1555,13 @@ func TestAuthTokenPlugin(t *testing.T) {
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin)
require.NoError(t, err)

// login succeeds even if the password expires now
tk.MustExec("ALTER USER auth_session_token PASSWORD EXPIRE")
err = cc.openSessionAndDoAuth([]byte{}, mysql.AuthNativePassword)
require.ErrorContains(t, err, "Your password has expired")
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin)
require.NoError(t, err)

// wrong token should fail
tokenBytes[0] ^= 0xff
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin)
Expand Down
4 changes: 4 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4138,6 +4138,10 @@ func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Conte
if len(s.lockedTables) > 0 {
return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has locked tables")
}
// It's insecure to migrate sandBoxMode because users can fake it.
if s.InSandBoxMode() {
return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session is in sandbox mode")
}

if err := s.sessionVars.EncodeSessionStates(ctx, sessionStates); err != nil {
return err
Expand Down
10 changes: 10 additions & 0 deletions sessionctx/sessionstates/session_states_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,16 @@ func TestShowStateFail(t *testing.T) {
tk.MustExec("drop table test.t1")
},
},
{
// enable sandbox mode
setFunc: func(tk *testkit.TestKit, conn server.MockConn) {
tk.Session().EnableSandBoxMode()
},
showErr: errno.ErrCannotMigrateSession,
cleanFunc: func(tk *testkit.TestKit) {
tk.Session().DisableSandBoxMode()
},
},
{
// after COM_STMT_SEND_LONG_DATA
setFunc: func(tk *testkit.TestKit, conn server.MockConn) {
Expand Down

0 comments on commit 556c267

Please sign in to comment.