Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server, sessionctx: improved mysql compatibility with support for init_connect (#23713) #26072

Merged
merged 4 commits into from
Aug 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ import (
"github.com/pingcap/tidb/metrics"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/plugin"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
Expand Down Expand Up @@ -251,6 +252,18 @@ func (cc *clientConn) handshake(ctx context.Context) error {
}
return err
}

// MySQL supports an "init_connect" query, which can be run on initial connection.
// The query must return a non-error or the client is disconnected.
if err := cc.initConnect(ctx); err != nil {
logutil.Logger(ctx).Warn("init_connect failed", zap.Error(err))
initErr := errNewAbortingConnection.FastGenByArgs(cc.connectionID, "unconnected", cc.user, cc.peerHost, "init_connect command failed")
if err1 := cc.writeError(ctx, initErr); err1 != nil {
terror.Log(err1)
}
return initErr
}

data := cc.alloc.AllocWithLen(4, 32)
data = append(data, mysql.OKHeader)
data = append(data, 0, 0)
Expand Down Expand Up @@ -722,6 +735,57 @@ func (cc *clientConn) PeerHost(hasPassword string) (host, port string, err error
return
}

// skipInitConnect follows MySQL's rules of when init-connect should be skipped.
// In 5.7 it is any user with SUPER privilege, but in 8.0 it is:
// - SUPER or the CONNECTION_ADMIN dynamic privilege.
// - (additional exception) users with expired passwords (not yet supported)
func (cc *clientConn) skipInitConnect() bool {
checker := privilege.GetPrivilegeManager(cc.ctx.Session)
activeRoles := cc.ctx.GetSessionVars().ActiveRoles
return checker != nil && checker.RequestVerification(activeRoles, "", "", "", mysql.SuperPriv)
}

// initConnect runs the initConnect SQL statement if it has been specified.
// The semantics are MySQL compatible.
func (cc *clientConn) initConnect(ctx context.Context) error {
val, err := cc.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.InitConnect)
if err != nil {
return err
}
if val == "" || cc.skipInitConnect() {
return nil
}
logutil.Logger(ctx).Debug("init_connect starting")
stmts, err := cc.ctx.Parse(ctx, val)
if err != nil {
return err
}
for _, stmt := range stmts {
rs, err := cc.ctx.ExecuteStmt(ctx, stmt)
if err != nil {
return err
}
// init_connect does not care about the results,
// but they need to be drained because of lazy loading.
if rs != nil {
req := rs.NewChunk()
for {
if err = rs.Next(ctx, req); err != nil {
return err
}
if req.NumRows() == 0 {
break
}
}
if err := rs.Close(); err != nil {
return err
}
}
}
logutil.Logger(ctx).Debug("init_connect complete")
return nil
}

// Run reads client query and writes query result to client in for loop, if there is a panic during query handling,
// it will be recovered and log the panic error.
// This function returns and the connection is closed if there is an IO error or there is a panic.
Expand Down
1 change: 1 addition & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ var (
errConCount = dbterror.ClassServer.NewStd(errno.ErrConCount)
errSecureTransportRequired = dbterror.ClassServer.NewStd(errno.ErrSecureTransportRequired)
errMultiStatementDisabled = dbterror.ClassServer.NewStd(errno.ErrMultiStatementDisabled)
errNewAbortingConnection = dbterror.ClassServer.NewStd(errno.ErrNewAbortingConnection)
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand Down
48 changes: 48 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,54 @@ func (cli *testServerClient) waitUntilServerOnline() {
}
}

func (cli *testServerClient) runTestInitConnect(c *C) {

cli.runTests(c, nil, func(dbt *DBTest) {
dbt.mustExec(`SET GLOBAL init_connect="insert into test.ts VALUES (NOW());SET @a=1;"`)
dbt.mustExec(`CREATE USER init_nonsuper`)
dbt.mustExec(`CREATE USER init_super`)
dbt.mustExec(`GRANT SELECT, INSERT, DROP ON test.* TO init_nonsuper`)
dbt.mustExec(`GRANT SELECT, INSERT, DROP, SUPER ON *.* TO init_super`)
dbt.mustExec(`CREATE TABLE ts (a TIMESTAMP)`)
})

// test init_nonsuper
cli.runTests(c, func(config *mysql.Config) {
config.User = "init_nonsuper"
}, func(dbt *DBTest) {
rows := dbt.mustQuery(`SELECT @a`)
c.Assert(rows.Next(), IsTrue)
var a int
err := rows.Scan(&a)
c.Assert(err, IsNil)
dbt.Check(a, Equals, 1)
c.Assert(rows.Close(), IsNil)
})

// test init_super
cli.runTests(c, func(config *mysql.Config) {
config.User = "init_super"
}, func(dbt *DBTest) {
rows := dbt.mustQuery(`SELECT IFNULL(@a,"")`)
c.Assert(rows.Next(), IsTrue)
var a string
err := rows.Scan(&a)
c.Assert(err, IsNil)
dbt.Check(a, Equals, "") // null
c.Assert(rows.Close(), IsNil)
// change the init-connect to invalid.
dbt.mustExec(`SET GLOBAL init_connect="invalidstring"`)
})

db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) {
config.User = "init_nonsuper"
}))
c.Assert(err, IsNil, Commentf("Error connecting")) // doesn't fail because of lazy loading
defer db.Close() // may already be closed
_, err = db.Exec("SELECT 1") // fails because of init sql
c.Assert(err, NotNil)
}

// Client errors are only incremented when using the TiDB Server protocol,
// and not internal SQL statements. Thus, this test is in the server-test suite.
func (cli *testServerClient) runTestInfoschemaClientErrors(t *C) {
Expand Down
4 changes: 4 additions & 0 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,10 @@ func (ts *tidbTestSuite) TestClientErrors(c *C) {
ts.runTestInfoschemaClientErrors(c)
}

func (ts *tidbTestSuite) TestInitConnect(c *C) {
ts.runTestInitConnect(c)
}

func (ts *tidbTestSuite) TestSumAvg(c *C) {
c.Parallel()
ts.runTestSumAvg(c)
Expand Down
1 change: 0 additions & 1 deletion sessionctx/variable/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ var noopSysVars = []*SysVar{
{Scope: ScopeGlobal | ScopeSession, Name: QueryCacheType, Value: BoolOff, Type: TypeEnum, PossibleValues: []string{BoolOff, BoolOn, "DEMAND"}},
{Scope: ScopeNone, Name: "innodb_rollback_on_timeout", Value: "0"},
{Scope: ScopeGlobal | ScopeSession, Name: "query_alloc_block_size", Value: "8192"},
{Scope: ScopeGlobal | ScopeSession, Name: InitConnect, Value: ""},
{Scope: ScopeNone, Name: "have_compress", Value: "YES"},
{Scope: ScopeNone, Name: "thread_concurrency", Value: "10"},
{Scope: ScopeGlobal | ScopeSession, Name: "query_prealloc_size", Value: "8192"},
Expand Down
1 change: 1 addition & 0 deletions sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ var defaultSysVars = []*SysVar{
}
return oracle.LocalTxnScope
}()},
{Scope: ScopeGlobal, Name: InitConnect, Value: ""},
/* TiDB specific variables */
{Scope: ScopeGlobal | ScopeSession, Name: TiDBAllowMPPExecution, Type: TypeBool, Value: BoolToOnOff(DefTiDBAllowMPPExecution)},
{Scope: ScopeSession, Name: TiDBEnforceMPPExecution, Type: TypeBool, Value: BoolToOnOff(config.GetGlobalConfig().Performance.EnforceMPP)},
Expand Down