Skip to content

Commit

Permalink
Merge pull request #132 from hanchuanchuan/feature-secure
Browse files Browse the repository at this point in the history
添加用户鉴权模块
  • Loading branch information
hanchuanchuan authored Dec 11, 2019
2 parents 65787dc + 2865709 commit 2a30e04
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 12 deletions.
12 changes: 11 additions & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ type Config struct {
Ghost Ghost `toml:"ghost" json:"ghost"`
IncLevel IncLevel `toml:"inc_level" json:"inc_level"`
CompatibleKillQuery bool `toml:"compatible-kill-query" json:"compatible-kill-query"`

// 是否跳过用户权限校验
SkipGrantTable bool `toml:"skip_grant_table" json:"skip_grant_table"`
}

// Log is the log section of config.
Expand All @@ -98,7 +101,7 @@ type Log struct {

// Security is the security section of the config.
type Security struct {
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SkipGrantTable bool `toml:"skip_grant_table" json:"skip_grant_table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
Expand Down Expand Up @@ -348,6 +351,8 @@ type Inc struct {
// 建表必须创建的列. 可指定多个列,以逗号分隔.列类型可选. 格式: 列名 [列类型,可选],...
MustHaveColumns string `toml:"must_have_columns" json:"must_have_columns"`

// 是否跳过用户权限校验
SkipGrantTable bool `toml:"skip_grant_table" json:"skip_grant_table"`
// 要跳过的sql语句, 多个时以分号分隔
SkipSqls string `toml:"skip_sqls" json:"skip_sqls"`

Expand Down Expand Up @@ -709,6 +714,8 @@ var defaultConf = Config{
WriteTimeout: "15s",
},
// 默认跳过权限校验 2019-1-26
// 为配置方便,在config节点也添加相同参数
SkipGrantTable: true,
Security: Security{
SkipGrantTable: true,
},
Expand All @@ -734,6 +741,9 @@ var defaultConf = Config{
DefaultCharset: "utf8mb4",
MaxAllowedPacket: 4194304,
ExplainRule: "first",

// 为配置方便,在config节点也添加相同参数
SkipGrantTable: true,
// Version: &mysql.TiDBReleaseVersion,
},
Osc: Osc{
Expand Down
6 changes: 5 additions & 1 deletion config/config.toml.example
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ lower-case-table-names = 2
# turn on this option when TiDB server is behind a proxy.
compatible-kill-query = false

skip_grant_table = true

[log]
# Log level: debug, info, warn, error, fatal.
level = "info"
Expand Down Expand Up @@ -88,7 +90,7 @@ max-backups = 0
log-rotate = true

[security]
skip-grant-table = true
skip_grant_table = true

# Path of file that contains list of trusted SSL CAs for connection with mysql client.
ssl-ca = ""
Expand Down Expand Up @@ -279,6 +281,8 @@ explain_rule = "first"
# 1 表示开启安全更新
sql_safe_updates = -1

skip_grant_table = true

support_charset = "utf8,utf8mb4"
support_engine = "innodb"

Expand Down
14 changes: 7 additions & 7 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,13 +638,13 @@ func (cc *clientConn) dispatch(data []byte) error {
return cc.handleQuery(ctx1, hack.String(data))
case mysql.ComPing:
return cc.writeOK()
// case mysql.ComInitDB:
// if err := cc.useDB(ctx1, hack.String(data)); err != nil {
// return errors.Trace(err)
// }
// return cc.writeOK()
// case mysql.ComFieldList:
// return cc.handleFieldList(hack.String(data))
case mysql.ComInitDB:
if err := cc.useDB(ctx1, hack.String(data)); err != nil {
return errors.Trace(err)
}
return cc.writeOK()
case mysql.ComFieldList:
return cc.handleFieldList(hack.String(data))
// case mysql.ComStmtPrepare:
// return cc.handleStmtPrepare(hack.String(data))
// case mysql.ComStmtExecute:
Expand Down
88 changes: 86 additions & 2 deletions session/session_inception.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ func (s *session) ExecuteInc(ctx context.Context, sql string) (recordSets []sqle
}
}

if strings.HasPrefix(lowerSql, "select high_priority") {
if lowerSql == "select database()" {
return s.execute(ctx, sql)
} else if strings.HasPrefix(lowerSql, "select high_priority") {
return s.execute(ctx, sql)
} else if strings.HasPrefix(lowerSql,
`select variable_value from mysql.tidb where variable_name = "system_tz"`) {
Expand Down Expand Up @@ -367,6 +369,9 @@ func (s *session) ExecuteInc(ctx context.Context, sql string) (recordSets []sqle
func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) {
sqlList := strings.Split(sql, "\n")

// tidb执行的SQL关闭general日志
logging := s.Inc.GeneralLog

defer func() {
if s.sessionVars.StmtCtx.AffectedRows() == 0 {
if s.opt != nil && s.opt.Print {
Expand All @@ -377,9 +382,13 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle
s.sessionVars.StmtCtx.AddAffectedRows(uint64(s.recordSets.rc.count))
}
}

if logging {
logQuery(sql, s.sessionVars)
}
}()

defer logQuery(sql, s.sessionVars)
// defer logQuery(sql, s.sessionVars)

s.PrepareTxnCtx(ctx)
connID := s.sessionVars.ConnectionID
Expand Down Expand Up @@ -565,6 +574,20 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle
s.executeCommit(ctx)
return s.makeResult()
default:
// TiDB原生执行器
if !s.haveBegin {
istidb, isFlush := s.isRunToTiDB(stmtNode)
if istidb {
r, err := s.execute(ctx, currentSql)
if isFlush {
// 权限模块的SQL在执行后自动刷新
s.execute(ctx, "FLUSH PRIVILEGES")
}
logging = false
return r, err
}
}

need := s.needDataSource(stmtNode)

if !s.haveBegin && need {
Expand Down Expand Up @@ -697,6 +720,64 @@ func (s *session) makeResult() (recordSets []sqlexec.RecordSet, err error) {
}
}

func (s *session) isRunToTiDB(stmtNode ast.StmtNode) (is bool, isFlush bool) {

switch node := stmtNode.(type) {
case *ast.UseStmt:
return true, false

case *ast.ExplainStmt:
return true, false

case *ast.UnionStmt:
return true, false

case *ast.SelectStmt:
return true, false

if node.From != nil {
join := node.From.TableRefs
if join.Right == nil {
switch x := node.From.TableRefs.Left.(type) {
case *ast.TableSource:
if s, ok := x.Source.(*ast.TableName); ok {
// log.Infof("%#v", s)
if s.Name.L == "user" {
return true, false
}
return false, false
}
default:
log.Infof("%T", x)
// log.Infof("%#v", x)
}
}
} else {
return true, false
}

case *ast.CreateUserStmt, *ast.AlterUserStmt, *ast.DropUserStmt,
*ast.GrantStmt, *ast.RevokeStmt,
*ast.SetPwdStmt:
return true, true
case *ast.FlushStmt:
return true, false

case *ast.ShowStmt:
if !node.IsInception {
// 添加部分命令支持
switch node.Tp {
case ast.ShowDatabases, ast.ShowTables,
ast.ShowTableStatus, ast.ShowColumns,
ast.ShowWarnings, ast.ShowGrants:
return true, false
}
}
}

return false, false
}

func (s *session) needDataSource(stmtNode ast.StmtNode) bool {
switch node := stmtNode.(type) {
case *ast.ShowStmt:
Expand Down Expand Up @@ -5440,6 +5521,9 @@ func (s *session) getSubSelectColumns(node ast.ResultSetNode) []string {
switch e := f.Expr.(type) {
case *ast.ColumnNameExpr:
columns = append(columns, e.Name.Name.String())
// case *ast.VariableExpr:
// todo ...
// log.Infof("con:%d %#v", s.sessionVars.ConnectionID, e)
default:
log.Infof("con:%d %T", s.sessionVars.ConnectionID, e)
}
Expand Down
7 changes: 6 additions & 1 deletion tidb-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,12 @@ func setGlobalVars() {
ddl.RunWorker = cfg.RunDDL
ddl.EnableSplitTableRegion = cfg.SplitTable
plannercore.AllowCartesianProduct = cfg.Performance.CrossJoin
privileges.SkipWithGrant = cfg.Security.SkipGrantTable
// 权限参数冗余设置,开启任一鉴权即可,默认跳过鉴权
skip := cfg.SkipGrantTable && cfg.Security.SkipGrantTable && cfg.Inc.SkipGrantTable
cfg.Security.SkipGrantTable = skip
cfg.Inc.SkipGrantTable = skip
cfg.SkipGrantTable = skip
privileges.SkipWithGrant = skip
variable.ForcePriority = int32(mysql.Str2Priority(cfg.Performance.ForcePriority))

variable.SysVars[variable.TIDBMemQuotaQuery].Value = strconv.FormatInt(cfg.MemQuotaQuery, 10)
Expand Down

0 comments on commit 2a30e04

Please sign in to comment.