From e3a9c6dac8e497fd00499bf80d89a0d6e1be9931 Mon Sep 17 00:00:00 2001 From: exialin Date: Mon, 14 Jan 2019 09:47:52 +0800 Subject: [PATCH] *: fix the lower bound when converting numbers less than 0 to unsigned integers (#9028) --- executor/distsql.go | 2 +- executor/executor.go | 7 +++--- executor/executor_test.go | 1 + executor/load_data.go | 2 +- executor/write_test.go | 41 +++++++++++++++++++++++++++++++++++ expression/builtin_cast.go | 12 ++++++---- expression/errors.go | 4 ++-- sessionctx/stmtctx/stmtctx.go | 22 ++++++++++++++++++- types/convert.go | 15 +++++++++---- types/datum.go | 26 +++++++++++----------- 10 files changed, 103 insertions(+), 29 deletions(-) diff --git a/executor/distsql.go b/executor/distsql.go index 9b8c5600d3f8e..22e326e2390b0 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -122,7 +122,7 @@ func statementContextToFlags(sc *stmtctx.StatementContext) uint64 { var flags uint64 if sc.InInsertStmt { flags |= model.FlagInInsertStmt - } else if sc.InUpdateOrDeleteStmt { + } else if sc.InUpdateStmt || sc.InDeleteStmt { flags |= model.FlagInUpdateOrDeleteStmt } else if sc.InSelectStmt { flags |= model.FlagInSelectStmt diff --git a/executor/executor.go b/executor/executor.go index 8d47a4bab4968..91a58b8ea6ab7 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1245,7 +1245,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { // pushing them down to TiKV as flags. switch stmt := s.(type) { case *ast.UpdateStmt: - sc.InUpdateOrDeleteStmt = true + sc.InUpdateStmt = true sc.DupKeyAsWarning = stmt.IgnoreErr sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr @@ -1253,7 +1253,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.IgnoreZeroInDate = !vars.StrictSQLMode || stmt.IgnoreErr sc.Priority = stmt.Priority case *ast.DeleteStmt: - sc.InUpdateOrDeleteStmt = true + sc.InDeleteStmt = true sc.DupKeyAsWarning = stmt.IgnoreErr sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr @@ -1274,6 +1274,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.DupKeyAsWarning = true sc.BadNullAsWarning = true sc.TruncateAsWarning = !vars.StrictSQLMode + sc.InLoadDataStmt = true case *ast.SelectStmt: sc.InSelectStmt = true @@ -1314,7 +1315,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.PrevLastInsertID = vars.StmtCtx.PrevLastInsertID } sc.PrevAffectedRows = 0 - if vars.StmtCtx.InUpdateOrDeleteStmt || vars.StmtCtx.InInsertStmt { + if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt { sc.PrevAffectedRows = int64(vars.StmtCtx.AffectedRows()) } else if vars.StmtCtx.InSelectStmt { sc.PrevAffectedRows = -1 diff --git a/executor/executor_test.go b/executor/executor_test.go index 5c6a53005713b..cd07756f560e3 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -329,6 +329,7 @@ func checkCases(tests []testCase, ld *executor.LoadDataInfo, c.Assert(ctx.NewTxn(), IsNil) ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true + ctx.GetSessionVars().StmtCtx.InLoadDataStmt = true data, reachLimit, err1 := ld.InsertData(tt.data1, tt.data2) c.Assert(err1, IsNil) c.Assert(reachLimit, IsFalse) diff --git a/executor/load_data.go b/executor/load_data.go index 258b5cfbd233d..74a5cff129fde 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -201,7 +201,7 @@ func (e *LoadDataInfo) getLine(prevData, curData []byte) ([]byte, []byte, bool) } // InsertData inserts data into specified table according to the specified format. -// If it has the rest of data isn't completed the processing, then is returns without completed data. +// If it has the rest of data isn't completed the processing, then it returns without completed data. // If the number of inserted rows reaches the batchRows, then the second return value is true. // If prevData isn't nil and curData is nil, there are no other data to deal with and the isEOF is true. func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error) { diff --git a/executor/write_test.go b/executor/write_test.go index 8ea7a5a572128..045c61158d9e7 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -227,6 +227,27 @@ func (s *testSuite) TestInsert(c *C) { tk.MustExec("insert into test values(2, 3)") tk.MustQuery("select * from test use index (id) where id = 2").Check(testkit.Rows("2 2", "2 3")) + // issue 6360 + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a bigint unsigned);") + tk.MustExec(" set @orig_sql_mode = @@sql_mode; set @@sql_mode = 'strict_all_tables';") + _, err = tk.Exec("insert into t value (-1);") + c.Assert(types.ErrWarnDataOutOfRange.Equal(err), IsTrue) + tk.MustExec("set @@sql_mode = '';") + tk.MustExec("insert into t value (-1);") + // TODO: the following warning messages are not consistent with MySQL, fix them in the future PRs + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustExec("insert into t select -1;") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustExec("insert into t select cast(-1 as unsigned);") + tk.MustExec("insert into t value (-1.111);") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustExec("insert into t value ('-1.111');") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 BIGINT UNSIGNED value is out of range in '-1'")) + r = tk.MustQuery("select * from t;") + r.Check(testkit.Rows("0", "0", "18446744073709551615", "0", "0")) + tk.MustExec("set @@sql_mode = @orig_sql_mode;") + // issue 6424 tk.MustExec("drop table if exists t") tk.MustExec("create table t(a time(6))") @@ -1699,6 +1720,26 @@ func (s *testSuite) TestLoadDataIgnoreLines(c *C) { checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL) } +// related to issue 6360 +func (s *testSuite) TestLoadDataOverflowBigintUnsigned(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test; drop table if exists load_data_test;") + tk.MustExec("CREATE TABLE load_data_test (a bigint unsigned);") + tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test") + ctx := tk.Se.(sessionctx.Context) + ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataInfo) + c.Assert(ok, IsTrue) + defer ctx.SetValue(executor.LoadDataVarKey, nil) + c.Assert(ld, NotNil) + tests := []testCase{ + {nil, []byte("-1\n-18446744073709551615\n-18446744073709551616\n"), []string{"0", "0", "0"}, nil}, + {nil, []byte("-9223372036854775809\n18446744073709551616\n"), []string{"0", "18446744073709551615"}, nil}, + } + deleteSQL := "delete from load_data_test" + selectSQL := "select * from load_data_test;" + checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL) +} + func (s *testSuite) TestBatchInsertDelete(c *C) { originLimit := atomic.LoadUint64(&kv.TxnEntryCountLimit) defer func() { diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index c61d97f6a4924..403ba83457172 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -464,7 +464,8 @@ func (b *builtinCastIntAsRealSig) evalReal(row chunk.Row) (res float64, isNull b res = 0 } else { var uVal uint64 - uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + sc := b.ctx.GetSessionVars().StmtCtx + uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) res = float64(uVal) } return res, false, errors.Trace(err) @@ -491,7 +492,8 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyDe res = &types.MyDecimal{} } else { var uVal uint64 - uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + sc := b.ctx.GetSessionVars().StmtCtx + uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) if err != nil { return res, false, errors.Trace(err) } @@ -520,7 +522,8 @@ func (b *builtinCastIntAsStringSig) evalString(row chunk.Row) (res string, isNul res = strconv.FormatInt(val, 10) } else { var uVal uint64 - uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + sc := b.ctx.GetSessionVars().StmtCtx + uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) if err != nil { return res, false, errors.Trace(err) } @@ -747,7 +750,8 @@ func (b *builtinCastRealAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool res = 0 } else { var uintVal uint64 - uintVal, err = types.ConvertFloatToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble) + sc := b.ctx.GetSessionVars().StmtCtx + uintVal, err = types.ConvertFloatToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble) res = int64(uintVal) } return res, isNull, errors.Trace(err) diff --git a/expression/errors.go b/expression/errors.go index 4b1a3163e1453..ce35d125bddbf 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -70,7 +70,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error { return err } sc := ctx.GetSessionVars().StmtCtx - if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateOrDeleteStmt) { + if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt) { return err } sc.AppendWarning(err) @@ -80,7 +80,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error { // handleDivisionByZeroError reports error or warning depend on the context. func handleDivisionByZeroError(ctx sessionctx.Context) error { sc := ctx.GetSessionVars().StmtCtx - if sc.InInsertStmt || sc.InUpdateOrDeleteStmt { + if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt { if !ctx.GetSessionVars().SQLMode.HasErrorForDivisionByZeroMode() { return nil } diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index eecb56d1ac05c..6819b2eaeb8d2 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -47,8 +47,10 @@ type StatementContext struct { // If IsDDLJobInQueue is true, it means the DDL job is in the queue of storage, and it can be handled by the DDL worker. IsDDLJobInQueue bool InInsertStmt bool - InUpdateOrDeleteStmt bool + InUpdateStmt bool + InDeleteStmt bool InSelectStmt bool + InLoadDataStmt bool IgnoreTruncate bool IgnoreZeroInDate bool DupKeyAsWarning bool @@ -276,3 +278,21 @@ func (sc *StatementContext) GetExecDetails() execdetails.ExecDetails { sc.mu.Unlock() return details } + +// ShouldClipToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types. +// This is the case for `insert`, `update`, `alter table` and `load data infile` statements, when not in strict SQL mode. +// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html +func (sc *StatementContext) ShouldClipToZero() bool { + // TODO: Currently altering column of integer to unsigned integer is not supported. + // If it is supported one day, that case should be added here. + return sc.InInsertStmt || sc.InLoadDataStmt +} + +// ShouldIgnoreOverflowError indicates whether we should ignore the error when type conversion overflows, +// so we can leave it for further processing like clipping values less than 0 to 0 for unsigned integer types. +func (sc *StatementContext) ShouldIgnoreOverflowError() bool { + if (sc.InInsertStmt && sc.TruncateAsWarning) || sc.InLoadDataStmt { + return true + } + return false +} diff --git a/types/convert.go b/types/convert.go index 27df51b771997..87e3cd82e24a8 100644 --- a/types/convert.go +++ b/types/convert.go @@ -106,7 +106,11 @@ func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) { } // ConvertIntToUint converts an int value to an uint value. -func ConvertIntToUint(val int64, upperBound uint64, tp byte) (uint64, error) { +func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) { + if sc.ShouldClipToZero() && val < 0 { + return 0, overflow(val, tp) + } + if uint64(val) > upperBound { return upperBound, overflow(val, tp) } @@ -124,9 +128,12 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) { } // ConvertFloatToUint converts a float value to an uint value. -func ConvertFloatToUint(fval float64, upperBound uint64, tp byte) (uint64, error) { +func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) { val := RoundFloat(fval) if val < 0 { + if sc.ShouldClipToZero() { + return 0, overflow(val, tp) + } return uint64(int64(val)), overflow(val, tp) } @@ -400,7 +407,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j json.BinaryJSON, unsigned return ConvertFloatToInt(f, lBound, uBound, mysql.TypeDouble) } bound := UnsignedUpperBound[mysql.TypeLonglong] - u, err := ConvertFloatToUint(f, bound, mysql.TypeDouble) + u, err := ConvertFloatToUint(sc, f, bound, mysql.TypeDouble) return int64(u), errors.Trace(err) case json.TypeCodeString: return StrToInt(sc, hack.String(j.GetString())) @@ -423,7 +430,7 @@ func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float6 case json.TypeCodeInt64: return float64(j.GetInt64()), nil case json.TypeCodeUint64: - u, err := ConvertIntToUint(j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + u, err := ConvertIntToUint(sc, j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) return float64(u), errors.Trace(err) case json.TypeCodeFloat64: return j.GetFloat64(), nil diff --git a/types/datum.go b/types/datum.go index 86d94792f2762..33558c416cfdb 100644 --- a/types/datum.go +++ b/types/datum.go @@ -865,21 +865,21 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( ) switch d.k { case KindInt64: - val, err = ConvertIntToUint(d.GetInt64(), upperBound, tp) + val, err = ConvertIntToUint(sc, d.GetInt64(), upperBound, tp) case KindUint64: val, err = ConvertUintToUint(d.GetUint64(), upperBound, tp) case KindFloat32, KindFloat64: - val, err = ConvertFloatToUint(d.GetFloat64(), upperBound, tp) + val, err = ConvertFloatToUint(sc, d.GetFloat64(), upperBound, tp) case KindString, KindBytes: - val, err = StrToUint(sc, d.GetString()) - if err != nil { - return ret, errors.Trace(err) + uval, err1 := StrToUint(sc, d.GetString()) + if err1 != nil && ErrOverflow.Equal(err1) && !sc.ShouldIgnoreOverflowError() { + return ret, errors.Trace(err1) } - val, err = ConvertUintToUint(val, upperBound, tp) + val, err = ConvertUintToUint(uval, upperBound, tp) if err != nil { return ret, errors.Trace(err) } - ret.SetUint64(val) + err = err1 case KindMysqlTime: dec := d.GetMysqlTime().ToNumber() err = dec.Round(dec, 0, ModeHalfEven) @@ -887,7 +887,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( if err == nil { err = err1 } - val, err1 = ConvertIntToUint(ival, upperBound, tp) + val, err1 = ConvertIntToUint(sc, ival, upperBound, tp) if err == nil { err = err1 } @@ -896,18 +896,18 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( err = dec.Round(dec, 0, ModeHalfEven) ival, err1 := dec.ToInt() if err1 == nil { - val, err = ConvertIntToUint(ival, upperBound, tp) + val, err = ConvertIntToUint(sc, ival, upperBound, tp) } case KindMysqlDecimal: fval, err1 := d.GetMysqlDecimal().ToFloat64() - val, err = ConvertFloatToUint(fval, upperBound, tp) + val, err = ConvertFloatToUint(sc, fval, upperBound, tp) if err == nil { err = err1 } case KindMysqlEnum: - val, err = ConvertFloatToUint(d.GetMysqlEnum().ToNumber(), upperBound, tp) + val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp) case KindMysqlSet: - val, err = ConvertFloatToUint(d.GetMysqlSet().ToNumber(), upperBound, tp) + val, err = ConvertFloatToUint(sc, d.GetMysqlSet().ToNumber(), upperBound, tp) case KindBinaryLiteral, KindMysqlBit: val, err = d.GetBinaryLiteral().ToInt(sc) case KindMysqlJSON: @@ -1137,7 +1137,7 @@ func ProduceDecWithSpecifiedTp(dec *MyDecimal, tp *FieldType, sc *stmtctx.Statem return nil, errors.Trace(err) } if !dec.IsZero() && frac > decimal && dec.Compare(&old) != 0 { - if sc.InInsertStmt || sc.InUpdateOrDeleteStmt { + if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt { // fix https://github.com/pingcap/tidb/issues/3895 // fix https://github.com/pingcap/tidb/issues/5532 sc.AppendWarning(ErrTruncated)