From 60c0b2f1d8d5b646b5cffcc6f5c7b342d84db8c5 Mon Sep 17 00:00:00 2001 From: Zhuhe Fang Date: Fri, 13 Nov 2020 15:42:19 +0800 Subject: [PATCH] cherry pick #19797 to release-4.0 Signed-off-by: ti-srebot --- expression/function_traits.go | 10 +++++++--- expression/integration_test.go | 9 +++++++++ planner/core/expression_rewriter.go | 10 ++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/expression/function_traits.go b/expression/function_traits.go index 0da50400e9647..27f677e90b968 100644 --- a/expression/function_traits.go +++ b/expression/function_traits.go @@ -60,9 +60,13 @@ var DisableFoldFunctions = map[string]struct{}{ // otherwise, the child functions do not fold constant. // Note: the function itself should fold constant. var TryFoldFunctions = map[string]struct{}{ - ast.If: {}, - ast.Ifnull: {}, - ast.Case: {}, + ast.If: {}, + ast.Ifnull: {}, + ast.Case: {}, + ast.LogicAnd: {}, + ast.LogicOr: {}, + ast.Coalesce: {}, + ast.Interval: {}, } // IllegalFunctions4GeneratedColumns stores functions that is illegal for generated columns. diff --git a/expression/integration_test.go b/expression/integration_test.go index 86b92118dc8be..4e1d1d3466717 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2838,6 +2838,15 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) { tk.MustQuery("select 1 or b/0 from t") tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select 1 or 1/0") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select 0 and 1/0") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select COALESCE(1, 1/0)") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select interval(1,0,1,2,1/0)") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select case 2.0 when 2.0 then 3.0 when 3.0 then 2.0 end").Check(testkit.Rows("3.0")) tk.MustQuery("select case 2.0 when 3.0 then 2.0 when 4.0 then 3.0 else 5.0 end").Check(testkit.Rows("5.0")) tk.MustQuery("select case cast('2011-01-01' as date) when cast('2011-01-01' as date) then cast('2011-02-02' as date) end").Check(testkit.Rows("2011-02-02")) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 834e4c01c2c97..90cd0646cf9be 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -400,6 +400,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { er.ctxStackAppend(er.schema.Columns[index], er.names[index]) return inNode, true case *ast.FuncCallExpr: + er.asScalar = true if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok { er.disableFoldCounter++ } @@ -407,12 +408,18 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { er.tryFoldCounter++ } case *ast.CaseExpr: + er.asScalar = true if _, ok := expression.DisableFoldFunctions["case"]; ok { er.disableFoldCounter++ } if _, ok := expression.TryFoldFunctions["case"]; ok { er.tryFoldCounter++ } + case *ast.BinaryOperationExpr: + er.asScalar = true + if v.Op == opcode.LogicAnd || v.Op == opcode.LogicOr { + er.tryFoldCounter++ + } case *ast.SetCollationExpr: // Do nothing default: @@ -972,6 +979,9 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok case *ast.UnaryOperationExpr: er.unaryOpToExpression(v) case *ast.BinaryOperationExpr: + if v.Op == opcode.LogicAnd || v.Op == opcode.LogicOr { + er.tryFoldCounter-- + } er.binaryOpToExpression(v) case *ast.BetweenExpr: er.betweenToExpression(v)