diff --git a/server/conn.go b/server/conn.go index ffad5d84263b8..7ef95d7bd8ad7 100644 --- a/server/conn.go +++ b/server/conn.go @@ -71,6 +71,7 @@ import ( "github.com/pingcap/tidb/metrics" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/plugin" + "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" @@ -251,6 +252,18 @@ func (cc *clientConn) handshake(ctx context.Context) error { } return err } + + // MySQL supports an "init_connect" query, which can be run on initial connection. + // The query must return a non-error or the client is disconnected. + if err := cc.initConnect(ctx); err != nil { + logutil.Logger(ctx).Warn("init_connect failed", zap.Error(err)) + initErr := errNewAbortingConnection.FastGenByArgs(cc.connectionID, "unconnected", cc.user, cc.peerHost, "init_connect command failed") + if err1 := cc.writeError(ctx, initErr); err1 != nil { + terror.Log(err1) + } + return initErr + } + data := cc.alloc.AllocWithLen(4, 32) data = append(data, mysql.OKHeader) data = append(data, 0, 0) @@ -722,6 +735,57 @@ func (cc *clientConn) PeerHost(hasPassword string) (host, port string, err error return } +// skipInitConnect follows MySQL's rules of when init-connect should be skipped. +// In 5.7 it is any user with SUPER privilege, but in 8.0 it is: +// - SUPER or the CONNECTION_ADMIN dynamic privilege. +// - (additional exception) users with expired passwords (not yet supported) +func (cc *clientConn) skipInitConnect() bool { + checker := privilege.GetPrivilegeManager(cc.ctx.Session) + activeRoles := cc.ctx.GetSessionVars().ActiveRoles + return checker != nil && checker.RequestVerification(activeRoles, "", "", "", mysql.SuperPriv) +} + +// initConnect runs the initConnect SQL statement if it has been specified. +// The semantics are MySQL compatible. +func (cc *clientConn) initConnect(ctx context.Context) error { + val, err := cc.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.InitConnect) + if err != nil { + return err + } + if val == "" || cc.skipInitConnect() { + return nil + } + logutil.Logger(ctx).Debug("init_connect starting") + stmts, err := cc.ctx.Parse(ctx, val) + if err != nil { + return err + } + for _, stmt := range stmts { + rs, err := cc.ctx.ExecuteStmt(ctx, stmt) + if err != nil { + return err + } + // init_connect does not care about the results, + // but they need to be drained because of lazy loading. + if rs != nil { + req := rs.NewChunk() + for { + if err = rs.Next(ctx, req); err != nil { + return err + } + if req.NumRows() == 0 { + break + } + } + if err := rs.Close(); err != nil { + return err + } + } + } + logutil.Logger(ctx).Debug("init_connect complete") + return nil +} + // Run reads client query and writes query result to client in for loop, if there is a panic during query handling, // it will be recovered and log the panic error. // This function returns and the connection is closed if there is an IO error or there is a panic. diff --git a/server/server.go b/server/server.go index 0f839cb559029..1b1929a9854d8 100644 --- a/server/server.go +++ b/server/server.go @@ -99,6 +99,7 @@ var ( errConCount = dbterror.ClassServer.NewStd(errno.ErrConCount) errSecureTransportRequired = dbterror.ClassServer.NewStd(errno.ErrSecureTransportRequired) errMultiStatementDisabled = dbterror.ClassServer.NewStd(errno.ErrMultiStatementDisabled) + errNewAbortingConnection = dbterror.ClassServer.NewStd(errno.ErrNewAbortingConnection) ) // DefaultCapability is the capability of the server when it is created using the default configuration. diff --git a/server/server_test.go b/server/server_test.go index 48eb0b7639ed7..a2f3c8d03244b 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1864,6 +1864,54 @@ func (cli *testServerClient) waitUntilServerOnline() { } } +func (cli *testServerClient) runTestInitConnect(c *C) { + + cli.runTests(c, nil, func(dbt *DBTest) { + dbt.mustExec(`SET GLOBAL init_connect="insert into test.ts VALUES (NOW());SET @a=1;"`) + dbt.mustExec(`CREATE USER init_nonsuper`) + dbt.mustExec(`CREATE USER init_super`) + dbt.mustExec(`GRANT SELECT, INSERT, DROP ON test.* TO init_nonsuper`) + dbt.mustExec(`GRANT SELECT, INSERT, DROP, SUPER ON *.* TO init_super`) + dbt.mustExec(`CREATE TABLE ts (a TIMESTAMP)`) + }) + + // test init_nonsuper + cli.runTests(c, func(config *mysql.Config) { + config.User = "init_nonsuper" + }, func(dbt *DBTest) { + rows := dbt.mustQuery(`SELECT @a`) + c.Assert(rows.Next(), IsTrue) + var a int + err := rows.Scan(&a) + c.Assert(err, IsNil) + dbt.Check(a, Equals, 1) + c.Assert(rows.Close(), IsNil) + }) + + // test init_super + cli.runTests(c, func(config *mysql.Config) { + config.User = "init_super" + }, func(dbt *DBTest) { + rows := dbt.mustQuery(`SELECT IFNULL(@a,"")`) + c.Assert(rows.Next(), IsTrue) + var a string + err := rows.Scan(&a) + c.Assert(err, IsNil) + dbt.Check(a, Equals, "") // null + c.Assert(rows.Close(), IsNil) + // change the init-connect to invalid. + dbt.mustExec(`SET GLOBAL init_connect="invalidstring"`) + }) + + db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) { + config.User = "init_nonsuper" + })) + c.Assert(err, IsNil, Commentf("Error connecting")) // doesn't fail because of lazy loading + defer db.Close() // may already be closed + _, err = db.Exec("SELECT 1") // fails because of init sql + c.Assert(err, NotNil) +} + // Client errors are only incremented when using the TiDB Server protocol, // and not internal SQL statements. Thus, this test is in the server-test suite. func (cli *testServerClient) runTestInfoschemaClientErrors(t *C) { diff --git a/server/tidb_test.go b/server/tidb_test.go index 7bea9bd655d9d..9c5b76bf7d251 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -952,6 +952,10 @@ func (ts *tidbTestSuite) TestClientErrors(c *C) { ts.runTestInfoschemaClientErrors(c) } +func (ts *tidbTestSuite) TestInitConnect(c *C) { + ts.runTestInitConnect(c) +} + func (ts *tidbTestSuite) TestSumAvg(c *C) { c.Parallel() ts.runTestSumAvg(c) diff --git a/sessionctx/variable/noop.go b/sessionctx/variable/noop.go index 1eba30079ee2d..7ba11357b81fb 100644 --- a/sessionctx/variable/noop.go +++ b/sessionctx/variable/noop.go @@ -153,7 +153,6 @@ var noopSysVars = []*SysVar{ {Scope: ScopeGlobal | ScopeSession, Name: QueryCacheType, Value: BoolOff, Type: TypeEnum, PossibleValues: []string{BoolOff, BoolOn, "DEMAND"}}, {Scope: ScopeNone, Name: "innodb_rollback_on_timeout", Value: "0"}, {Scope: ScopeGlobal | ScopeSession, Name: "query_alloc_block_size", Value: "8192"}, - {Scope: ScopeGlobal | ScopeSession, Name: InitConnect, Value: ""}, {Scope: ScopeNone, Name: "have_compress", Value: "YES"}, {Scope: ScopeNone, Name: "thread_concurrency", Value: "10"}, {Scope: ScopeGlobal | ScopeSession, Name: "query_prealloc_size", Value: "8192"}, diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 6eb8ce417e0d3..da010df34ac24 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -579,6 +579,7 @@ var defaultSysVars = []*SysVar{ } return oracle.LocalTxnScope }()}, + {Scope: ScopeGlobal, Name: InitConnect, Value: ""}, /* TiDB specific variables */ {Scope: ScopeGlobal | ScopeSession, Name: TiDBAllowMPPExecution, Type: TypeBool, Value: BoolToOnOff(DefTiDBAllowMPPExecution)}, {Scope: ScopeSession, Name: TiDBEnforceMPPExecution, Type: TypeBool, Value: BoolToOnOff(config.GetGlobalConfig().Performance.EnforceMPP)},