diff --git a/executor/executor_test.go b/executor/executor_test.go index 19592632b1ff1..9324729e3c146 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1080,6 +1080,19 @@ func (s *testSuite) TestBuiltin(c *C) { result = tk.MustQuery("select cast(a as signed) from t") result.Check(testkit.Rows("130000")) + // fixed issue #3762 + result = tk.MustQuery("select -9223372036854775809;") + result.Check(testkit.Rows("-9223372036854775809")) + result = tk.MustQuery("select --9223372036854775809;") + result.Check(testkit.Rows("9223372036854775809")) + result = tk.MustQuery("select -9223372036854775808;") + result.Check(testkit.Rows("-9223372036854775808")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a bigint(30));") + _, err := tk.Exec("insert into t values(-9223372036854775809)") + c.Assert(err, NotNil) + // test unhex and hex result = tk.MustQuery("select unhex('4D7953514C')") result.Check(testkit.Rows("MySQL")) diff --git a/executor/prepared.go b/executor/prepared.go index 550bca73f031f..71e2d2e6b8921 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -336,6 +336,7 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) { switch s.(type) { case *ast.UpdateStmt, *ast.InsertStmt, *ast.DeleteStmt: sc.IgnoreTruncate = false + sc.IgnoreOverflow = false sc.TruncateAsWarning = !sessVars.StrictSQLMode if _, ok := s.(*ast.InsertStmt); !ok { sc.InUpdateOrDeleteStmt = true @@ -343,16 +344,20 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) { case *ast.CreateTableStmt, *ast.AlterTableStmt: // Make sure the sql_mode is strict when checking column default value. sc.IgnoreTruncate = false + sc.IgnoreOverflow = false sc.TruncateAsWarning = false case *ast.LoadDataStmt: sc.IgnoreTruncate = false + sc.IgnoreOverflow = false sc.TruncateAsWarning = !sessVars.StrictSQLMode case *ast.SelectStmt: + sc.IgnoreOverflow = true // Return warning for truncate error in selection. sc.IgnoreTruncate = false sc.TruncateAsWarning = true default: sc.IgnoreTruncate = true + sc.IgnoreOverflow = false if show, ok := s.(*ast.ShowStmt); ok { if show.Tp == ast.ShowWarnings { sc.InShowWarning = true diff --git a/expression/builtin.go b/expression/builtin.go index 27e2018f0d655..2cf0e0c120faf 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -786,7 +786,7 @@ var funcs = map[string]functionClass{ ast.UnaryNot: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryNot, 1, 1}, opcode.Not}, ast.BitNeg: &unaryOpFunctionClass{baseFunctionClass{ast.BitNeg, 1, 1}, opcode.BitNeg}, ast.UnaryPlus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryPlus, 1, 1}, opcode.Plus}, - ast.UnaryMinus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}, opcode.Minus}, + ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}}, ast.In: &inFunctionClass{baseFunctionClass{ast.In, 1, -1}}, ast.IsTruth: &isTrueOpFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth}, ast.IsFalsity: &isTrueOpFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity}, diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 52712459ca8d4..3e84311169b26 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -14,8 +14,12 @@ package expression import ( + "fmt" + "math" + "github.com/juju/errors" "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" ) @@ -27,6 +31,7 @@ var ( _ functionClass = &bitOpFunctionClass{} _ functionClass = &isTrueOpFunctionClass{} _ functionClass = &unaryOpFunctionClass{} + _ functionClass = &unaryMinusFunctionClass{} _ functionClass = &isNullFunctionClass{} ) @@ -37,6 +42,7 @@ var ( _ builtinFunc = &builtinBitOpSig{} _ builtinFunc = &builtinIsTrueOpSig{} _ builtinFunc = &builtinUnaryOpSig{} + _ builtinFunc = &builtinUnaryMinusIntSig{} _ builtinFunc = &builtinIsNullSig{} ) @@ -335,49 +341,154 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) { default: return d, errInvalidOperation.Gen("Unsupported type %v for op.Plus", aDatum.Kind()) } - case opcode.Minus: - switch aDatum.Kind() { - case types.KindInt64: - d.SetInt64(-aDatum.GetInt64()) - case types.KindUint64: - d.SetInt64(-int64(aDatum.GetUint64())) - case types.KindFloat64: - d.SetFloat64(-aDatum.GetFloat64()) - case types.KindFloat32: - d.SetFloat32(-aDatum.GetFloat32()) - case types.KindMysqlDuration: - dec := new(types.MyDecimal) - err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlDuration().ToNumber(), dec) - d.SetMysqlDecimal(dec) - case types.KindMysqlTime: - dec := new(types.MyDecimal) - err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlTime().ToNumber(), dec) - d.SetMysqlDecimal(dec) - case types.KindString, types.KindBytes: - f, err1 := types.StrToFloat(sc, aDatum.GetString()) - err = errors.Trace(err1) - d.SetFloat64(-f) - case types.KindMysqlDecimal: - dec := new(types.MyDecimal) - err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlDecimal(), dec) - d.SetMysqlDecimal(dec) - case types.KindMysqlHex: - d.SetFloat64(-aDatum.GetMysqlHex().ToNumber()) - case types.KindMysqlBit: - d.SetFloat64(-aDatum.GetMysqlBit().ToNumber()) - case types.KindMysqlEnum: - d.SetFloat64(-aDatum.GetMysqlEnum().ToNumber()) - case types.KindMysqlSet: - d.SetFloat64(-aDatum.GetMysqlSet().ToNumber()) - default: - return d, errInvalidOperation.Gen("Unsupported type %v for op.Minus", aDatum.Kind()) - } default: return d, errInvalidOperation.Gen("Unsupported op %v for unary op", b.op) } return } +type unaryMinusFunctionClass struct { + baseFunctionClass +} + +func (b *unaryMinusFunctionClass) handleIntOverflow(arg *Constant) (overflow bool) { + if mysql.HasUnsignedFlag(arg.GetType().Flag) { + uval := arg.Value.GetUint64() + // -math.MinInt64 is 9223372036854775808, so if uval is more than 9223372036854775808, like + // 9223372036854775809, -9223372036854775809 is less than math.MinInt64, overflow occurs. + if uval > uint64(-math.MinInt64) { + return true + } + } else { + val := arg.Value.GetInt64() + // The math.MinInt64 is -9223372036854775808, the math.MaxInt64 is 9223372036854775807, + // which is less than abs(-9223372036854775808). When val == math.MinInt64, overflow occurs. + if val == math.MinInt64 { + return true + } + } + return false +} + +// typeInfer infers unaryMinus function return type. when the arg is an int constant and overflow, +// typerInfer will infers the return type as tpDecimal, not tpInt. +func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression, ctx context.Context) (evalTp, bool) { + tp := tpInt + switch argExpr.GetTypeClass() { + case types.ClassString, types.ClassReal: + tp = tpReal + case types.ClassDecimal: + tp = tpDecimal + } + + sc := ctx.GetSessionVars().StmtCtx + overflow := false + // TODO: Handle float overflow. + if arg, ok := argExpr.(*Constant); sc.IgnoreOverflow && ok && + arg.GetTypeClass() == types.ClassInt { + overflow = b.handleIntOverflow(arg) + if overflow { + tp = tpDecimal + } + } + return tp, overflow +} + +func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Context) (sig builtinFunc, err error) { + err = b.verifyArgs(args) + if err != nil { + return nil, errors.Trace(err) + } + + argExpr := args[0] + retTp, intOverflow := b.typeInfer(argExpr, ctx) + + var bf baseBuiltinFunc + switch argExpr.GetTypeClass() { + case types.ClassInt: + if intOverflow { + bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) + sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, true} + } else { + bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpInt) + sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}} + } + case types.ClassDecimal: + bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) + sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} + case types.ClassReal: + bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) + sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} + case types.ClassString: + tp := argExpr.GetType().Tp + if types.IsTypeTime(tp) || tp == mysql.TypeDuration { + bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) + sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} + } else { + bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) + sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} + } + } + + return sig.setSelf(sig), errors.Trace(err) +} + +type builtinUnaryMinusIntSig struct { + baseIntBuiltinFunc +} + +func (b *builtinUnaryMinusIntSig) evalInt(row []types.Datum) (res int64, isNull bool, err error) { + var val int64 + val, isNull, err = b.args[0].EvalInt(row, b.getCtx().GetSessionVars().StmtCtx) + if err != nil || isNull { + return val, isNull, errors.Trace(err) + } + + if mysql.HasUnsignedFlag(b.args[0].GetType().Flag) { + uval := uint64(val) + if uval > uint64(-math.MinInt64) { + return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval)) + } else if uval == uint64(-math.MinInt64) { + return math.MinInt64, false, nil + } + } else if val == math.MinInt64 { + return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val)) + } + return -val, false, errors.Trace(err) +} + +type builtinUnaryMinusDecimalSig struct { + baseDecimalBuiltinFunc + + constantArgOverflow bool +} + +func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyDecimal, bool, error) { + sc := b.getCtx().GetSessionVars().StmtCtx + + var dec *types.MyDecimal + dec, isNull, err := b.args[0].EvalDecimal(row, sc) + if err != nil || isNull { + return dec, isNull, errors.Trace(err) + } + + to := new(types.MyDecimal) + err = types.DecimalSub(new(types.MyDecimal), dec, to) + return to, false, errors.Trace(err) +} + +type builtinUnaryMinusRealSig struct { + baseRealBuiltinFunc +} + +func (b *builtinUnaryMinusRealSig) evalReal(row []types.Datum) (res float64, isNull bool, err error) { + sc := b.getCtx().GetSessionVars().StmtCtx + var val float64 + val, isNull, err = b.args[0].EvalReal(row, sc) + res = -val + return +} + type isNullFunctionClass struct { baseFunctionClass } diff --git a/expression/builtin_op_test.go b/expression/builtin_op_test.go index 02264731a5326..4b91b2786ac52 100644 --- a/expression/builtin_op_test.go +++ b/expression/builtin_op_test.go @@ -14,6 +14,8 @@ package expression import ( + "math" + "github.com/juju/errors" . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" @@ -21,6 +23,47 @@ import ( "github.com/pingcap/tidb/util/types" ) +func (s *testEvaluatorSuite) TestUnary(c *C) { + defer testleak.AfterTest(c)() + cases := []struct { + args interface{} + expected interface{} + overflow bool + getErr bool + }{ + {uint64(9223372036854775809), "-9223372036854775809", true, false}, + {uint64(9223372036854775810), "-9223372036854775810", true, false}, + {uint64(9223372036854775808), int64(-9223372036854775808), false, false}, + {int64(math.MinInt64), "9223372036854775808", true, false}, // --9223372036854775808 + } + sc := s.ctx.GetSessionVars().StmtCtx + origin := sc.IgnoreOverflow + sc.IgnoreOverflow = true + defer func() { + sc.IgnoreOverflow = origin + }() + + for _, t := range cases { + f, err := newFunctionForTest(s.ctx, ast.UnaryMinus, primitiveValsToConstants([]interface{}{t.args})...) + c.Assert(err, IsNil) + d, err := f.Eval(nil) + if t.getErr == false { + c.Assert(err, IsNil) + if !t.overflow { + c.Assert(d.GetValue(), Equals, t.expected) + } else { + c.Assert(d.GetMysqlDecimal().String(), Equals, t.expected) + } + } else { + c.Assert(err, NotNil) + } + } + + f, err := funcs[ast.UnaryMinus].getFunction([]Expression{Zero}, s.ctx) + c.Assert(err, IsNil) + c.Assert(f.isDeterministic(), IsTrue) +} + func (s *testEvaluatorSuite) TestAndAnd(c *C) { defer testleak.AfterTest(c)() diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 920079d55313a..afb00ab2a5fc9 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -195,6 +195,7 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { res, isNull, err = sf.EvalString(row, sc) } } + if isNull || err != nil { d.SetValue(nil) return d, errors.Trace(err) diff --git a/expression/typeinferer_test.go b/expression/typeinferer_test.go index 979d260865ec9..5dc8be27f73e8 100644 --- a/expression/typeinferer_test.go +++ b/expression/typeinferer_test.go @@ -343,6 +343,10 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) { {`json_insert('{"a": 1}', '$.a', 3)`, mysql.TypeJSON, charset.CharsetUTF8, 0}, {`json_replace('{"a": 1}', '$.a', 3)`, mysql.TypeJSON, charset.CharsetUTF8, 0}, {`json_merge('{"a": 1}', '3')`, mysql.TypeJSON, charset.CharsetUTF8, 0}, + {"-9223372036854775809", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag}, + {"-9223372036854775808", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, + {"--9223372036854775809", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag}, + {"--9223372036854775808", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag}, } for _, tt := range tests { ctx := testKit.Se.(context.Context) diff --git a/plan/expression_rewriter.go b/plan/expression_rewriter.go index e78a169820196..6baca080bc25f 100644 --- a/plan/expression_rewriter.go +++ b/plan/expression_rewriter.go @@ -93,6 +93,7 @@ func (b *planBuilder) rewriteWithPreprocess(expr ast.ExprNode, p LogicalPlan, ag if getRowLen(er.ctxStack[0]) != 1 { return nil, nil, ErrOperandColumns.GenByArgs(1) } + return er.ctxStack[0], er.p, nil } @@ -142,7 +143,11 @@ func popRowArg(ctx context.Context, e expression.Expression) (ret expression.Exp return ret, errors.Trace(err) } c, _ := e.(*expression.Constant) - ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()} + if getRowLen(c) == 2 { + ret = &expression.Constant{Value: c.Value.GetRow()[1], RetType: c.GetType()} + } else { + ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()} + } return } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index b7c585d01420a..2cbc200823ead 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -319,6 +319,7 @@ type StatementContext struct { // Set the following variables before execution InUpdateOrDeleteStmt bool + IgnoreOverflow bool IgnoreTruncate bool TruncateAsWarning bool InShowWarning bool