From 52dcfa05765507d038877c2abda30199fec90a41 Mon Sep 17 00:00:00 2001 From: Haibin Xie Date: Mon, 26 Aug 2019 18:49:09 +0800 Subject: [PATCH] address comments --- parser.go | 9 ++++++--- parser.y | 3 +++ parser_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ yy_parser.go | 17 ++++++++++++----- 4 files changed, 61 insertions(+), 8 deletions(-) diff --git a/parser.go b/parser.go index 2efe4ff36..02464aaa9 100644 --- a/parser.go +++ b/parser.go @@ -12539,9 +12539,10 @@ yynewstate: case 1080: { st := &ast.SelectStmt{ - SelectStmtOpts: yyS[yypt-1].item.(*ast.SelectStmtOpts), - Distinct: yyS[yypt-1].item.(*ast.SelectStmtOpts).Distinct, - Fields: yyS[yypt-0].item.(*ast.FieldList), + SelectStmtOpts: yyS[yypt-1].item.(*ast.SelectStmtOpts), + Distinct: yyS[yypt-1].item.(*ast.SelectStmtOpts).Distinct, + Fields: yyS[yypt-0].item.(*ast.FieldList), + QueryBlockOffset: parser.queryBlockOffset(), } if st.SelectStmtOpts.TableHints != nil { st.TableHints = st.SelectStmtOpts.TableHints @@ -14580,6 +14581,7 @@ yynewstate: s.SetText(lexer.stmtText()) } parser.result = append(parser.result, s) + parser.blockOffset = 0 } } case 1521: @@ -14590,6 +14592,7 @@ yynewstate: s.SetText(lexer.stmtText()) } parser.result = append(parser.result, s) + parser.blockOffset = 0 } } case 1522: diff --git a/parser.y b/parser.y index 3f1eac45c..dc06d87f1 100644 --- a/parser.y +++ b/parser.y @@ -5738,6 +5738,7 @@ SelectStmtBasic: SelectStmtOpts: $2.(*ast.SelectStmtOpts), Distinct: $2.(*ast.SelectStmtOpts).Distinct, Fields: $3.(*ast.FieldList), + QueryBlockOffset: parser.queryBlockOffset(), } if st.SelectStmtOpts.TableHints != nil { st.TableHints = st.SelectStmtOpts.TableHints @@ -8102,6 +8103,7 @@ StatementList: s.SetText(lexer.stmtText()) } parser.result = append(parser.result, s) + parser.blockOffset = 0 } } | StatementList ';' Statement @@ -8112,6 +8114,7 @@ StatementList: s.SetText(lexer.stmtText()) } parser.result = append(parser.result, s) + parser.blockOffset = 0 } } diff --git a/parser_test.go b/parser_test.go index e39a8e370..7b5d0d9e2 100644 --- a/parser_test.go +++ b/parser_test.go @@ -4290,3 +4290,43 @@ func (checker *nodeTextCleaner) Enter(in ast.Node) (out ast.Node, skipChildren b func (checker *nodeTextCleaner) Leave(in ast.Node) (out ast.Node, ok bool) { return in, true } + +type queryBlockOffsetChecker struct { + curOffset int + mismatch bool +} + +func (checker *queryBlockOffsetChecker) Enter(in ast.Node) (ast.Node, bool) { + sel, ok := in.(*ast.SelectStmt) + if !ok { + return in, false + } + checker.curOffset++ + if sel.QueryBlockOffset != checker.curOffset { + checker.mismatch = true + } + return in, false +} + +func (checker *queryBlockOffsetChecker) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, true +} + +func (s *testParserSuite) SelectStmtOffset(c *C) { + parser := parser.New() + sqls := []string{ + "select * from t; select * from t", + "select a, (select count(*) from t t1 where t1.b > t.a) from t where b > (select b from t t2 where t2.b = t.a limit 1)", + "select count(*) from t t1 where t1.a < (select count(*) from t t2 where t1.a > t2.a)", + } + checker := &queryBlockOffsetChecker{} + for _, sql := range sqls { + stmts, _, err := parser.Parse(sql, "", "") + c.Assert(err, IsNil) + for _, stmt := range stmts { + checker.curOffset = 0 + stmt.Accept(checker) + c.Assert(checker.mismatch, IsFalse) + } + } +} diff --git a/yy_parser.go b/yy_parser.go index 46e61000a..77bb7e6e3 100644 --- a/yy_parser.go +++ b/yy_parser.go @@ -91,11 +91,12 @@ func TrimComment(txt string) string { // Parser represents a parser instance. Some temporary objects are stored in it to reduce object allocation during Parse function. type Parser struct { - charset string - collation string - result []ast.StmtNode - src string - lexer Scanner + charset string + collation string + result []ast.StmtNode + src string + lexer Scanner + blockOffset int // the following fields are used by yyParse to reduce allocation. cache []yySymType @@ -134,6 +135,7 @@ func (parser *Parser) Parse(sql, charset, collation string) (stmt []ast.StmtNode parser.collation = collation parser.src = sql parser.result = parser.result[:0] + parser.blockOffset = 0 var l yyLexer parser.lexer.reset(sql) @@ -217,6 +219,11 @@ func (parser *Parser) endOffset(v *yySymType) int { return offset } +func (parser *Parser) queryBlockOffset() int { + parser.blockOffset++ + return parser.blockOffset +} + func toInt(l yyLexer, lval *yySymType, str string) int { n, err := strconv.ParseUint(str, 10, 64) if err != nil {