From 44314aca6b1e4305d0019b334ee6eaaf3b675fe5 Mon Sep 17 00:00:00 2001 From: wjHuang Date: Sat, 19 Sep 2020 10:59:08 +0800 Subject: [PATCH] expression: fix a bug that DML using caseWhen may cause schema change (#20095) Signed-off-by: wjhuang2016 --- expression/constant_fold.go | 2 +- expression/integration_test.go | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/expression/constant_fold.go b/expression/constant_fold.go index 8045838d594ef..3f1913a0ce91e 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -97,7 +97,7 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { foldedExpr.GetType().Decimal = expr.GetType().Decimal return foldedExpr, isDeferredConst } - return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst + return foldedExpr, isDeferredConst } } else { // for no-const, here should return directly, because the following branches are unknown to be run or not diff --git a/expression/integration_test.go b/expression/integration_test.go index e80190808d0f6..80b627487db88 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -3592,6 +3592,7 @@ func (s *testIntegrationSuite) TestAggregationBuiltin(c *C) { defer s.cleanEnv(c) tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") + tk.MustExec("drop table if exists t") tk.MustExec("create table t(a decimal(7, 6))") tk.MustExec("insert into t values(1.123456), (1.123456)") result := tk.MustQuery("select avg(a) from t") @@ -3618,6 +3619,19 @@ func (s *testIntegrationSuite) TestAggregationBuiltin(c *C) { result.Check(testkit.Rows("18446744073709551615")) } +func (s *testIntegrationSuite) Test19387(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("USE test;") + + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a decimal(16, 2));") + tk.MustExec("select sum(case when 1 then a end) from t group by a;") + res := tk.MustQuery("show create table t") + c.Assert(len(res.Rows()), Equals, 1) + str := res.Rows()[0][1].(string) + c.Assert(strings.Contains(str, "decimal(16,2)"), IsTrue) +} + func (s *testIntegrationSuite) TestAggregationBuiltinBitOr(c *C) { defer s.cleanEnv(c) tk := testkit.NewTestKit(c, s.store)