From f4a39be405ff9b6fb1ba2303598c04539f3a16c4 Mon Sep 17 00:00:00 2001 From: Arenatlx <314806019@qq.com> Date: Wed, 24 Aug 2022 18:02:21 +0800 Subject: [PATCH] cherry pick #37117 to release-6.2 Signed-off-by: ti-srebot --- cmd/explaintest/r/subquery.result | 28 +++++++++++++ cmd/explaintest/t/subquery.test | 13 ++++++ expression/util.go | 56 +++++++++++++++++++------ expression/util_test.go | 4 +- planner/core/expression_rewriter.go | 13 +++++- planner/core/logical_plans.go | 64 +++++++++++++++++++++++++---- planner/core/rule_decorrelate.go | 20 ++++++++- 7 files changed, 174 insertions(+), 24 deletions(-) diff --git a/cmd/explaintest/r/subquery.result b/cmd/explaintest/r/subquery.result index 84bac87bb1d23..cfb170dabf0cd 100644 --- a/cmd/explaintest/r/subquery.result +++ b/cmd/explaintest/r/subquery.result @@ -46,3 +46,31 @@ create table t1(a int(11)); create table t2(a decimal(40,20) unsigned, b decimal(40,20)); select count(*) as x from t1 group by a having x not in (select a from t2 where x = t2.b); x +drop table if exists stu; +drop table if exists exam; +create table stu(id int, name varchar(100)); +insert into stu values(1, null); +create table exam(stu_id int, course varchar(100), grade int); +insert into exam values(1, 'math', 100); +set names utf8 collate utf8_general_ci; +explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +id estRows task access object operator info +Apply 10000.00 root CARTESIAN anti semi join, other cond:eq(test.stu.name, Column#8) +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:stu keep order:false, stats:pseudo +└─Projection(Probe) 10.00 root guo->Column#8 + └─TableReader 10.00 root data:Selection + └─Selection 10.00 cop[tikv] eq(test.exam.stu_id, test.stu.id) + └─TableFullScan 10000.00 cop[tikv] table:exam keep order:false, stats:pseudo +select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +id name +set names utf8mb4; +explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +id estRows task access object operator info +HashJoin 8000.00 root anti semi join, equal:[eq(test.stu.id, test.exam.stu_id)], other cond:eq(test.stu.name, "guo") +├─TableReader(Build) 10000.00 root data:TableFullScan +│ └─TableFullScan 10000.00 cop[tikv] table:exam keep order:false, stats:pseudo +└─TableReader(Probe) 10000.00 root data:TableFullScan + └─TableFullScan 10000.00 cop[tikv] table:stu keep order:false, stats:pseudo +select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +id name diff --git a/cmd/explaintest/t/subquery.test b/cmd/explaintest/t/subquery.test index 6a3aa13e7e95a..5127c0e4260fa 100644 --- a/cmd/explaintest/t/subquery.test +++ b/cmd/explaintest/t/subquery.test @@ -20,3 +20,16 @@ drop table if exists t1, t2; create table t1(a int(11)); create table t2(a decimal(40,20) unsigned, b decimal(40,20)); select count(*) as x from t1 group by a having x not in (select a from t2 where x = t2.b); + +drop table if exists stu; +drop table if exists exam; +create table stu(id int, name varchar(100)); +insert into stu values(1, null); +create table exam(stu_id int, course varchar(100), grade int); +insert into exam values(1, 'math', 100); +set names utf8 collate utf8_general_ci; +explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +set names utf8mb4; +explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); +select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id); diff --git a/expression/util.go b/expression/util.go index f3bddcc0e8e52..67de867313915 100644 --- a/expression/util.go +++ b/expression/util.go @@ -365,7 +365,8 @@ func extractColumnSet(expr Expression, set *intsets.Sparse) { } } -func setExprColumnInOperand(expr Expression) Expression { +// SetExprColumnInOperand is used to set columns in expr as InOperand. +func SetExprColumnInOperand(expr Expression) Expression { switch v := expr.(type) { case *Column: col := v.Clone().(*Column) @@ -374,7 +375,7 @@ func setExprColumnInOperand(expr Expression) Expression { case *ScalarFunction: args := v.GetArgs() for i, arg := range args { - args[i] = setExprColumnInOperand(arg) + args[i] = SetExprColumnInOperand(arg) } } return expr @@ -383,44 +384,65 @@ func setExprColumnInOperand(expr Expression) Expression { // ColumnSubstitute substitutes the columns in filter to expressions in select fields. // e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k. func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Expression { - _, resExpr := ColumnSubstituteImpl(expr, schema, newExprs) + _, _, resExpr := ColumnSubstituteImpl(expr, schema, newExprs, false) return resExpr } +// ColumnSubstituteAll substitutes the columns just like ColumnSubstitute, but we don't accept partial substitution. +// Only accept: +// +// 1: substitute them all once find col in schema. +// 2: nothing in expr can be substituted. +func ColumnSubstituteAll(expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) { + _, hasFail, resExpr := ColumnSubstituteImpl(expr, schema, newExprs, true) + return hasFail, resExpr +} + // ColumnSubstituteImpl tries to substitute column expr using newExprs, // the newFunctionInternal is only called if its child is substituted -func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) { +// @return bool means whether the expr has changed. +// @return bool means whether the expr should change (has the dependency in schema, while the corresponding expr has some compatibility), but finally fallback. +// @return Expression, the original expr or the changed expr, it depends on the first @return bool. +func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression, fail1Return bool) (bool, bool, Expression) { switch v := expr.(type) { case *Column: id := schema.ColumnIndex(v) if id == -1 { - return false, v + return false, false, v } newExpr := newExprs[id] if v.InOperand { - newExpr = setExprColumnInOperand(newExpr) + newExpr = SetExprColumnInOperand(newExpr) } newExpr.SetCoercibility(v.Coercibility()) - return true, newExpr + return true, false, newExpr case *ScalarFunction: substituted := false + hasFail := false if v.FuncName.L == ast.Cast { newFunc := v.Clone().(*ScalarFunction) - substituted, newFunc.GetArgs()[0] = ColumnSubstituteImpl(newFunc.GetArgs()[0], schema, newExprs) + substituted, hasFail, newFunc.GetArgs()[0] = ColumnSubstituteImpl(newFunc.GetArgs()[0], schema, newExprs, fail1Return) + if fail1Return && hasFail { + return substituted, hasFail, newFunc + } if substituted { // Workaround for issue https://github.com/pingcap/tidb/issues/28804 e := NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, newFunc.GetArgs()...) e.SetCoercibility(v.Coercibility()) - return true, e + return true, false, e } - return false, newFunc + return false, false, newFunc } // cowExprRef is a copy-on-write util, args array allocation happens only // when expr in args is changed refExprArr := cowExprRef{v.GetArgs(), nil} _, coll := DeriveCollationFromExprs(v.GetCtx(), v.GetArgs()...) for idx, arg := range v.GetArgs() { - changed, newFuncExpr := ColumnSubstituteImpl(arg, schema, newExprs) + changed, hasFail, newFuncExpr := ColumnSubstituteImpl(arg, schema, newExprs, fail1Return) + if fail1Return && hasFail { + return changed, hasFail, v + } + oldChanged := changed if collate.NewCollationEnabled() { // Make sure the collation used by the ScalarFunction isn't changed and its result collation is not weaker than the collation used by the ScalarFunction. if changed { @@ -433,16 +455,24 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression } } } + if fail1Return && oldChanged != changed { + // Only when the oldChanged is true and changed is false, we will get here. + // And this means there some dependency in this arg can be substituted with + // given expressions, while it has some collation compatibility, finally we + // fall back to use the origin args. (commonly used in projection elimination + // in which fallback usage is unacceptable) + return changed, true, v + } refExprArr.Set(idx, changed, newFuncExpr) if changed { substituted = true } } if substituted { - return true, NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, refExprArr.Result()...) + return true, false, NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, refExprArr.Result()...) } } - return false, expr + return false, false, expr } // checkCollationStrictness check collation strictness-ship between `coll` and `newFuncColl` diff --git a/expression/util_test.go b/expression/util_test.go index 689385a9c6869..5aa35d81ec4a9 100644 --- a/expression/util_test.go +++ b/expression/util_test.go @@ -193,12 +193,12 @@ func TestGetUint64FromConstant(t *testing.T) { func TestSetExprColumnInOperand(t *testing.T) { col := &Column{RetType: newIntFieldType()} - require.True(t, setExprColumnInOperand(col).(*Column).InOperand) + require.True(t, SetExprColumnInOperand(col).(*Column).InOperand) f, err := funcs[ast.Abs].getFunction(mock.NewContext(), []Expression{col}) require.NoError(t, err) fun := &ScalarFunction{Function: f} - setExprColumnInOperand(fun) + SetExprColumnInOperand(fun) require.True(t, f.getArgs()[0].(*Column).InOperand) } diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index a2a858262e740..9872da3b60000 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -479,6 +479,7 @@ func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r e rColCopy := *rCol rColCopy.InOperand = true r = &rColCopy + l = expression.SetExprColumnInOperand(l) } } else { rowFunc := r.(*expression.ScalarFunction) @@ -501,6 +502,7 @@ func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r e if er.err != nil { return } + l = expression.SetExprColumnInOperand(l) } } } @@ -912,11 +914,15 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, v *ast.Patte // normal column equal condition, so we specially mark the inner operand here. if v.Not || asScalar { // If both input columns of `in` expression are not null, we can treat the expression - // as normal column equal condition instead. + // as normal column equal condition instead. Otherwise, mark the left and right side. + // eg: for some optimization, the column substitute in right side in projection elimination + // will cause case like as which is not + // a valid null-aware EQ. (null in lcol still need to be null-aware) if !expression.ExprNotNull(lexpr) || !expression.ExprNotNull(rCol) { rColCopy := *rCol rColCopy.InOperand = true rexpr = &rColCopy + lexpr = expression.SetExprColumnInOperand(lexpr) } } } else { @@ -924,10 +930,15 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, v *ast.Patte for i, col := range np.Schema().Columns { if v.Not || asScalar { larg := expression.GetFuncArg(lexpr, i) + // If both input columns of `in` expression are not null, we can treat the expression + // as normal column equal condition instead. Otherwise, mark the left and right side. if !expression.ExprNotNull(larg) || !expression.ExprNotNull(col) { rarg := *col rarg.InOperand = true col = &rarg + if larg != nil { + lexpr.(*expression.ScalarFunction).GetArgs()[i] = expression.SetExprColumnInOperand(larg) + } } } args = append(args, col) diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 27315d316f45f..389e963ca456f 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -375,21 +375,70 @@ func (p *LogicalJoin) GetPotentialPartitionKeys() (leftKeys, rightKeys []*proper return } -func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expression.Expression) { +// decorrelate eliminate the correlated column with if the col is in schema. +func (p *LogicalJoin) decorrelate(schema *expression.Schema) { for i, cond := range p.LeftConditions { - p.LeftConditions[i] = expression.ColumnSubstitute(cond, schema, exprs) + p.LeftConditions[i] = cond.Decorrelate(schema) } - for i, cond := range p.RightConditions { - p.RightConditions[i] = expression.ColumnSubstitute(cond, schema, exprs) + p.RightConditions[i] = cond.Decorrelate(schema) } - for i, cond := range p.OtherConditions { - p.OtherConditions[i] = expression.ColumnSubstitute(cond, schema, exprs) + p.OtherConditions[i] = cond.Decorrelate(schema) + } + for i, cond := range p.EqualConditions { + p.EqualConditions[i] = cond.Decorrelate(schema).(*expression.ScalarFunction) + } +} + +// columnSubstituteAll is used in projection elimination in apply de-correlation. +// Substitutions for all conditions should be successful, otherwise, we should keep all conditions unchanged. +func (p *LogicalJoin) columnSubstituteAll(schema *expression.Schema, exprs []expression.Expression) (hasFail bool) { + // make a copy of exprs for convenience of substitution (may change/partially change the expr tree) + cpLeftConditions := make(expression.CNFExprs, len(p.LeftConditions)) + cpRightConditions := make(expression.CNFExprs, len(p.RightConditions)) + cpOtherConditions := make(expression.CNFExprs, len(p.OtherConditions)) + cpEqualConditions := make([]*expression.ScalarFunction, len(p.EqualConditions)) + copy(cpLeftConditions, p.LeftConditions) + copy(cpRightConditions, p.RightConditions) + copy(cpOtherConditions, p.OtherConditions) + copy(cpEqualConditions, p.EqualConditions) + + // try to substitute columns in these condition. + for i, cond := range cpLeftConditions { + if hasFail, cpLeftConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail { + return + } + } + + for i, cond := range cpRightConditions { + if hasFail, cpRightConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail { + return + } + } + + for i, cond := range cpOtherConditions { + if hasFail, cpOtherConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail { + return + } } + for i, cond := range cpEqualConditions { + var tmp expression.Expression + if hasFail, tmp = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail { + return + } + cpEqualConditions[i] = tmp.(*expression.ScalarFunction) + } + + // if all substituted, change them atomically here. + p.LeftConditions = cpLeftConditions + p.RightConditions = cpRightConditions + p.OtherConditions = cpOtherConditions + p.EqualConditions = cpEqualConditions + for i := len(p.EqualConditions) - 1; i >= 0; i-- { - newCond := expression.ColumnSubstitute(p.EqualConditions[i], schema, exprs).(*expression.ScalarFunction) + newCond := p.EqualConditions[i] // If the columns used in the new filter all come from the left child, // we can push this filter to it. @@ -420,6 +469,7 @@ func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expres p.EqualConditions[i] = newCond } + return false } // AttachOnConds extracts on conditions for join and set the `EqualConditions`, `LeftConditions`, `RightConditions` and diff --git a/planner/core/rule_decorrelate.go b/planner/core/rule_decorrelate.go index a09c55f1dac3e..ecb1d1c2f36bd 100644 --- a/planner/core/rule_decorrelate.go +++ b/planner/core/rule_decorrelate.go @@ -171,10 +171,28 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo // TODO: Actually, it can be optimized. We need to first push the projection down to the selection. And then the APPLY can be decorrelated. goto NoOptimize } + + // step1: substitute the all the schema with new expressions (including correlated column maybe, but it doesn't affect the collation infer inside) + // eg: projection: constant("guo") --> column8, once upper layer substitution failed here, the lower layer behind + // projection can't supply column8 anymore. + // + // upper OP (depend on column8) --> projection(constant "guo" --> column8) --> lower layer OP + // | ^ + // +-------------------------------------------------------+ + // + // upper OP (depend on column8) --> lower layer OP + // | ^ + // +-----------------------------+ // Fail: lower layer can't supply column8 anymore. + hasFail := apply.columnSubstituteAll(proj.Schema(), proj.Exprs) + if hasFail { + goto NoOptimize + } + // step2: when it can be substituted all, we then just do the de-correlation (apply conditions included). for i, expr := range proj.Exprs { proj.Exprs[i] = expr.Decorrelate(outerPlan.Schema()) } - apply.columnSubstitute(proj.Schema(), proj.Exprs) + apply.decorrelate(outerPlan.Schema()) + innerPlan = proj.children[0] apply.SetChildren(outerPlan, innerPlan) if apply.JoinType != SemiJoin && apply.JoinType != LeftOuterSemiJoin && apply.JoinType != AntiSemiJoin && apply.JoinType != AntiLeftOuterSemiJoin {