From 2f03962d1f9d8d72770222230ffd1188fd8aeadf Mon Sep 17 00:00:00 2001 From: hanchuanchuan Date: Tue, 10 Dec 2019 22:34:25 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feature:=20=E6=B7=BB=E5=8A=A0=E5=AE=89?= =?UTF-8?q?=E5=85=A8=E6=A8=A1=E5=9D=97,=E5=AE=9E=E7=8E=B0=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E7=AE=A1=E7=90=86,=E5=AE=89=E5=85=A8=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/conn.go | 14 +++++----- session/session_inception.go | 51 +++++++++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/server/conn.go b/server/conn.go index 5c1ee96f..c16d1e15 100644 --- a/server/conn.go +++ b/server/conn.go @@ -638,13 +638,13 @@ func (cc *clientConn) dispatch(data []byte) error { return cc.handleQuery(ctx1, hack.String(data)) case mysql.ComPing: return cc.writeOK() - // case mysql.ComInitDB: - // if err := cc.useDB(ctx1, hack.String(data)); err != nil { - // return errors.Trace(err) - // } - // return cc.writeOK() - // case mysql.ComFieldList: - // return cc.handleFieldList(hack.String(data)) + case mysql.ComInitDB: + if err := cc.useDB(ctx1, hack.String(data)); err != nil { + return errors.Trace(err) + } + return cc.writeOK() + case mysql.ComFieldList: + return cc.handleFieldList(hack.String(data)) // case mysql.ComStmtPrepare: // return cc.handleStmtPrepare(hack.String(data)) // case mysql.ComStmtExecute: diff --git a/session/session_inception.go b/session/session_inception.go index 4456fdbf..3298ab54 100644 --- a/session/session_inception.go +++ b/session/session_inception.go @@ -303,7 +303,9 @@ func (s *session) ExecuteInc(ctx context.Context, sql string) (recordSets []sqle } } - if strings.HasPrefix(lowerSql, "select high_priority") { + if lowerSql == "select database()" { + return s.execute(ctx, sql) + } else if strings.HasPrefix(lowerSql, "select high_priority") { return s.execute(ctx, sql) } else if strings.HasPrefix(lowerSql, `select variable_value from mysql.tidb where variable_name = "system_tz"`) { @@ -565,6 +567,11 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle s.executeCommit(ctx) return s.makeResult() default: + // TiDB原生执行器 + if !s.haveBegin && s.isRunToTiDB(stmtNode) { + return s.execute(ctx, currentSql) + } + need := s.needDataSource(stmtNode) if !s.haveBegin && need { @@ -697,6 +704,48 @@ func (s *session) makeResult() (recordSets []sqlexec.RecordSet, err error) { } } +func (s *session) isRunToTiDB(stmtNode ast.StmtNode) bool { + switch node := stmtNode.(type) { + case *ast.UseStmt: + return true + case *ast.SelectStmt: + + if node.From != nil { + join := node.From.TableRefs + if join.Right == nil { + switch x := node.From.TableRefs.Left.(type) { + case *ast.TableSource: + if s, ok := x.Source.(*ast.TableName); ok { + // log.Infof("%#v", s) + if s.Name.L == "user" { + return true + } + return false + } + default: + log.Infof("%T", x) + // log.Infof("%#v", x) + } + } + } + + case *ast.ShowStmt: + if node.IsInception { + return false + } else { + // 添加部分命令支持 + switch node.Tp { + case ast.ShowDatabases, ast.ShowTables, + ast.ShowTableStatus, ast.ShowColumns, + ast.ShowWarnings, ast.ShowGrants: + return true + } + } + } + + return false +} + func (s *session) needDataSource(stmtNode ast.StmtNode) bool { switch node := stmtNode.(type) { case *ast.ShowStmt: From 286570974cefc3b6fd2241f9f8f8f0963caa9a31 Mon Sep 17 00:00:00 2001 From: hanchuanchuan Date: Wed, 11 Dec 2019 11:31:08 +0800 Subject: [PATCH 2/2] =?UTF-8?q?update:=20=E5=BC=80=E6=94=BE=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E8=AF=AD=E6=B3=95,=E4=BC=98=E5=8C=96=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.go | 12 +++++++- config/config.toml.example | 6 +++- session/session_inception.go | 59 ++++++++++++++++++++++++++++-------- tidb-server/main.go | 7 ++++- 4 files changed, 69 insertions(+), 15 deletions(-) diff --git a/config/config.go b/config/config.go index 95500173..ab3fcc4c 100644 --- a/config/config.go +++ b/config/config.go @@ -77,6 +77,9 @@ type Config struct { Ghost Ghost `toml:"ghost" json:"ghost"` IncLevel IncLevel `toml:"inc_level" json:"inc_level"` CompatibleKillQuery bool `toml:"compatible-kill-query" json:"compatible-kill-query"` + + // 是否跳过用户权限校验 + SkipGrantTable bool `toml:"skip_grant_table" json:"skip_grant_table"` } // Log is the log section of config. @@ -98,7 +101,7 @@ type Log struct { // Security is the security section of the config. type Security struct { - SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"` + SkipGrantTable bool `toml:"skip_grant_table" json:"skip_grant_table"` SSLCA string `toml:"ssl-ca" json:"ssl-ca"` SSLCert string `toml:"ssl-cert" json:"ssl-cert"` SSLKey string `toml:"ssl-key" json:"ssl-key"` @@ -348,6 +351,8 @@ type Inc struct { // 建表必须创建的列. 可指定多个列,以逗号分隔.列类型可选. 格式: 列名 [列类型,可选],... MustHaveColumns string `toml:"must_have_columns" json:"must_have_columns"` + // 是否跳过用户权限校验 + SkipGrantTable bool `toml:"skip_grant_table" json:"skip_grant_table"` // 要跳过的sql语句, 多个时以分号分隔 SkipSqls string `toml:"skip_sqls" json:"skip_sqls"` @@ -709,6 +714,8 @@ var defaultConf = Config{ WriteTimeout: "15s", }, // 默认跳过权限校验 2019-1-26 + // 为配置方便,在config节点也添加相同参数 + SkipGrantTable: true, Security: Security{ SkipGrantTable: true, }, @@ -734,6 +741,9 @@ var defaultConf = Config{ DefaultCharset: "utf8mb4", MaxAllowedPacket: 4194304, ExplainRule: "first", + + // 为配置方便,在config节点也添加相同参数 + SkipGrantTable: true, // Version: &mysql.TiDBReleaseVersion, }, Osc: Osc{ diff --git a/config/config.toml.example b/config/config.toml.example index a3ef5ca8..1ea5801b 100644 --- a/config/config.toml.example +++ b/config/config.toml.example @@ -48,6 +48,8 @@ lower-case-table-names = 2 # turn on this option when TiDB server is behind a proxy. compatible-kill-query = false +skip_grant_table = true + [log] # Log level: debug, info, warn, error, fatal. level = "info" @@ -88,7 +90,7 @@ max-backups = 0 log-rotate = true [security] -skip-grant-table = true +skip_grant_table = true # Path of file that contains list of trusted SSL CAs for connection with mysql client. ssl-ca = "" @@ -279,6 +281,8 @@ explain_rule = "first" # 1 表示开启安全更新 sql_safe_updates = -1 +skip_grant_table = true + support_charset = "utf8,utf8mb4" support_engine = "innodb" diff --git a/session/session_inception.go b/session/session_inception.go index 3298ab54..434237c6 100644 --- a/session/session_inception.go +++ b/session/session_inception.go @@ -369,6 +369,9 @@ func (s *session) ExecuteInc(ctx context.Context, sql string) (recordSets []sqle func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) { sqlList := strings.Split(sql, "\n") + // tidb执行的SQL关闭general日志 + logging := s.Inc.GeneralLog + defer func() { if s.sessionVars.StmtCtx.AffectedRows() == 0 { if s.opt != nil && s.opt.Print { @@ -379,9 +382,13 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle s.sessionVars.StmtCtx.AddAffectedRows(uint64(s.recordSets.rc.count)) } } + + if logging { + logQuery(sql, s.sessionVars) + } }() - defer logQuery(sql, s.sessionVars) + // defer logQuery(sql, s.sessionVars) s.PrepareTxnCtx(ctx) connID := s.sessionVars.ConnectionID @@ -568,8 +575,17 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle return s.makeResult() default: // TiDB原生执行器 - if !s.haveBegin && s.isRunToTiDB(stmtNode) { - return s.execute(ctx, currentSql) + if !s.haveBegin { + istidb, isFlush := s.isRunToTiDB(stmtNode) + if istidb { + r, err := s.execute(ctx, currentSql) + if isFlush { + // 权限模块的SQL在执行后自动刷新 + s.execute(ctx, "FLUSH PRIVILEGES") + } + logging = false + return r, err + } } need := s.needDataSource(stmtNode) @@ -704,11 +720,20 @@ func (s *session) makeResult() (recordSets []sqlexec.RecordSet, err error) { } } -func (s *session) isRunToTiDB(stmtNode ast.StmtNode) bool { +func (s *session) isRunToTiDB(stmtNode ast.StmtNode) (is bool, isFlush bool) { + switch node := stmtNode.(type) { case *ast.UseStmt: - return true + return true, false + + case *ast.ExplainStmt: + return true, false + + case *ast.UnionStmt: + return true, false + case *ast.SelectStmt: + return true, false if node.From != nil { join := node.From.TableRefs @@ -718,32 +743,39 @@ func (s *session) isRunToTiDB(stmtNode ast.StmtNode) bool { if s, ok := x.Source.(*ast.TableName); ok { // log.Infof("%#v", s) if s.Name.L == "user" { - return true + return true, false } - return false + return false, false } default: log.Infof("%T", x) // log.Infof("%#v", x) } } + } else { + return true, false } + case *ast.CreateUserStmt, *ast.AlterUserStmt, *ast.DropUserStmt, + *ast.GrantStmt, *ast.RevokeStmt, + *ast.SetPwdStmt: + return true, true + case *ast.FlushStmt: + return true, false + case *ast.ShowStmt: - if node.IsInception { - return false - } else { + if !node.IsInception { // 添加部分命令支持 switch node.Tp { case ast.ShowDatabases, ast.ShowTables, ast.ShowTableStatus, ast.ShowColumns, ast.ShowWarnings, ast.ShowGrants: - return true + return true, false } } } - return false + return false, false } func (s *session) needDataSource(stmtNode ast.StmtNode) bool { @@ -5489,6 +5521,9 @@ func (s *session) getSubSelectColumns(node ast.ResultSetNode) []string { switch e := f.Expr.(type) { case *ast.ColumnNameExpr: columns = append(columns, e.Name.Name.String()) + // case *ast.VariableExpr: + // todo ... + // log.Infof("con:%d %#v", s.sessionVars.ConnectionID, e) default: log.Infof("con:%d %T", s.sessionVars.ConnectionID, e) } diff --git a/tidb-server/main.go b/tidb-server/main.go index 4840bcf7..ee1dc3db 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -346,7 +346,12 @@ func setGlobalVars() { ddl.RunWorker = cfg.RunDDL ddl.EnableSplitTableRegion = cfg.SplitTable plannercore.AllowCartesianProduct = cfg.Performance.CrossJoin - privileges.SkipWithGrant = cfg.Security.SkipGrantTable + // 权限参数冗余设置,开启任一鉴权即可,默认跳过鉴权 + skip := cfg.SkipGrantTable && cfg.Security.SkipGrantTable && cfg.Inc.SkipGrantTable + cfg.Security.SkipGrantTable = skip + cfg.Inc.SkipGrantTable = skip + cfg.SkipGrantTable = skip + privileges.SkipWithGrant = skip variable.ForcePriority = int32(mysql.Str2Priority(cfg.Performance.ForcePriority)) variable.SysVars[variable.TIDBMemQuotaQuery].Value = strconv.FormatInt(cfg.MemQuotaQuery, 10)