diff --git a/session/session_inception.go b/session/session_inception.go index 4b769ff4..b117bd81 100644 --- a/session/session_inception.go +++ b/session/session_inception.go @@ -813,7 +813,7 @@ func (s *session) processCommand(ctx context.Context, stmtNode ast.StmtNode, } } } - s.checkSelectItem(node) + s.checkSelectItem(node, false) if s.opt.execute { s.AppendErrorNo(ER_NOT_SUPPORTED_YET) } @@ -826,7 +826,7 @@ func (s *session) processCommand(ctx context.Context, stmtNode ast.StmtNode, } } } - s.checkSelectItem(node) + s.checkSelectItem(node, false) if s.opt.execute { s.AppendErrorNo(ER_NOT_SUPPORTED_YET) } @@ -5563,7 +5563,7 @@ func (s *session) checkInsert(node *ast.InsertStmt, sql string) { if !s.hasError() { // 如果不是新建表时,则直接explain if haveNewTable { - s.checkSelectItem(x.Select) + s.checkSelectItem(x.Select, sel.Where != nil) } else { var selectSql string if table.IsNew || table.IsNewColumns || s.DBVersion < 50600 { @@ -7118,7 +7118,10 @@ func (s *session) checkUpdate(node *ast.UpdateStmt, sql string) { s.checkItem(l.Expr, tableInfoList) } - s.checkSelectItem(node.TableRefs.TableRefs) + // log.Infof("%#v", node.TableRefs) + // log.Infof("%#v", node.TableRefs.TableRefs) + + s.checkSelectItem(node.TableRefs.TableRefs, node.Where != nil) // if node.TableRefs.TableRefs.On != nil { // s.checkItem(node.TableRefs.TableRefs.On.Expr, tableInfoList) // } @@ -7230,14 +7233,14 @@ func (s *session) checkItem(expr ast.ExprNode, tables []*TableInfo) bool { s.checkItem(e.Expr, tables) case *ast.SubqueryExpr: - s.checkSelectItem(e.Query) + s.checkSelectItem(e.Query, false) case *ast.CompareSubqueryExpr: s.checkItem(e.L, tables) s.checkItem(e.R, tables) case *ast.ExistsSubqueryExpr: - s.checkSelectItem(e.Sel) + s.checkSelectItem(e.Sel, false) case *ast.IsNullExpr: s.checkItem(e.Expr, tables) @@ -8082,7 +8085,8 @@ func (s *session) copyTableInfo(t *TableInfo) *TableInfo { return p } -func (s *session) checkSelectItem(node ast.ResultSetNode) []*TableInfo { +// checkSelectItem 子句递归检查 +func (s *session) checkSelectItem(node ast.ResultSetNode, hasWhere bool) []*TableInfo { if node == nil { return nil } @@ -8107,18 +8111,23 @@ func (s *session) checkSelectItem(node ast.ResultSetNode) []*TableInfo { return s.checkSubSelectItem(x) case *ast.Join: - tableInfoList := s.checkSelectItem(x.Left) - tableInfoList = append(tableInfoList, s.checkSelectItem(x.Right)...) + tableInfoList := s.checkSelectItem(x.Left, false) + tableInfoList = append(tableInfoList, s.checkSelectItem(x.Right, false)...) // b, _ := json.MarshalIndent(x, "", " ") // log.Info(string(b)) // log.Infof("%#v", x.Left) // log.Infof("%#v", x.Right) + // log.Infof("%#v", x) + if x.On != nil { s.checkItem(x.On.Expr, tableInfoList) } else if x.Right != nil { - s.AppendErrorNo(ErrJoinNoOnCondition) + // 没有任何where条件时 + if !hasWhere && !x.NaturalJoin && !x.StraightJoin && x.Using == nil { + s.AppendErrorNo(ErrJoinNoOnCondition) + } } return tableInfoList case *ast.TableSource: @@ -8153,7 +8162,7 @@ func (s *session) checkSelectItem(node ast.ResultSetNode) []*TableInfo { } case *ast.UnionStmt: - s.checkSelectItem(tblSource) + s.checkSelectItem(tblSource, false) cols := s.getSubSelectColumns(tblSource) if cols != nil { @@ -8170,7 +8179,7 @@ func (s *session) checkSelectItem(node ast.ResultSetNode) []*TableInfo { } default: - return s.checkSelectItem(tblSource) + return s.checkSelectItem(tblSource, false) // log.Infof("%T", x) // log.Infof("%#v", x) } @@ -8236,7 +8245,7 @@ func (s *session) checkSubSelectItem(node *ast.SelectStmt) []*TableInfo { } default: log.Infof("con:%d %T", s.sessionVars.ConnectionID, x) - tableInfoList = append(tableInfoList, s.checkSelectItem(tblSource)...) + tableInfoList = append(tableInfoList, s.checkSelectItem(tblSource, false)...) } } diff --git a/session/session_inception_test.go b/session/session_inception_test.go index 82b75d17..ab054166 100644 --- a/session/session_inception_test.go +++ b/session/session_inception_test.go @@ -1459,13 +1459,33 @@ func (s *testSessionIncSuite) TestUpdate(c *C) { sql = "create table t1(id int,c1 int);update t1 set c1 = 1;" s.testErrorCode(c, sql, session.NewErr(session.ER_NO_WHERE_CONDITION)) - config.GetGlobalConfig().Inc.CheckDMLWhere = false - // where - config.GetGlobalConfig().Inc.CheckDMLWhere = true - sql = "create table t1(id int,c1 int);create table t2(id int,c1 int);update t1 join t2 set t1.c1 = 1 where t1.id=1;" + sql = `create table t1(id int,c1 int); + create table t2(id int,c1 int);` + s.mustRunExec(c, sql) + + sql = `update t1 join t2 set t1.c1 = 1;` s.testErrorCode(c, sql, - session.NewErr(session.ErrJoinNoOnCondition)) + session.NewErr(session.ErrJoinNoOnCondition), + session.NewErr(session.ER_NO_WHERE_CONDITION)) + + sql = `update t1,t2 set t1.c1 = 1;` + s.testErrorCode(c, sql, + session.NewErr(session.ErrJoinNoOnCondition), + session.NewErr(session.ER_NO_WHERE_CONDITION)) + + sql = `update t1 NATURAL join t2 set t1.c1 = 1 ;` + s.testErrorCode(c, sql, + session.NewErr(session.ER_NO_WHERE_CONDITION)) + + sql = `create table t1(id int,c1 int); + create table t2(id int,c1 int); + update t1 join t2 using(id) set t1.c1 = 1 ;` + s.testErrorCode(c, sql, + session.NewErr(session.ER_NO_WHERE_CONDITION)) + + sql = `update t1,t2 set t1.c1 = 1 where t1.id=1;` + s.testErrorCode(c, sql) config.GetGlobalConfig().Inc.CheckDMLWhere = false // limit @@ -1484,7 +1504,7 @@ func (s *testSessionIncSuite) TestUpdate(c *C) { // 受影响行数 s.realRowCount = false - res := s.runCheck("create table t1(id int,c1 int);update t1 set c1 = 1;") + res := s.runCheck("drop table if exists t1,t2;create table t1(id int,c1 int);update t1 set c1 = 1;") row := res.Rows()[int(s.tk.Se.AffectedRows())-1] c.Assert(row[2], Equals, "0") c.Assert(row[6], Equals, "0")