From 08a56a0343dbe367b95c0f4085c32648602b5201 Mon Sep 17 00:00:00 2001 From: sylzd Date: Thu, 2 Dec 2021 13:11:53 +0800 Subject: [PATCH] cherry pick #30121 to release-5.0 Signed-off-by: ti-srebot --- expression/builtin_compare.go | 28 ++++++++++++++++++++++++++++ expression/builtin_compare_test.go | 10 ++++++++++ expression/integration_test.go | 9 +++++++++ expression/typeinfer_test.go | 7 +++++++ 4 files changed, 54 insertions(+) diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 5d8afe1696916..fb52dcdd132e6 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -470,6 +470,14 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre } switch tp { case types.ETInt: + // adjust unsigned flag + greastInitUnsignedFlag := false + if isEqualsInitUnsignedFlag(greastInitUnsignedFlag, args) { + bf.tp.Flag &= ^mysql.UnsignedFlag + } else { + bf.tp.Flag |= mysql.UnsignedFlag + } + sig = &builtinGreatestIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_GreatestInt) case types.ETReal: @@ -701,6 +709,14 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi } switch tp { case types.ETInt: + // adjust unsigned flag + leastInitUnsignedFlag := true + if isEqualsInitUnsignedFlag(leastInitUnsignedFlag, args) { + bf.tp.Flag |= mysql.UnsignedFlag + } else { + bf.tp.Flag &= ^mysql.UnsignedFlag + } + sig = &builtinLeastIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_LeastInt) case types.ETReal: @@ -2770,3 +2786,15 @@ func CompareJSON(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhs } return int64(json.CompareBinary(arg0, arg1)), false, nil } + +// isEqualsInitUnsignedFlag can adjust unsigned flag for greatest/least function. +// For greatest, returns unsigned result if there is at least one argument is unsigned. +// For least, returns signed result if there is at least one argument is signed. +func isEqualsInitUnsignedFlag(initUnsigned bool, args []Expression) bool { + for _, arg := range args { + if initUnsigned != mysql.HasUnsignedFlag(arg.GetType().Flag) { + return false + } + } + return true +} diff --git a/expression/builtin_compare_test.go b/expression/builtin_compare_test.go index 9d11d8b5ad18a..833d7e8db79a3 100644 --- a/expression/builtin_compare_test.go +++ b/expression/builtin_compare_test.go @@ -263,6 +263,8 @@ func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) { sc := s.ctx.GetSessionVars().StmtCtx originIgnoreTruncate := sc.IgnoreTruncate sc.IgnoreTruncate = true + decG := &types.MyDecimal{} + decL := &types.MyDecimal{} defer func() { sc.IgnoreTruncate = originIgnoreTruncate }() @@ -274,6 +276,14 @@ func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) { isNil bool getErr bool }{ + { + []interface{}{int64(-9223372036854775808), uint64(9223372036854775809)}, + decG.FromUint(9223372036854775809), decL.FromInt(-9223372036854775808), false, false, + }, + { + []interface{}{uint64(9223372036854775808), uint64(9223372036854775809)}, + uint64(9223372036854775809), uint64(9223372036854775808), false, false, + }, { []interface{}{1, 2, 3, 4}, int64(4), int64(1), false, false, diff --git a/expression/integration_test.go b/expression/integration_test.go index f9557b1c574cd..f43070d56f2ee 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -9274,3 +9274,12 @@ func (s *testIntegrationSuite) TestConstPropNullFunctions(c *C) { tk.MustExec("insert into t2 values (0, 'c', null), (1, null, 0.1), (3, 'b', 0.01), (2, 'q', 0.12), (null, 'a', -0.1), (null, null, null)") tk.MustQuery("select * from t2 where t2.i2=((select count(1) from t1 where t1.i1=t2.i2))").Check(testkit.Rows("1 0.1")) } + +func (s *testIntegrationSuite) TestIssue30101(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 bigint unsigned, c2 bigint unsigned);") + tk.MustExec("insert into t1 values(9223372036854775808, 9223372036854775809);") + tk.MustQuery("select greatest(c1, c2) from t1;").Sort().Check(testkit.Rows("9223372036854775809")) +} diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 10e146d6189ab..5c5788d98d939 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -1031,6 +1031,13 @@ func (s *testInferTypeSuite) createTestCase4CompareFuncs() []typeInferTestCase { {"interval(c_int_d, c_int_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"interval(c_int_d, c_float_d, c_double_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + + {"greatest(c_bigint_d, c_ubigint_d, c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"greatest(c_ubigint_d, c_ubigint_d, c_uint_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, + {"greatest(c_uint_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 11, 0}, + {"least(c_bigint_d, c_ubigint_d, c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"least(c_ubigint_d, c_ubigint_d, c_uint_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, + {"least(c_uint_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0}, } }