Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner, executor: fix cast not check error #21064

Closed
wants to merge 13 commits into from
Closed
9 changes: 0 additions & 9 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2315,7 +2315,6 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) {
defer s.cleanEnv(c)
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
ctx := context.Background()

// for is true && is false
tk.MustExec("drop table if exists t")
Expand Down Expand Up @@ -2753,14 +2752,6 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) {
_, err = tk.Exec("insert into t values(-9223372036854775809)")
c.Assert(err, NotNil)

// test case decimal precision less than the scale.
rs, err := tk.Exec("select cast(12.1 as decimal(3, 4));")
c.Assert(err, IsNil)
_, err = session.GetRows4Test(ctx, tk.Se, rs)
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, "[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '').")
c.Assert(rs.Close(), IsNil)

// test unhex and hex
result = tk.MustQuery("select unhex('4D7953514C')")
result.Check(testkit.Rows("MySQL"))
Expand Down
71 changes: 57 additions & 14 deletions planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TryAddExtraLimit(ctx sessionctx.Context, node ast.StmtNode) ast.StmtNode {

// Preprocess resolves table names of the node, and checks some statements validation.
func Preprocess(ctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema, preprocessOpt ...PreprocessOpt) error {
v := preprocessor{is: is, ctx: ctx, tableAliasInJoin: make([]map[string]interface{}, 0)}
v := preprocessor{is: is, ctx: ctx, sql: node.Text(), tableAliasInJoin: make([]map[string]interface{}, 0)}
for _, optFn := range preprocessOpt {
optFn(&v)
}
Expand Down Expand Up @@ -110,6 +110,7 @@ const (
type preprocessor struct {
is infoschema.InfoSchema
ctx sessionctx.Context
sql string
err error
flag preprocessorFlag

Expand Down Expand Up @@ -184,6 +185,8 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
if node.FnName.L == ast.NextVal || node.FnName.L == ast.LastVal || node.FnName.L == ast.SetVal {
p.flag |= inSequenceFunction
}
case *ast.FuncCastExpr:
p.checkCastGrammar(node)
case *ast.BRIEStmt:
if node.Kind == ast.BRIEKindRestore {
p.flag |= inCreateOrDropTable
Expand Down Expand Up @@ -802,18 +805,21 @@ func checkColumn(colDef *ast.ColumnDef) error {
}

// Check column type.
tp := colDef.Tp
return checkTp(colDef.Tp, colDef.Name.Name.O, "")
}

func checkTp(tp *types.FieldType, colName, val string) error {
if tp == nil {
return nil
}
if tp.Flen > math.MaxUint32 {
return types.ErrTooBigDisplayWidth.GenWithStack("Display width out of range for column '%s' (max = %d)", colDef.Name.Name.O, math.MaxUint32)
return types.ErrTooBigDisplayWidth.GenWithStack("Display width out of range for column '%s' (max = %d)", colName, math.MaxUint32)
}

switch tp.Tp {
case mysql.TypeString:
if tp.Flen != types.UnspecifiedLength && tp.Flen > mysql.MaxFieldCharLength {
return types.ErrTooBigFieldLength.GenWithStack("Column length too big for column '%s' (max = %d); use BLOB or TEXT instead", colDef.Name.Name.O, mysql.MaxFieldCharLength)
return types.ErrTooBigFieldLength.GenWithStack("Column length too big for column '%s' (max = %d); use BLOB or TEXT instead", colName, mysql.MaxFieldCharLength)
}
case mysql.TypeVarchar:
if len(tp.Charset) == 0 {
Expand All @@ -822,7 +828,7 @@ func checkColumn(colDef *ast.ColumnDef) error {
// return nil, to make the check in the ddl.CreateTable.
return nil
}
err := ddl.IsTooBigFieldLength(colDef.Tp.Flen, colDef.Name.Name.O, tp.Charset)
err := ddl.IsTooBigFieldLength(tp.Flen, colName, tp.Charset)
if err != nil {
return err
}
Expand All @@ -835,41 +841,58 @@ func checkColumn(colDef *ast.ColumnDef) error {
// For Double type Flen and Decimal check is moved to parser component
default:
if tp.Flen > mysql.MaxDoublePrecisionLength {
return types.ErrWrongFieldSpec.GenWithStackByArgs(colDef.Name.Name.O)
return types.ErrWrongFieldSpec.GenWithStackByArgs(colName)
}
}
} else {
if tp.Decimal > mysql.MaxFloatingTypeScale {
return types.ErrTooBigScale.GenWithStackByArgs(tp.Decimal, colDef.Name.Name.O, mysql.MaxFloatingTypeScale)
return types.ErrTooBigScale.GenWithStackByArgs(tp.Decimal, colName, mysql.MaxFloatingTypeScale)
}
if tp.Flen > mysql.MaxFloatingTypeWidth {
return types.ErrTooBigDisplayWidth.GenWithStackByArgs(colDef.Name.Name.O, mysql.MaxFloatingTypeWidth)
return types.ErrTooBigDisplayWidth.GenWithStackByArgs(colName, mysql.MaxFloatingTypeWidth)
}
}
case mysql.TypeSet:
if len(tp.Elems) > mysql.MaxTypeSetMembers {
return types.ErrTooBigSet.GenWithStack("Too many strings for column %s and SET", colDef.Name.Name.O)
return types.ErrTooBigSet.GenWithStack("Too many strings for column %s and SET", colName)
}
// Check set elements. See https://dev.mysql.com/doc/refman/5.7/en/set.html.
for _, str := range colDef.Tp.Elems {
for _, str := range tp.Elems {
if strings.Contains(str, ",") {
return types.ErrIllegalValueForType.GenWithStackByArgs(types.TypeStr(tp.Tp), str)
}
}
case mysql.TypeNewDecimal:
if tp.Decimal > mysql.MaxDecimalScale {
return types.ErrTooBigScale.GenWithStackByArgs(tp.Decimal, colDef.Name.Name.O, mysql.MaxDecimalScale)
var arg string
if colName == "" {
arg = val
} else {
arg = colName
}
return types.ErrTooBigScale.GenWithStackByArgs(tp.Decimal, arg, mysql.MaxDecimalScale)

}

if tp.Flen > mysql.MaxDecimalWidth {
return types.ErrTooBigPrecision.GenWithStackByArgs(tp.Flen, colDef.Name.Name.O, mysql.MaxDecimalWidth)
var arg string
if colName == "" {
arg = val
} else {
arg = colName
}
return types.ErrTooBigPrecision.GenWithStackByArgs(tp.Flen, arg, mysql.MaxDecimalWidth)
}

if tp.Flen < tp.Decimal {
return types.ErrMBiggerThanD.GenWithStackByArgs(colName)
}
case mysql.TypeBit:
if tp.Flen <= 0 {
return types.ErrInvalidFieldSize.GenWithStackByArgs(colDef.Name.Name.O)
return types.ErrInvalidFieldSize.GenWithStackByArgs(colName)
}
if tp.Flen > mysql.MaxBitDisplayWidth {
return types.ErrTooBigDisplayWidth.GenWithStackByArgs(colDef.Name.Name.O, mysql.MaxBitDisplayWidth)
return types.ErrTooBigDisplayWidth.GenWithStackByArgs(colName, mysql.MaxBitDisplayWidth)
}
default:
// TODO: Add more types.
Expand Down Expand Up @@ -1063,3 +1086,23 @@ func (p *preprocessor) resolveCreateSequenceStmt(stmt *ast.CreateSequenceStmt) {
return
}
}
func (p *preprocessor) checkCastGrammar(node *ast.FuncCastExpr) {
var val string
switch x := node.Expr.(type) {
case ast.ValueExpr:
val = x.GetDatumString()
if val == "" {
val = fmt.Sprintf("%v", x.GetValue())
} else {
wrapChar := p.sql[x.OriginTextPosition() : x.OriginTextPosition()+1]
val = wrapChar + val + wrapChar
}
case *ast.ColumnNameExpr:
val = x.Name.Name.O
default:
}
if err := checkTp(node.Tp, "", val); err != nil {
p.err = err
return
}
}
25 changes: 25 additions & 0 deletions planner/core/preprocess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,31 @@ func (s *testValidatorSuite) TestValidator(c *C) {
{"CREATE TABLE t (a int, index(a));", false, nil},
{"CREATE INDEX `` on t (a);", true, errors.New("[ddl:1280]Incorrect index name ''")},
{"CREATE INDEX `` on t ((lower(a)));", true, errors.New("[ddl:1280]Incorrect index name ''")},

// for ErrTooBigScale
{`select convert('0.0', Decimal(41, 40))`, false, errors.New(`[types:1425]Too big scale 40 specified for column ''0.0''. Maximum is 30` + ".")},
// for cast decimal ErrTooBigPrecision
{`select * from t where d = cast(d as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column 'd'. Maximum is 65` + ".")},
{`select * from t where d = cast(111 as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '111'. Maximum is 65` + ".")},
{`select * from t where d = cast("abc" as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '"abc"'. Maximum is 65` + ".")},
{`select * from t where d = cast('d' as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column ''d''. Maximum is 65` + ".")},
{`select cast(d as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column 'd'. Maximum is 65` + ".")},
{`select cast(111 as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '111'. Maximum is 65` + ".")},
{`select cast("abc" as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '"abc"'. Maximum is 65` + ".")},
{`select cast("'d'" as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '"'d'"'. Maximum is 65` + ".")},
// for cast decimal ErrMBiggerThanD
{`select * from t where d = cast(d as decimal(10,20))`, false, errors.New(`[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '')` + ".")},
{`select * from t where d = cast("d" as decimal(10,20))`, false, errors.New(`[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '')` + ".")},
{`select * from t where d = cast("'d'" as decimal(10,20))`, false, errors.New(`[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '')` + ".")},
// for convert decimal ErrTooBigPrecision
{`select * from t where d = convert(d, decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column 'd'. Maximum is 65` + ".")},
{`select * from t where d = convert(111, decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '111'. Maximum is 65` + ".")},
{`select * from t where d = convert("abc", decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '"abc"'. Maximum is 65` + ".")},
{`select * from t where d = convert('d', decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column ''d''. Maximum is 65` + ".")},
// for convert decimal ErrMBiggerThanD
{`select * from t where d = convert(d , decimal(10,20))`, false, errors.New(`[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '')` + ".")},
{`select * from t where d = convert("d", decimal(10,20))`, false, errors.New(`[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '')` + ".")},
{`select * from t where d = convert("'d'", decimal(10,20))`, false, errors.New(`[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '')` + ".")},
Comment on lines +294 to +296
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why all the columns names are empty when error M bigger than D happened?

}

_, err := s.se.Execute(context.Background(), "use test")
Expand Down