Skip to content

Commit

Permalink
fix: 修复join语法的ON子句审核不准确的问题
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchuanchuan committed Jan 3, 2020
1 parent 7780bee commit a7ea2c1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 19 deletions.
35 changes: 22 additions & 13 deletions session/session_inception.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ func (s *session) processCommand(ctx context.Context, stmtNode ast.StmtNode,
}
}
}
s.checkSelectItem(node)
s.checkSelectItem(node, false)
if s.opt.execute {
s.AppendErrorNo(ER_NOT_SUPPORTED_YET)
}
Expand All @@ -826,7 +826,7 @@ func (s *session) processCommand(ctx context.Context, stmtNode ast.StmtNode,
}
}
}
s.checkSelectItem(node)
s.checkSelectItem(node, false)
if s.opt.execute {
s.AppendErrorNo(ER_NOT_SUPPORTED_YET)
}
Expand Down Expand Up @@ -5563,7 +5563,7 @@ func (s *session) checkInsert(node *ast.InsertStmt, sql string) {
if !s.hasError() {
// 如果不是新建表时,则直接explain
if haveNewTable {
s.checkSelectItem(x.Select)
s.checkSelectItem(x.Select, sel.Where != nil)
} else {
var selectSql string
if table.IsNew || table.IsNewColumns || s.DBVersion < 50600 {
Expand Down Expand Up @@ -7118,7 +7118,10 @@ func (s *session) checkUpdate(node *ast.UpdateStmt, sql string) {
s.checkItem(l.Expr, tableInfoList)
}

s.checkSelectItem(node.TableRefs.TableRefs)
// log.Infof("%#v", node.TableRefs)
// log.Infof("%#v", node.TableRefs.TableRefs)

s.checkSelectItem(node.TableRefs.TableRefs, node.Where != nil)
// if node.TableRefs.TableRefs.On != nil {
// s.checkItem(node.TableRefs.TableRefs.On.Expr, tableInfoList)
// }
Expand Down Expand Up @@ -7230,14 +7233,14 @@ func (s *session) checkItem(expr ast.ExprNode, tables []*TableInfo) bool {
s.checkItem(e.Expr, tables)

case *ast.SubqueryExpr:
s.checkSelectItem(e.Query)
s.checkSelectItem(e.Query, false)

case *ast.CompareSubqueryExpr:
s.checkItem(e.L, tables)
s.checkItem(e.R, tables)

case *ast.ExistsSubqueryExpr:
s.checkSelectItem(e.Sel)
s.checkSelectItem(e.Sel, false)

case *ast.IsNullExpr:
s.checkItem(e.Expr, tables)
Expand Down Expand Up @@ -8082,7 +8085,8 @@ func (s *session) copyTableInfo(t *TableInfo) *TableInfo {
return p
}

func (s *session) checkSelectItem(node ast.ResultSetNode) []*TableInfo {
// checkSelectItem 子句递归检查
func (s *session) checkSelectItem(node ast.ResultSetNode, hasWhere bool) []*TableInfo {
if node == nil {
return nil
}
Expand All @@ -8107,18 +8111,23 @@ func (s *session) checkSelectItem(node ast.ResultSetNode) []*TableInfo {
return s.checkSubSelectItem(x)

case *ast.Join:
tableInfoList := s.checkSelectItem(x.Left)
tableInfoList = append(tableInfoList, s.checkSelectItem(x.Right)...)
tableInfoList := s.checkSelectItem(x.Left, false)
tableInfoList = append(tableInfoList, s.checkSelectItem(x.Right, false)...)

// b, _ := json.MarshalIndent(x, "", " ")
// log.Info(string(b))

// log.Infof("%#v", x.Left)
// log.Infof("%#v", x.Right)
// log.Infof("%#v", x)

if x.On != nil {
s.checkItem(x.On.Expr, tableInfoList)
} else if x.Right != nil {
s.AppendErrorNo(ErrJoinNoOnCondition)
// 没有任何where条件时
if !hasWhere && !x.NaturalJoin && !x.StraightJoin && x.Using == nil {
s.AppendErrorNo(ErrJoinNoOnCondition)
}
}
return tableInfoList
case *ast.TableSource:
Expand Down Expand Up @@ -8153,7 +8162,7 @@ func (s *session) checkSelectItem(node ast.ResultSetNode) []*TableInfo {
}

case *ast.UnionStmt:
s.checkSelectItem(tblSource)
s.checkSelectItem(tblSource, false)

cols := s.getSubSelectColumns(tblSource)
if cols != nil {
Expand All @@ -8170,7 +8179,7 @@ func (s *session) checkSelectItem(node ast.ResultSetNode) []*TableInfo {
}

default:
return s.checkSelectItem(tblSource)
return s.checkSelectItem(tblSource, false)
// log.Infof("%T", x)
// log.Infof("%#v", x)
}
Expand Down Expand Up @@ -8236,7 +8245,7 @@ func (s *session) checkSubSelectItem(node *ast.SelectStmt) []*TableInfo {
}
default:
log.Infof("con:%d %T", s.sessionVars.ConnectionID, x)
tableInfoList = append(tableInfoList, s.checkSelectItem(tblSource)...)
tableInfoList = append(tableInfoList, s.checkSelectItem(tblSource, false)...)
}
}

Expand Down
32 changes: 26 additions & 6 deletions session/session_inception_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1459,13 +1459,33 @@ func (s *testSessionIncSuite) TestUpdate(c *C) {
sql = "create table t1(id int,c1 int);update t1 set c1 = 1;"
s.testErrorCode(c, sql,
session.NewErr(session.ER_NO_WHERE_CONDITION))
config.GetGlobalConfig().Inc.CheckDMLWhere = false

// where
config.GetGlobalConfig().Inc.CheckDMLWhere = true
sql = "create table t1(id int,c1 int);create table t2(id int,c1 int);update t1 join t2 set t1.c1 = 1 where t1.id=1;"
sql = `create table t1(id int,c1 int);
create table t2(id int,c1 int);`
s.mustRunExec(c, sql)

sql = `update t1 join t2 set t1.c1 = 1;`
s.testErrorCode(c, sql,
session.NewErr(session.ErrJoinNoOnCondition))
session.NewErr(session.ErrJoinNoOnCondition),
session.NewErr(session.ER_NO_WHERE_CONDITION))

sql = `update t1,t2 set t1.c1 = 1;`
s.testErrorCode(c, sql,
session.NewErr(session.ErrJoinNoOnCondition),
session.NewErr(session.ER_NO_WHERE_CONDITION))

sql = `update t1 NATURAL join t2 set t1.c1 = 1 ;`
s.testErrorCode(c, sql,
session.NewErr(session.ER_NO_WHERE_CONDITION))

sql = `create table t1(id int,c1 int);
create table t2(id int,c1 int);
update t1 join t2 using(id) set t1.c1 = 1 ;`
s.testErrorCode(c, sql,
session.NewErr(session.ER_NO_WHERE_CONDITION))

sql = `update t1,t2 set t1.c1 = 1 where t1.id=1;`
s.testErrorCode(c, sql)
config.GetGlobalConfig().Inc.CheckDMLWhere = false

// limit
Expand All @@ -1484,7 +1504,7 @@ func (s *testSessionIncSuite) TestUpdate(c *C) {

// 受影响行数
s.realRowCount = false
res := s.runCheck("create table t1(id int,c1 int);update t1 set c1 = 1;")
res := s.runCheck("drop table if exists t1,t2;create table t1(id int,c1 int);update t1 set c1 = 1;")
row := res.Rows()[int(s.tk.Se.AffectedRows())-1]
c.Assert(row[2], Equals, "0")
c.Assert(row[6], Equals, "0")
Expand Down

0 comments on commit a7ea2c1

Please sign in to comment.