Skip to content

Commit

Permalink
server, sessionctx: improved mysql compatibility with support for ini…
Browse files Browse the repository at this point in the history
…t_connect (#23713)
  • Loading branch information
morgo authored Apr 2, 2021
1 parent 32f6e33 commit 9c75cfa
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 1 deletion.
65 changes: 65 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ import (
"github.com/pingcap/tidb/metrics"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/plugin"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
Expand Down Expand Up @@ -251,6 +252,18 @@ func (cc *clientConn) handshake(ctx context.Context) error {
}
return err
}

// MySQL supports an "init_connect" query, which can be run on initial connection.
// The query must return a non-error or the client is disconnected.
if err := cc.initConnect(ctx); err != nil {
logutil.Logger(ctx).Warn("init_connect failed", zap.Error(err))
initErr := errNewAbortingConnection.FastGenByArgs(cc.connectionID, "unconnected", cc.user, cc.peerHost, "init_connect command failed")
if err1 := cc.writeError(ctx, initErr); err1 != nil {
terror.Log(err1)
}
return initErr
}

data := cc.alloc.AllocWithLen(4, 32)
data = append(data, mysql.OKHeader)
data = append(data, 0, 0)
Expand Down Expand Up @@ -722,6 +735,58 @@ func (cc *clientConn) PeerHost(hasPassword string) (host, port string, err error
return
}

// skipInitConnect follows MySQL's rules of when init-connect should be skipped.
// In 5.7 it is any user with SUPER privilege, but in 8.0 it is:
// - SUPER or the CONNECTION_ADMIN dynamic privilege.
// - (additional exception) users with expired passwords (not yet supported)
// In TiDB CONNECTION_ADMIN is satisfied by SUPER, so we only need to check once.
func (cc *clientConn) skipInitConnect() bool {
checker := privilege.GetPrivilegeManager(cc.ctx.Session)
activeRoles := cc.ctx.GetSessionVars().ActiveRoles
return checker != nil && checker.RequestDynamicVerification(activeRoles, "CONNECTION_ADMIN", false)
}

// initConnect runs the initConnect SQL statement if it has been specified.
// The semantics are MySQL compatible.
func (cc *clientConn) initConnect(ctx context.Context) error {
val, err := cc.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.InitConnect)
if err != nil {
return err
}
if val == "" || cc.skipInitConnect() {
return nil
}
logutil.Logger(ctx).Debug("init_connect starting")
stmts, err := cc.ctx.Parse(ctx, val)
if err != nil {
return err
}
for _, stmt := range stmts {
rs, err := cc.ctx.ExecuteStmt(ctx, stmt)
if err != nil {
return err
}
// init_connect does not care about the results,
// but they need to be drained because of lazy loading.
if rs != nil {
req := rs.NewChunk()
for {
if err = rs.Next(ctx, req); err != nil {
return err
}
if req.NumRows() == 0 {
break
}
}
if err := rs.Close(); err != nil {
return err
}
}
}
logutil.Logger(ctx).Debug("init_connect complete")
return nil
}

// Run reads client query and writes query result to client in for loop, if there is a panic during query handling,
// it will be recovered and log the panic error.
// This function returns and the connection is closed if there is an IO error or there is a panic.
Expand Down
1 change: 1 addition & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ var (
errConCount = dbterror.ClassServer.NewStd(errno.ErrConCount)
errSecureTransportRequired = dbterror.ClassServer.NewStd(errno.ErrSecureTransportRequired)
errMultiStatementDisabled = dbterror.ClassServer.NewStd(errno.ErrMultiStatementDisabled)
errNewAbortingConnection = dbterror.ClassServer.NewStd(errno.ErrNewAbortingConnection)
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand Down
48 changes: 48 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1979,6 +1979,54 @@ func (cli *testServerClient) waitUntilServerOnline() {
}
}

