Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: avoid unnecessary warnings/errors when folding constants in control expr (#19675) #19910

Merged
merged 2 commits into from
Sep 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 21 additions & 44 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -537,12 +537,10 @@ func (b *builtinIfIntSig) evalInt(row chunk.Row) (ret int64, isNull bool, err er
if err != nil {
return 0, true, err
}
arg1, isNull1, err := b.args[1].EvalInt(b.ctx, row)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, err
if !isNull0 && arg0 != 0 {
return b.args[1].EvalInt(b.ctx, row)
}
arg2, isNull2, err := b.args[2].EvalInt(b.ctx, row)
return arg2, isNull2, err
return b.args[2].EvalInt(b.ctx, row)
}

type builtinIfRealSig struct {
Expand All @@ -560,12 +558,10 @@ func (b *builtinIfRealSig) evalReal(row chunk.Row) (ret float64, isNull bool, er
if err != nil {
return 0, true, err
}
arg1, isNull1, err := b.args[1].EvalReal(b.ctx, row)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, err
if !isNull0 && arg0 != 0 {
return b.args[1].EvalReal(b.ctx, row)
}
arg2, isNull2, err := b.args[2].EvalReal(b.ctx, row)
return arg2, isNull2, err
return b.args[2].EvalReal(b.ctx, row)
}

type builtinIfDecimalSig struct {
Expand All @@ -583,12 +579,10 @@ func (b *builtinIfDecimalSig) evalDecimal(row chunk.Row) (ret *types.MyDecimal,
if err != nil {
return nil, true, err
}
arg1, isNull1, err := b.args[1].EvalDecimal(b.ctx, row)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, err
if !isNull0 && arg0 != 0 {
return b.args[1].EvalDecimal(b.ctx, row)
}
arg2, isNull2, err := b.args[2].EvalDecimal(b.ctx, row)
return arg2, isNull2, err
return b.args[2].EvalDecimal(b.ctx, row)
}

type builtinIfStringSig struct {
Expand All @@ -606,12 +600,10 @@ func (b *builtinIfStringSig) evalString(row chunk.Row) (ret string, isNull bool,
if err != nil {
return "", true, err
}
arg1, isNull1, err := b.args[1].EvalString(b.ctx, row)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, err
if !isNull0 && arg0 != 0 {
return b.args[1].EvalString(b.ctx, row)
}
arg2, isNull2, err := b.args[2].EvalString(b.ctx, row)
return arg2, isNull2, err
return b.args[2].EvalString(b.ctx, row)
}

type builtinIfTimeSig struct {
Expand All @@ -629,12 +621,10 @@ func (b *builtinIfTimeSig) evalTime(row chunk.Row) (ret types.Time, isNull bool,
if err != nil {
return ret, true, err
}
arg1, isNull1, err := b.args[1].EvalTime(b.ctx, row)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, err
if !isNull0 && arg0 != 0 {
return b.args[1].EvalTime(b.ctx, row)
}
arg2, isNull2, err := b.args[2].EvalTime(b.ctx, row)
return arg2, isNull2, err
return b.args[2].EvalTime(b.ctx, row)
}

type builtinIfDurationSig struct {
Expand All @@ -652,12 +642,10 @@ func (b *builtinIfDurationSig) evalDuration(row chunk.Row) (ret types.Duration,
if err != nil {
return ret, true, err
}
arg1, isNull1, err := b.args[1].EvalDuration(b.ctx, row)
if (!isNull0 && arg0 != 0) || err != nil {
return arg1, isNull1, err
if !isNull0 && arg0 != 0 {
return b.args[1].EvalDuration(b.ctx, row)
}
arg2, isNull2, err := b.args[2].EvalDuration(b.ctx, row)
return arg2, isNull2, err
return b.args[2].EvalDuration(b.ctx, row)
}

type builtinIfJSONSig struct {
Expand All @@ -675,21 +663,10 @@ func (b *builtinIfJSONSig) evalJSON(row chunk.Row) (ret json.BinaryJSON, isNull
if err != nil {
return ret, true, err
}
arg1, isNull1, err := b.args[1].EvalJSON(b.ctx, row)
if err != nil {
return ret, true, err
}
arg2, isNull2, err := b.args[2].EvalJSON(b.ctx, row)
if err != nil {
return ret, true, err
}
switch {
case isNull0 || arg0 == 0:
ret, isNull = arg2, isNull2
case arg0 != 0:
ret, isNull = arg1, isNull1
if !isNull0 && arg0 != 0 {
return b.args[1].EvalJSON(b.ctx, row)
}
return
return b.args[2].EvalJSON(b.ctx, row)
}

type ifNullFunctionClass struct {
Expand Down
35 changes: 10 additions & 25 deletions expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,8 @@ func ifFoldHandler(expr *ScalarFunction) (Expression, bool) {
}
return foldConstant(args[2])
}
var isDeferred, isDeferredConst bool
expr.GetArgs()[1], isDeferred = foldConstant(args[1])
isDeferredConst = isDeferredConst || isDeferred
expr.GetArgs()[2], isDeferred = foldConstant(args[2])
isDeferredConst = isDeferredConst || isDeferred
return expr, isDeferredConst
// if the condition is not const, which branch is unknown to run, so directly return.
return expr, false
}

func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) {
Expand All @@ -76,18 +72,17 @@ func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) {
}
return constArg, isDeferred
}
var isDeferredConst bool
expr.GetArgs()[1], isDeferredConst = foldConstant(args[1])
return expr, isDeferredConst
// if the condition is not const, which branch is unknown to run, so directly return.
return expr, false
}

func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
args, l := expr.GetArgs(), len(expr.GetArgs())
var isDeferred, isDeferredConst, hasNonConstCondition bool
var isDeferred, isDeferredConst bool
for i := 0; i < l-1; i += 2 {
expr.GetArgs()[i], isDeferred = foldConstant(args[i])
isDeferredConst = isDeferredConst || isDeferred
if _, isConst := expr.GetArgs()[i].(*Constant); isConst && !hasNonConstCondition {
if _, isConst := expr.GetArgs()[i].(*Constant); isConst {
// If the condition is const and true, and the previous conditions
// has no expr, then the folded execution body is returned, otherwise
// the arguments of the casewhen are folded and replaced.
Expand All @@ -105,20 +100,14 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst
}
} else {
hasNonConstCondition = true
// for no-const, here should return directly, because the following branches are unknown to be run or not
return expr, false
}
expr.GetArgs()[i+1], isDeferred = foldConstant(args[i+1])
isDeferredConst = isDeferredConst || isDeferred
}

if l%2 == 0 {
return expr, isDeferredConst
}

// If the number of arguments in casewhen is odd, and the previous conditions
// is const and false, then the folded else execution body is returned. otherwise
// is false, then the folded else execution body is returned. otherwise
// the execution body of the else are folded and replaced.
if !hasNonConstCondition {
if l%2 == 1 {
foldedExpr, isDeferred := foldConstant(args[l-1])
isDeferredConst = isDeferredConst || isDeferred
if _, isConst := foldedExpr.(*Constant); isConst {
Expand All @@ -127,10 +116,6 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
}
return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst
}

expr.GetArgs()[l-1], isDeferred = foldConstant(args[l-1])
isDeferredConst = isDeferredConst || isDeferred

return expr, isDeferredConst
}

Expand Down
9 changes: 9 additions & 0 deletions expression/function_traits.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ var DisableFoldFunctions = map[string]struct{}{
ast.Benchmark: {},
}

// TryFoldFunctions stores functions which try to fold constant in child scope functions if without errors/warnings,
// otherwise, the child functions do not fold constant.
// Note: the function itself should fold constant.
var TryFoldFunctions = map[string]struct{}{
ast.If: {},
ast.Ifnull: {},
ast.Case: {},
}

// IllegalFunctions4GeneratedColumns stores functions that is illegal for generated columns.
// See https://github.com/mysql/mysql-server/blob/5.7/mysql-test/suite/gcol/inc/gcol_blocked_sql_funcs_main.inc for details
var IllegalFunctions4GeneratedColumns = map[string]struct{}{
Expand Down
18 changes: 18 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2808,6 +2808,24 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) {
tk.MustQuery("select ifnull(b, b/0) from t")
tk.MustQuery("show warnings").Check(testkit.Rows())

tk.MustQuery("select case when 1 then 1 else 1/0 end")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery(" select if(1,1,1/0)")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("select ifnull(1, 1/0)")
tk.MustQuery("show warnings").Check(testkit.Rows())

tk.MustExec("delete from t")
tk.MustExec("insert t values ('str2', 0)")
tk.MustQuery("select case when b < 1 then 1 else 1/0 end from t")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("select case when b < 1 then 1 when 1/0 then b else 1/0 end from t")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("select if(b < 1 , 1, 1/0) from t")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("select ifnull(b, 1/0) from t")
tk.MustQuery("show warnings").Check(testkit.Rows())

tk.MustQuery("select case 2.0 when 2.0 then 3.0 when 3.0 then 2.0 end").Check(testkit.Rows("3.0"))
tk.MustQuery("select case 2.0 when 3.0 then 2.0 when 4.0 then 3.0 else 5.0 end").Check(testkit.Rows("5.0"))
tk.MustQuery("select case cast('2011-01-01' as date) when cast('2011-01-01' as date) then cast('2011-02-02' as date) end").Check(testkit.Rows("2011-02-02"))
Expand Down
26 changes: 22 additions & 4 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ func typeInferForNull(args []Expression) {
}

// newFunctionImpl creates a new scalar function or constant.
func newFunctionImpl(ctx sessionctx.Context, fold bool, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
// fold: 1 means folding constants, while 0 means not,
// -1 means try to fold constants if without errors/warnings, otherwise not.
func newFunctionImpl(ctx sessionctx.Context, fold int, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
if retType == nil {
return nil, errors.Errorf("RetType cannot be nil for ScalarFunction.")
}
Expand Down Expand Up @@ -210,20 +212,36 @@ func newFunctionImpl(ctx sessionctx.Context, fold bool, funcName string, retType
RetType: retType,
Function: f,
}
if fold {
if fold == 1 {
return FoldConstant(sf), nil
} else if fold == -1 {
// try to fold constants, and return the original function if errors/warnings occur
sc := ctx.GetSessionVars().StmtCtx
beforeWarns := sc.WarningCount()
newSf := FoldConstant(sf)
afterWarns := sc.WarningCount()
if afterWarns > beforeWarns {
sc.TruncateWarnings(int(beforeWarns))
return sf, nil
}
return newSf, nil
}
return sf, nil
}

// NewFunction creates a new scalar function or constant via a constant folding.
func NewFunction(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
return newFunctionImpl(ctx, true, funcName, retType, args...)
return newFunctionImpl(ctx, 1, funcName, retType, args...)
}

// NewFunctionBase creates a new scalar function with no constant folding.
func NewFunctionBase(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
return newFunctionImpl(ctx, false, funcName, retType, args...)
return newFunctionImpl(ctx, 0, funcName, retType, args...)
}

// NewFunctionTryFold creates a new scalar function with trying constant folding.
func NewFunctionTryFold(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
return newFunctionImpl(ctx, -1, funcName, retType, args...)
}

// NewFunctionInternal is similar to NewFunction, but do not returns error, should only be used internally.
Expand Down
24 changes: 24 additions & 0 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ func (b *PlanBuilder) getExpressionRewriter(ctx context.Context, p LogicalPlan)
rewriter.preprocess = nil
rewriter.insertPlan = nil
rewriter.disableFoldCounter = 0
rewriter.tryFoldCounter = 0
rewriter.ctxStack = rewriter.ctxStack[:0]
rewriter.ctxNameStk = rewriter.ctxNameStk[:0]
rewriter.ctx = ctx
Expand Down Expand Up @@ -226,6 +227,7 @@ type expressionRewriter struct {
// leaving the scope(enable again), the counter will -1.
// NOTE: This value can be changed during expression rewritten.
disableFoldCounter int
tryFoldCounter int
}

func (er *expressionRewriter) ctxStackLen() int {
Expand Down Expand Up @@ -401,6 +403,16 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok {
er.disableFoldCounter++
}
if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok {
er.tryFoldCounter++
}
case *ast.CaseExpr:
if _, ok := expression.DisableFoldFunctions["case"]; ok {
er.disableFoldCounter++
}
if _, ok := expression.TryFoldFunctions["case"]; ok {
er.tryFoldCounter++
}
case *ast.SetCollationExpr:
// Do nothing
default:
Expand Down Expand Up @@ -944,6 +956,9 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
case *ast.VariableExpr:
er.rewriteVariable(v)
case *ast.FuncCallExpr:
if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok {
er.tryFoldCounter--
}
er.funcCallToExpression(v)
if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok {
er.disableFoldCounter--
Expand All @@ -959,7 +974,13 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
case *ast.BetweenExpr:
er.betweenToExpression(v)
case *ast.CaseExpr:
if _, ok := expression.TryFoldFunctions["case"]; ok {
er.tryFoldCounter--
}
er.caseToExpression(v)
if _, ok := expression.DisableFoldFunctions["case"]; ok {
er.disableFoldCounter--
}
case *ast.FuncCastExpr:
arg := er.ctxStack[len(er.ctxStack)-1]
er.err = expression.CheckArgsNotMultiColumnRow(arg)
Expand Down Expand Up @@ -1052,6 +1073,9 @@ func (er *expressionRewriter) newFunction(funcName string, retType *types.FieldT
if er.disableFoldCounter > 0 {
return expression.NewFunctionBase(er.sctx, funcName, retType, args...)
}
if er.tryFoldCounter > 0 {
return expression.NewFunctionTryFold(er.sctx, funcName, retType, args...)
}
return expression.NewFunction(er.sctx, funcName, retType, args...)
}

Expand Down