Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

添加用户鉴权模块 #132

Merged
merged 2 commits into from
Dec 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ type Config struct {
Ghost Ghost `toml:"ghost" json:"ghost"`
IncLevel IncLevel `toml:"inc_level" json:"inc_level"`
CompatibleKillQuery bool `toml:"compatible-kill-query" json:"compatible-kill-query"`

// 是否跳过用户权限校验
SkipGrantTable bool `toml:"skip_grant_table" json:"skip_grant_table"`
}

// Log is the log section of config.
Expand All @@ -98,7 +101,7 @@ type Log struct {

// Security is the security section of the config.
type Security struct {
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SkipGrantTable bool `toml:"skip_grant_table" json:"skip_grant_table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
Expand Down Expand Up @@ -348,6 +351,8 @@ type Inc struct {
// 建表必须创建的列. 可指定多个列,以逗号分隔.列类型可选. 格式: 列名 [列类型,可选],...
MustHaveColumns string `toml:"must_have_columns" json:"must_have_columns"`

// 是否跳过用户权限校验
SkipGrantTable bool `toml:"skip_grant_table" json:"skip_grant_table"`
// 要跳过的sql语句, 多个时以分号分隔
SkipSqls string `toml:"skip_sqls" json:"skip_sqls"`

Expand Down Expand Up @@ -709,6 +714,8 @@ var defaultConf = Config{
WriteTimeout: "15s",
},
// 默认跳过权限校验 2019-1-26
// 为配置方便,在config节点也添加相同参数
SkipGrantTable: true,
Security: Security{
SkipGrantTable: true,
},
Expand All @@ -734,6 +741,9 @@ var defaultConf = Config{
DefaultCharset: "utf8mb4",
MaxAllowedPacket: 4194304,
ExplainRule: "first",

// 为配置方便,在config节点也添加相同参数
SkipGrantTable: true,
// Version: &mysql.TiDBReleaseVersion,
},
Osc: Osc{
Expand Down
6 changes: 5 additions & 1 deletion config/config.toml.example
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ lower-case-table-names = 2
# turn on this option when TiDB server is behind a proxy.
compatible-kill-query = false

skip_grant_table = true

[log]
# Log level: debug, info, warn, error, fatal.
level = "info"
Expand Down Expand Up @@ -88,7 +90,7 @@ max-backups = 0
log-rotate = true

[security]
skip-grant-table = true
skip_grant_table = true

# Path of file that contains list of trusted SSL CAs for connection with mysql client.
ssl-ca = ""
Expand Down Expand Up @@ -279,6 +281,8 @@ explain_rule = "first"
# 1 表示开启安全更新
sql_safe_updates = -1

skip_grant_table = true

support_charset = "utf8,utf8mb4"
support_engine = "innodb"

Expand Down
14 changes: 7 additions & 7 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,13 +638,13 @@ func (cc *clientConn) dispatch(data []byte) error {
return cc.handleQuery(ctx1, hack.String(data))
case mysql.ComPing:
return cc.writeOK()
// case mysql.ComInitDB:
// if err := cc.useDB(ctx1, hack.String(data)); err != nil {
// return errors.Trace(err)
// }
// return cc.writeOK()
// case mysql.ComFieldList:
// return cc.handleFieldList(hack.String(data))
case mysql.ComInitDB:
if err := cc.useDB(ctx1, hack.String(data)); err != nil {
return errors.Trace(err)
}
return cc.writeOK()
case mysql.ComFieldList:
return cc.handleFieldList(hack.String(data))
// case mysql.ComStmtPrepare:
// return cc.handleStmtPrepare(hack.String(data))
// case mysql.ComStmtExecute:
Expand Down
88 changes: 86 additions & 2 deletions session/session_inception.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ func (s *session) ExecuteInc(ctx context.Context, sql string) (recordSets []sqle
}
}

if strings.HasPrefix(lowerSql, "select high_priority") {
if lowerSql == "select database()" {
return s.execute(ctx, sql)
} else if strings.HasPrefix(lowerSql, "select high_priority") {
return s.execute(ctx, sql)
} else if strings.HasPrefix(lowerSql,
`select variable_value from mysql.tidb where variable_name = "system_tz"`) {
Expand Down Expand Up @@ -367,6 +369,9 @@ func (s *session) ExecuteInc(ctx context.Context, sql string) (recordSets []sqle
func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) {
sqlList := strings.Split(sql, "\n")

// tidb执行的SQL关闭general日志
logging := s.Inc.GeneralLog

defer func() {
if s.sessionVars.StmtCtx.AffectedRows() == 0 {
if s.opt != nil && s.opt.Print {
Expand All @@ -377,9 +382,13 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle
s.sessionVars.StmtCtx.AddAffectedRows(uint64(s.recordSets.rc.count))
}
}

if logging {
logQuery(sql, s.sessionVars)
}
}()

defer logQuery(sql, s.sessionVars)
// defer logQuery(sql, s.sessionVars)

s.PrepareTxnCtx(ctx)
connID := s.sessionVars.ConnectionID
Expand Down Expand Up @@ -565,6 +574,20 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle
s.executeCommit(ctx)
return s.makeResult()
default:
// TiDB原生执行器
if !s.haveBegin {
istidb, isFlush := s.isRunToTiDB(stmtNode)
if istidb {
r, err := s.execute(ctx, currentSql)
if isFlush {
// 权限模块的SQL在执行后自动刷新
s.execute(ctx, "FLUSH PRIVILEGES")
}
logging = false
return r, err
}
}

need := s.needDataSource(stmtNode)

if !s.haveBegin && need {
Expand Down Expand Up @@ -697,6 +720,64 @@ func (s *session) makeResult() (recordSets []sqlexec.RecordSet, err error) {
}
}

func (s *session) isRunToTiDB(stmtNode ast.StmtNode) (is bool, isFlush bool) {

switch node := stmtNode.(type) {
case *ast.UseStmt:
return true, false

case *ast.ExplainStmt:
return true, false

case *ast.UnionStmt:
return true, false

case *ast.SelectStmt:
return true, false

if node.From != nil {
join := node.From.TableRefs
if join.Right == nil {
switch x := node.From.TableRefs.Left.(type) {
case *ast.TableSource:
if s, ok := x.Source.(*ast.TableName); ok {
// log.Infof("%#v", s)
if s.Name.L == "user" {
return true, false
}
return false, false
}
default:
log.Infof("%T", x)
// log.Infof("%#v", x)
}
}
} else {
return true, false
}

case *ast.CreateUserStmt, *ast.AlterUserStmt, *ast.DropUserStmt,
*ast.GrantStmt, *ast.RevokeStmt,
*ast.SetPwdStmt:
return true, true
case *ast.FlushStmt:
return true, false

case *ast.ShowStmt:
if !node.IsInception {
// 添加部分命令支持
switch node.Tp {
case ast.ShowDatabases, ast.ShowTables,
ast.ShowTableStatus, ast.ShowColumns,
ast.ShowWarnings, ast.ShowGrants:
return true, false
}
}
}

return false, false
}

func (s *session) needDataSource(stmtNode ast.StmtNode) bool {
switch node := stmtNode.(type) {
case *ast.ShowStmt:
Expand Down Expand Up @@ -5440,6 +5521,9 @@ func (s *session) getSubSelectColumns(node ast.ResultSetNode) []string {
switch e := f.Expr.(type) {
case *ast.ColumnNameExpr:
columns = append(columns, e.Name.Name.String())
// case *ast.VariableExpr:
// todo ...
// log.Infof("con:%d %#v", s.sessionVars.ConnectionID, e)
default:
log.Infof("con:%d %T", s.sessionVars.ConnectionID, e)
}
Expand Down
7 changes: 6 additions & 1 deletion tidb-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,12 @@ func setGlobalVars() {
ddl.RunWorker = cfg.RunDDL
ddl.EnableSplitTableRegion = cfg.SplitTable
plannercore.AllowCartesianProduct = cfg.Performance.CrossJoin
privileges.SkipWithGrant = cfg.Security.SkipGrantTable
// 权限参数冗余设置,开启任一鉴权即可,默认跳过鉴权
skip := cfg.SkipGrantTable && cfg.Security.SkipGrantTable && cfg.Inc.SkipGrantTable
cfg.Security.SkipGrantTable = skip
cfg.Inc.SkipGrantTable = skip
cfg.SkipGrantTable = skip
privileges.SkipWithGrant = skip
variable.ForcePriority = int32(mysql.Str2Priority(cfg.Performance.ForcePriority))

variable.SysVars[variable.TIDBMemQuotaQuery].Value = strconv.FormatInt(cfg.MemQuotaQuery, 10)
Expand Down