func (cli *testServerClient) runTestInitConnect(c *C) {

cli.runTests(c, nil, func(dbt *DBTest) {
dbt.mustExec(`SET GLOBAL init_connect="insert into test.ts VALUES (NOW());SET @a=1;"`)
dbt.mustExec(`CREATE USER init_nonsuper`)
dbt.mustExec(`CREATE USER init_super`)
dbt.mustExec(`GRANT SELECT, INSERT, DROP ON test.* TO init_nonsuper`)
dbt.mustExec(`GRANT SELECT, INSERT, DROP, SUPER ON *.* TO init_super`)
dbt.mustExec(`CREATE TABLE ts (a TIMESTAMP)`)
})

// test init_nonsuper
cli.runTests(c, func(config *mysql.Config) {
config.User = "init_nonsuper"
}, func(dbt *DBTest) {
rows := dbt.mustQuery(`SELECT @a`)
c.Assert(rows.Next(), IsTrue)
var a int
err := rows.Scan(&a)
c.Assert(err, IsNil)
dbt.Check(a, Equals, 1)
c.Assert(rows.Close(), IsNil)
})

// test init_super
cli.runTests(c, func(config *mysql.Config) {
config.User = "init_super"
}, func(dbt *DBTest) {
rows := dbt.mustQuery(`SELECT IFNULL(@a,"")`)
c.Assert(rows.Next(), IsTrue)
var a string
err := rows.Scan(&a)
c.Assert(err, IsNil)
dbt.Check(a, Equals, "") // null
c.Assert(rows.Close(), IsNil)
// change the init-connect to invalid.
dbt.mustExec(`SET GLOBAL init_connect="invalidstring"`)
})

db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) {
config.User = "init_nonsuper"
}))
c.Assert(err, IsNil, Commentf("Error connecting")) // doesn't fail because of lazy loading
defer db.Close() // may already be closed
_, err = db.Exec("SELECT 1") // fails because of init sql
c.Assert(err, NotNil)
}

// Client errors are only incremented when using the TiDB Server protocol,
// and not internal SQL statements. Thus, this test is in the server-test suite.
func (cli *testServerClient) runTestInfoschemaClientErrors(t *C) {
Expand Down
4 changes: 4 additions & 0 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,10 @@ func (ts *tidbTestSuite) TestClientErrors(c *C) {
ts.runTestInfoschemaClientErrors(c)
}

func (ts *tidbTestSuite) TestInitConnect(c *C) {
ts.runTestInitConnect(c)
}

func (ts *tidbTestSuite) TestSumAvg(c *C) {
c.Parallel()
ts.runTestSumAvg(c)
Expand Down
1 change: 0 additions & 1 deletion sessionctx/variable/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ var noopSysVars = []*SysVar{
{Scope: ScopeGlobal | ScopeSession, Name: QueryCacheType, Value: BoolOff, Type: TypeEnum, PossibleValues: []string{BoolOff, BoolOn, "DEMAND"}},
{Scope: ScopeNone, Name: "innodb_rollback_on_timeout", Value: "0"},
{Scope: ScopeGlobal | ScopeSession, Name: "query_alloc_block_size", Value: "8192"},
{Scope: ScopeGlobal | ScopeSession, Name: InitConnect, Value: ""},
{Scope: ScopeNone, Name: "have_compress", Value: "YES"},
{Scope: ScopeNone, Name: "thread_concurrency", Value: "10"},
{Scope: ScopeGlobal | ScopeSession, Name: "query_prealloc_size", Value: "8192"},
Expand Down
1 change: 1 addition & 0 deletions sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ var defaultSysVars = []*SysVar{
}
return oracle.LocalTxnScope
}()},
{Scope: ScopeGlobal, Name: InitConnect, Value: ""},
/* TiDB specific variables */
{Scope: ScopeGlobal | ScopeSession, Name: TiDBAllowMPPExecution, Type: TypeBool, Value: BoolToOnOff(DefTiDBAllowMPPExecution), SetSession: func(s *SessionVars, val string) error {
s.AllowMPPExecution = TiDBOptOn(val)
Expand Down

0 comments on commit 9c75cfa

Please sign in to comment.