Skip to content

Commit

Permalink
Merge pull request #44 from hanchuanchuan/update-check-items
Browse files Browse the repository at this point in the history
update: 优化审核细节,审核子查询,函数等表达式
  • Loading branch information
hanchuanchuan authored Jun 18, 2019
2 parents e8139c7 + 3b531ca commit ff3fa75
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 43 deletions.
206 changes: 163 additions & 43 deletions session/session_inception.go
Original file line number Diff line number Diff line change
Expand Up @@ -5360,64 +5360,184 @@ func (s *session) checkItem(expr ast.ExprNode, tables []*TableInfo) bool {
return true
}

// log.Infof("%#v", expr)

switch e := expr.(type) {
case *ast.ColumnNameExpr:
s.checkFieldItem(e.Name, tables)
if e.Refer != nil {
s.checkItem(e.Refer.Expr, tables)
}

case *ast.BinaryOperationExpr:
return s.checkItem(e.L, tables) && s.checkItem(e.R, tables)
case *ast.ColumnNameExpr:
found := false

db := e.Name.Schema.L
// if db == "" {
// db = s.DBName
// }
case *ast.UnaryOperationExpr:
return s.checkItem(e.V, tables)

for _, t := range tables {
var tName string
if t.AsName != "" {
tName = t.AsName
} else {
tName = t.Name
}
case *ast.FuncCallExpr:
return s.checkFuncItem(e, tables)

if e.Name.Table.L != "" && (db == "" || strings.EqualFold(t.Schema, db)) &&
(strings.EqualFold(tName, e.Name.Table.L)) ||
e.Name.Table.L == "" {
for _, field := range t.Fields {
if strings.EqualFold(field.Field, e.Name.Name.L) && !field.IsDeleted {
found = true
break
}
}
if found {
case *ast.FuncCastExpr:
return s.checkItem(e.Expr, tables)

case *ast.AggregateFuncExpr:
return s.checkAggregateFuncItem(e, tables)

case *ast.PatternInExpr:
s.checkItem(e.Expr, tables)
for _, expr := range e.List {
s.checkItem(expr, tables)
}
if e.Sel != nil {
s.checkItem(e.Sel, tables)
}
case *ast.PatternLikeExpr:
s.checkItem(e.Expr, tables)
case *ast.PatternRegexpExpr:
s.checkItem(e.Expr, tables)

case *ast.SubqueryExpr:
s.checkSelectItem(e.Query)

case *ast.CompareSubqueryExpr:
s.checkItem(e.L, tables)
s.checkItem(e.R, tables)

case *ast.ExistsSubqueryExpr:
s.checkSelectItem(e.Sel)

case *ast.IsNullExpr:
s.checkItem(e.Expr, tables)
case *ast.IsTruthExpr:
s.checkItem(e.Expr, tables)

case *ast.BetweenExpr:
s.checkItem(e.Expr, tables)
s.checkItem(e.Left, tables)
s.checkItem(e.Right, tables)

case *ast.CaseExpr:
s.checkItem(e.Value, tables)
for _, when := range e.WhenClauses {
s.checkItem(when.Expr, tables)
s.checkItem(when.Result, tables)
}
s.checkItem(e.ElseClause, tables)

case *ast.DefaultExpr:
s.checkFieldItem(e.Name, tables)

case *ast.ParenthesesExpr:
s.checkItem(e.Expr, tables)

case *ast.RowExpr:
for _, expr := range e.Values {
s.checkItem(expr, tables)
}

case *ast.ValuesExpr:
s.checkFieldItem(e.Column.Name, tables)

case *ast.VariableExpr:
s.checkItem(e.Value, tables)

case *ast.ValueExpr, *ast.ParamMarkerExpr, *ast.PositionExpr:
// pass

default:
log.Infof("checkItem: %#v", e)
}

return true
}

// checkFieldItem 检查字段
func (s *session) checkFieldItem(name *ast.ColumnName, tables []*TableInfo) bool {
found := false
db := name.Schema.L

for _, t := range tables {
var tName string
if t.AsName != "" {
tName = t.AsName
} else {
tName = t.Name
}

if name.Table.L != "" && (db == "" || strings.EqualFold(t.Schema, db)) &&
(strings.EqualFold(tName, name.Table.L)) ||
name.Table.L == "" {
for _, field := range t.Fields {
if strings.EqualFold(field.Field, name.Name.L) && !field.IsDeleted {
found = true
break
}
}
if found {
break
}
}
}

// log.Info(e.Name.Name, "--------", found)
// for _, t := range tables {
// log.Info(t.AsName, ",", t.Name)
// for _, f := range t.Fields {
// fmt.Print(f.Field, " ")
// }
// fmt.Println()
// }
// log.Info(name.Name, "--------", found)
// for _, t := range tables {
// log.Info(t.AsName, ",", t.Name)
// for _, f := range t.Fields {
// fmt.Print(f.Field, " ")
// }
// fmt.Println()
// }

if found {
return true
if found {
return true
} else {
if name.Table.L == "" {
s.AppendErrorNo(ER_COLUMN_NOT_EXISTED, name.Name.O)
} else {
if e.Name.Table.L == "" {
s.AppendErrorNo(ER_COLUMN_NOT_EXISTED, e.Name.Name.O)
} else {
s.AppendErrorNo(ER_COLUMN_NOT_EXISTED,
fmt.Sprintf("%s.%s", e.Name.Table.O, e.Name.Name.O))
}
return false
s.AppendErrorNo(ER_COLUMN_NOT_EXISTED,
fmt.Sprintf("%s.%s", name.Table.O, name.Name.O))
}
default:
// log.Infof("checkItem: %#v", e)
return true
return false
}
}

// checkFuncItem 检查函数的字段
func (s *session) checkFuncItem(f *ast.FuncCallExpr, tables []*TableInfo) bool {

for _, arg := range f.Args {
s.checkItem(arg, tables)
}

// log.Info(f.FnName.L)
// switch f.FnName.L {
// case ast.Nullif:
// log.Infof("%#v", f)
// for _, arg := range f.Args {
// log.Infof("%#v", arg)
// }
// }

return false
}

// checkFuncItem 检查聚合函数的字段
func (s *session) checkAggregateFuncItem(f *ast.AggregateFuncExpr, tables []*TableInfo) bool {

for _, arg := range f.Args {
s.checkItem(arg, tables)
}

// log.Info(f.F)
// switch f.FnName.L {
// case ast.Nullif:
// log.Infof("%#v", f)
// for _, arg := range f.Args {
// log.Infof("%#v", arg)
// }
// }

return false
}

func (s *session) checkDelete(node *ast.DeleteStmt, sql string) {
Expand Down
29 changes: 29 additions & 0 deletions session/session_inception_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,35 @@ insert into t2 select id from t1;`
s.testErrorCode(c, sql,
session.NewErr(session.ER_COLUMN_NOT_EXISTED, "now"))

sql = `insert into t1(id) values(nullif(a,'123'));`
s.testErrorCode(c, sql,
session.NewErr(session.ER_COLUMN_NOT_EXISTED, "a"))

sql = `insert into t1(id) values(now);`
s.testErrorCode(c, sql,
session.NewErr(session.ER_COLUMN_NOT_EXISTED, "now"))

sql = `insert into t1(id) values(now());`
s.testErrorCode(c, sql)

sql = `insert into t1(id) values(max(1));`
s.testErrorCode(c, sql)

sql = `insert into t1(id) values(max(a));`
s.testErrorCode(c, sql,
session.NewErr(session.ER_COLUMN_NOT_EXISTED, "a"))

sql = `insert into t1(id) values(abs(-1));`
s.testErrorCode(c, sql)

sql = `insert into t1(id) values(cast(a as signed));`
s.testErrorCode(c, sql,
session.NewErr(session.ER_COLUMN_NOT_EXISTED, "a"))

sql = `drop table if exists tt1;create table tt1(id int,c1 int);insert into tt1(id) select max(id) from tt1 where id in (select id1 from tt1);`
s.testErrorCode(c, sql,
session.NewErr(session.ER_COLUMN_NOT_EXISTED, "id1"))

}

func (s *testSessionIncSuite) TestUpdate(c *C) {
Expand Down

0 comments on commit ff3fa75

Please sign in to comment.