diff --git a/parser/lexer.go b/parser/lexer.go index c489246e..80d7e4d2 100644 --- a/parser/lexer.go +++ b/parser/lexer.go @@ -127,7 +127,9 @@ func (s *Scanner) Errorf(format string, a ...interface{}) (err error) { if len(val) > 2048 { val = val[:2048] } - err = fmt.Errorf("line %d column %d near \"%s\"%s (total length %d)", s.r.p.Line, s.r.p.Col, val, str, len(s.r.s)) + // err = fmt.Errorf("line %d column %d near \"%s\"%s (total length %d)", s.r.p.Line, s.r.p.Col, val, str, len(s.r.s)) + err = fmt.Errorf("You have an error in your SQL syntax, near '%-.200s'%s at line %d", + val, str, s.r.p.Line) return } @@ -207,8 +209,8 @@ func (s *Scanner) GetSQLMode() mysql.SQLMode { } // EnableWindowFunc enables the scanner to recognize the keywords of window function. -func (s *Scanner) EnableWindowFunc() { - s.supportWindowFunc = true +func (s *Scanner) EnableWindowFunc(val bool) { + s.supportWindowFunc = val } // NewScanner returns a new scanner object. diff --git a/parser/yy_parser.go b/parser/yy_parser.go index e5998f27..34e188fe 100644 --- a/parser/yy_parser.go +++ b/parser/yy_parser.go @@ -28,23 +28,27 @@ import ( "github.com/pingcap/errors" ) -const ( - codeErrParse = terror.ErrCode(mysql.ErrParse) - codeErrSyntax = terror.ErrCode(mysql.ErrSyntax) - codeErrUnknownAlterLock = terror.ErrCode(mysql.ErrUnknownAlterLock) - codeErrUnknownAlterAlgorithm = terror.ErrCode(mysql.ErrUnknownAlterAlgorithm) -) - var ( // ErrSyntax returns for sql syntax error. - ErrSyntax = terror.ClassParser.New(codeErrSyntax, mysql.MySQLErrName[mysql.ErrSyntax]) + ErrSyntax = terror.ClassParser.New(mysql.ErrSyntax, mysql.MySQLErrName[mysql.ErrSyntax]) // ErrParse returns for sql parse error. - ErrParse = terror.ClassParser.New(codeErrParse, mysql.MySQLErrName[mysql.ErrParse]) + ErrParse = terror.ClassParser.New(mysql.ErrParse, mysql.MySQLErrName[mysql.ErrParse]) + // ErrUnknownCharacterSet returns for no character set found error. + ErrUnknownCharacterSet = terror.ClassParser.New(mysql.ErrUnknownCharacterSet, mysql.MySQLErrName[mysql.ErrUnknownCharacterSet]) + // ErrInvalidYearColumnLength returns for illegal column length for year type. + ErrInvalidYearColumnLength = terror.ClassParser.New(mysql.ErrInvalidYearColumnLength, mysql.MySQLErrName[mysql.ErrInvalidYearColumnLength]) + // ErrWrongArguments returns for illegal argument. + ErrWrongArguments = terror.ClassParser.New(mysql.ErrWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments]) + // ErrWrongFieldTerminators returns for illegal field terminators. + ErrWrongFieldTerminators = terror.ClassParser.New(mysql.ErrWrongFieldTerminators, mysql.MySQLErrName[mysql.ErrWrongFieldTerminators]) + // ErrTooBigDisplayWidth returns for data display width exceed limit . + ErrTooBigDisplayWidth = terror.ClassParser.New(mysql.ErrTooBigDisplaywidth, mysql.MySQLErrName[mysql.ErrTooBigDisplaywidth]) + // ErrTooBigPrecision returns for data precision exceed limit. + ErrTooBigPrecision = terror.ClassParser.New(mysql.ErrTooBigPrecision, mysql.MySQLErrName[mysql.ErrTooBigPrecision]) // ErrUnknownAlterLock returns for no alter lock type found error. - ErrUnknownAlterLock = terror.ClassParser.New(codeErrUnknownAlterLock, mysql.MySQLErrName[mysql.ErrUnknownAlterLock]) + ErrUnknownAlterLock = terror.ClassParser.New(mysql.ErrUnknownAlterLock, mysql.MySQLErrName[mysql.ErrUnknownAlterLock]) // ErrUnknownAlterAlgorithm returns for no alter algorithm found error. - ErrUnknownAlterAlgorithm = terror.ClassParser.New(codeErrUnknownAlterAlgorithm, mysql.MySQLErrName[mysql.ErrUnknownAlterAlgorithm]) - + ErrUnknownAlterAlgorithm = terror.ClassParser.New(mysql.ErrUnknownAlterAlgorithm, mysql.MySQLErrName[mysql.ErrUnknownAlterAlgorithm]) // SpecFieldPattern special result field pattern SpecFieldPattern = regexp.MustCompile(`(\/\*!(M?[0-9]{5,6})?|\*\/)`) specCodePattern = regexp.MustCompile(`\/\*!(M?[0-9]{5,6})?([^*]|\*+[^*/])*\*+\/`) @@ -54,10 +58,16 @@ var ( func init() { parserMySQLErrCodes := map[terror.ErrCode]uint16{ - codeErrSyntax: mysql.ErrSyntax, - codeErrParse: mysql.ErrParse, - codeErrUnknownAlterLock: mysql.ErrUnknownAlterLock, - codeErrUnknownAlterAlgorithm: mysql.ErrUnknownAlterAlgorithm, + mysql.ErrSyntax: mysql.ErrSyntax, + mysql.ErrParse: mysql.ErrParse, + mysql.ErrUnknownCharacterSet: mysql.ErrUnknownCharacterSet, + mysql.ErrInvalidYearColumnLength: mysql.ErrInvalidYearColumnLength, + mysql.ErrWrongArguments: mysql.ErrWrongArguments, + mysql.ErrWrongFieldTerminators: mysql.ErrWrongFieldTerminators, + mysql.ErrTooBigDisplaywidth: mysql.ErrTooBigDisplaywidth, + mysql.ErrUnknownAlterLock: mysql.ErrUnknownAlterLock, + mysql.ErrUnknownAlterAlgorithm: mysql.ErrUnknownAlterAlgorithm, + mysql.ErrTooBigPrecision: mysql.ErrTooBigPrecision, } terror.ErrClassToMySQLCodes[terror.ClassParser] = parserMySQLErrCodes } @@ -127,6 +137,14 @@ func (parser *Parser) Parse(sql, charset, collation string) (stmt []ast.StmtNode return parser.result, warns, nil } +func (parser *Parser) lastErrorAsWarn() { + if len(parser.lexer.errs) == 0 { + return + } + parser.lexer.warns = append(parser.lexer.warns, parser.lexer.errs[len(parser.lexer.errs)-1]) + parser.lexer.errs = parser.lexer.errs[:len(parser.lexer.errs)-1] +} + // ParseOneStmt parses a query and returns an ast.StmtNode. // The query must have one statement, otherwise ErrSyntax is returned. func (parser *Parser) ParseOneStmt(sql, charset, collation string) (ast.StmtNode, error) { @@ -146,6 +164,11 @@ func (parser *Parser) SetSQLMode(mode mysql.SQLMode) { parser.lexer.SetSQLMode(mode) } +// EnableWindowFunc controls whether the parser to parse syntax related with window function. +func (parser *Parser) EnableWindowFunc(val bool) { + parser.lexer.EnableWindowFunc(val) +} + // ParseErrorWith returns "You have a syntax error near..." error message compatible with mysql. func ParseErrorWith(errstr string, lineno int) error { if len(errstr) > mysql.ErrTextLength { @@ -191,12 +214,12 @@ func toInt(l yyLexer, lval *yySymType, str string) int { // get value 99999999999999999999999999999999999999999999999999999999999999999 return toDecimal(l, lval, str) } - l.Errorf("integer literal: %v", err) + l.AppendError(l.Errorf("integer literal: %v", err)) return int(unicode.ReplacementChar) } switch { - case n < math.MaxInt64: + case n <= math.MaxInt64: lval.item = int64(n) default: lval.item = n diff --git a/session/session_inception_print_test.go b/session/session_inception_print_test.go index ce3cd467..ae12ed12 100644 --- a/session/session_inception_print_test.go +++ b/session/session_inception_print_test.go @@ -141,7 +141,8 @@ func (s *testSessionPrintSuite) TestInsert(c *C) { res = s.makeSQL("insert into t1 values;") row = res.Rows()[int(s.tk.Se.AffectedRows())-1] c.Assert(row[2], Equals, "2", Commentf("%v", row)) - c.Assert(row[4], Equals, "line 1 column 21 near \"\" (total length 21)", Commentf("%v", row)) + // c.Assert(row[4], Equals, "line 1 column 21 near \"\" (total length 21)", Commentf("%v", row)) + c.Assert(row[4], Equals, "You have an error in your SQL syntax, near '' at line 1", Commentf("%v", row)) } diff --git a/session/session_inception_split_test.go b/session/session_inception_split_test.go index 9dc67ebf..60d857e0 100644 --- a/session/session_inception_split_test.go +++ b/session/session_inception_split_test.go @@ -135,7 +135,8 @@ inception_magic_commit;` c.Assert(int(s.tk.Se.AffectedRows()), Equals, 1) row := res.Rows()[s.tk.Se.AffectedRows()-1] - c.Assert(row[3], Equals, "line 1 column 3 near \"\" (total length 3)") + // c.Assert(row[3], Equals, "line 1 column 3 near \"\" (total length 3)") + c.Assert(row[3], Equals, "You have an error in your SQL syntax, near '' at line 1") } func (s *testSessionSplitSuite) TestInsert(c *C) {