Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: mark the both side operand of NAAJ & refuse partial column substitute in projection elimination of Apply de-correlation (#37117) #37357

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions cmd/explaintest/r/subquery.result
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,31 @@ create table t1(a int(11));
create table t2(a decimal(40,20) unsigned, b decimal(40,20));
select count(*) as x from t1 group by a having x not in (select a from t2 where x = t2.b);
x
drop table if exists stu;
drop table if exists exam;
create table stu(id int, name varchar(100));
insert into stu values(1, null);
create table exam(stu_id int, course varchar(100), grade int);
insert into exam values(1, 'math', 100);
set names utf8 collate utf8_general_ci;
explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
id estRows task access object operator info
Apply 10000.00 root CARTESIAN anti semi join, other cond:eq(test.stu.name, Column#8)
├─TableReader(Build) 10000.00 root data:TableFullScan
│ └─TableFullScan 10000.00 cop[tikv] table:stu keep order:false, stats:pseudo
└─Projection(Probe) 10.00 root guo->Column#8
└─TableReader 10.00 root data:Selection
└─Selection 10.00 cop[tikv] eq(test.exam.stu_id, test.stu.id)
└─TableFullScan 10000.00 cop[tikv] table:exam keep order:false, stats:pseudo
select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
id name
set names utf8mb4;
explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
id estRows task access object operator info
HashJoin 8000.00 root anti semi join, equal:[eq(test.stu.id, test.exam.stu_id)], other cond:eq(test.stu.name, "guo")
├─TableReader(Build) 10000.00 root data:TableFullScan
│ └─TableFullScan 10000.00 cop[tikv] table:exam keep order:false, stats:pseudo
└─TableReader(Probe) 10000.00 root data:TableFullScan
└─TableFullScan 10000.00 cop[tikv] table:stu keep order:false, stats:pseudo
select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
id name
13 changes: 13 additions & 0 deletions cmd/explaintest/t/subquery.test
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,16 @@ drop table if exists t1, t2;
create table t1(a int(11));
create table t2(a decimal(40,20) unsigned, b decimal(40,20));
select count(*) as x from t1 group by a having x not in (select a from t2 where x = t2.b);

drop table if exists stu;
drop table if exists exam;
create table stu(id int, name varchar(100));
insert into stu values(1, null);
create table exam(stu_id int, course varchar(100), grade int);
insert into exam values(1, 'math', 100);
set names utf8 collate utf8_general_ci;
explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
set names utf8mb4;
explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
56 changes: 43 additions & 13 deletions expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ func extractColumnSet(expr Expression, set *intsets.Sparse) {
}
}

func setExprColumnInOperand(expr Expression) Expression {
// SetExprColumnInOperand is used to set columns in expr as InOperand.
func SetExprColumnInOperand(expr Expression) Expression {
switch v := expr.(type) {
case *Column:
col := v.Clone().(*Column)
Expand All @@ -374,7 +375,7 @@ func setExprColumnInOperand(expr Expression) Expression {
case *ScalarFunction:
args := v.GetArgs()
for i, arg := range args {
args[i] = setExprColumnInOperand(arg)
args[i] = SetExprColumnInOperand(arg)
}
}
return expr
Expand All @@ -383,44 +384,65 @@ func setExprColumnInOperand(expr Expression) Expression {
// ColumnSubstitute substitutes the columns in filter to expressions in select fields.
// e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k.
func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Expression {
_, resExpr := ColumnSubstituteImpl(expr, schema, newExprs)
_, _, resExpr := ColumnSubstituteImpl(expr, schema, newExprs, false)
return resExpr
}

// ColumnSubstituteAll substitutes the columns just like ColumnSubstitute, but we don't accept partial substitution.
// Only accept:
//
// 1: substitute them all once find col in schema.
// 2: nothing in expr can be substituted.
func ColumnSubstituteAll(expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) {
_, hasFail, resExpr := ColumnSubstituteImpl(expr, schema, newExprs, true)
return hasFail, resExpr
}

// ColumnSubstituteImpl tries to substitute column expr using newExprs,
// the newFunctionInternal is only called if its child is substituted
func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) {
// @return bool means whether the expr has changed.
// @return bool means whether the expr should change (has the dependency in schema, while the corresponding expr has some compatibility), but finally fallback.
// @return Expression, the original expr or the changed expr, it depends on the first @return bool.
func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression, fail1Return bool) (bool, bool, Expression) {
switch v := expr.(type) {
case *Column:
id := schema.ColumnIndex(v)
if id == -1 {
return false, v
return false, false, v
}
newExpr := newExprs[id]
if v.InOperand {
newExpr = setExprColumnInOperand(newExpr)
newExpr = SetExprColumnInOperand(newExpr)
}
newExpr.SetCoercibility(v.Coercibility())
return true, newExpr
return true, false, newExpr
case *ScalarFunction:
substituted := false
hasFail := false
if v.FuncName.L == ast.Cast {
newFunc := v.Clone().(*ScalarFunction)
substituted, newFunc.GetArgs()[0] = ColumnSubstituteImpl(newFunc.GetArgs()[0], schema, newExprs)
substituted, hasFail, newFunc.GetArgs()[0] = ColumnSubstituteImpl(newFunc.GetArgs()[0], schema, newExprs, fail1Return)
if fail1Return && hasFail {
return substituted, hasFail, newFunc
}
if substituted {
// Workaround for issue https://github.com/pingcap/tidb/issues/28804
e := NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, newFunc.GetArgs()...)
e.SetCoercibility(v.Coercibility())
return true, e
return true, false, e
}
return false, newFunc
return false, false, newFunc
}
// cowExprRef is a copy-on-write util, args array allocation happens only
// when expr in args is changed
refExprArr := cowExprRef{v.GetArgs(), nil}
_, coll := DeriveCollationFromExprs(v.GetCtx(), v.GetArgs()...)
for idx, arg := range v.GetArgs() {
changed, newFuncExpr := ColumnSubstituteImpl(arg, schema, newExprs)
changed, hasFail, newFuncExpr := ColumnSubstituteImpl(arg, schema, newExprs, fail1Return)
if fail1Return && hasFail {
return changed, hasFail, v
}
oldChanged := changed
if collate.NewCollationEnabled() {
// Make sure the collation used by the ScalarFunction isn't changed and its result collation is not weaker than the collation used by the ScalarFunction.
if changed {
Expand All @@ -433,16 +455,24 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression
}
}
}
if fail1Return && oldChanged != changed {
// Only when the oldChanged is true and changed is false, we will get here.
// And this means there some dependency in this arg can be substituted with
// given expressions, while it has some collation compatibility, finally we
// fall back to use the origin args. (commonly used in projection elimination
// in which fallback usage is unacceptable)
return changed, true, v
}
refExprArr.Set(idx, changed, newFuncExpr)
if changed {
substituted = true
}
}
if substituted {
return true, NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, refExprArr.Result()...)
return true, false, NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, refExprArr.Result()...)
}
}
return false, expr
return false, false, expr
}

// checkCollationStrictness check collation strictness-ship between `coll` and `newFuncColl`
Expand Down
4 changes: 2 additions & 2 deletions expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ func TestGetUint64FromConstant(t *testing.T) {

func TestSetExprColumnInOperand(t *testing.T) {
col := &Column{RetType: newIntFieldType()}
require.True(t, setExprColumnInOperand(col).(*Column).InOperand)
require.True(t, SetExprColumnInOperand(col).(*Column).InOperand)

f, err := funcs[ast.Abs].getFunction(mock.NewContext(), []Expression{col})
require.NoError(t, err)
fun := &ScalarFunction{Function: f}
setExprColumnInOperand(fun)
SetExprColumnInOperand(fun)
require.True(t, f.getArgs()[0].(*Column).InOperand)
}

Expand Down
13 changes: 12 additions & 1 deletion planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r e
rColCopy := *rCol
rColCopy.InOperand = true
r = &rColCopy
l = expression.SetExprColumnInOperand(l)
}
} else {
rowFunc := r.(*expression.ScalarFunction)
Expand All @@ -501,6 +502,7 @@ func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r e
if er.err != nil {
return
}
l = expression.SetExprColumnInOperand(l)
}
}
}
Expand Down Expand Up @@ -912,22 +914,31 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, v *ast.Patte
// normal column equal condition, so we specially mark the inner operand here.
if v.Not || asScalar {
// If both input columns of `in` expression are not null, we can treat the expression
// as normal column equal condition instead.
// as normal column equal condition instead. Otherwise, mark the left and right side.
// eg: for some optimization, the column substitute in right side in projection elimination
// will cause case like <lcol EQ rcol(inOperand)> as <lcol EQ constant> which is not
// a valid null-aware EQ. (null in lcol still need to be null-aware)
if !expression.ExprNotNull(lexpr) || !expression.ExprNotNull(rCol) {
rColCopy := *rCol
rColCopy.InOperand = true
rexpr = &rColCopy
lexpr = expression.SetExprColumnInOperand(lexpr)
}
}
} else {
args := make([]expression.Expression, 0, np.Schema().Len())
for i, col := range np.Schema().Columns {
if v.Not || asScalar {
larg := expression.GetFuncArg(lexpr, i)
// If both input columns of `in` expression are not null, we can treat the expression
// as normal column equal condition instead. Otherwise, mark the left and right side.
if !expression.ExprNotNull(larg) || !expression.ExprNotNull(col) {
rarg := *col
rarg.InOperand = true
col = &rarg
if larg != nil {
lexpr.(*expression.ScalarFunction).GetArgs()[i] = expression.SetExprColumnInOperand(larg)
}
}
}
args = append(args, col)
Expand Down
64 changes: 57 additions & 7 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,21 +375,70 @@ func (p *LogicalJoin) GetPotentialPartitionKeys() (leftKeys, rightKeys []*proper
return
}

func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expression.Expression) {
// decorrelate eliminate the correlated column with if the col is in schema.
func (p *LogicalJoin) decorrelate(schema *expression.Schema) {
for i, cond := range p.LeftConditions {
p.LeftConditions[i] = expression.ColumnSubstitute(cond, schema, exprs)
p.LeftConditions[i] = cond.Decorrelate(schema)
}

for i, cond := range p.RightConditions {
p.RightConditions[i] = expression.ColumnSubstitute(cond, schema, exprs)
p.RightConditions[i] = cond.Decorrelate(schema)
}

for i, cond := range p.OtherConditions {
p.OtherConditions[i] = expression.ColumnSubstitute(cond, schema, exprs)
p.OtherConditions[i] = cond.Decorrelate(schema)
}
for i, cond := range p.EqualConditions {
p.EqualConditions[i] = cond.Decorrelate(schema).(*expression.ScalarFunction)
}
}

// columnSubstituteAll is used in projection elimination in apply de-correlation.
// Substitutions for all conditions should be successful, otherwise, we should keep all conditions unchanged.
func (p *LogicalJoin) columnSubstituteAll(schema *expression.Schema, exprs []expression.Expression) (hasFail bool) {
// make a copy of exprs for convenience of substitution (may change/partially change the expr tree)
cpLeftConditions := make(expression.CNFExprs, len(p.LeftConditions))
cpRightConditions := make(expression.CNFExprs, len(p.RightConditions))
cpOtherConditions := make(expression.CNFExprs, len(p.OtherConditions))
cpEqualConditions := make([]*expression.ScalarFunction, len(p.EqualConditions))
copy(cpLeftConditions, p.LeftConditions)
copy(cpRightConditions, p.RightConditions)
copy(cpOtherConditions, p.OtherConditions)
copy(cpEqualConditions, p.EqualConditions)

// try to substitute columns in these condition.
for i, cond := range cpLeftConditions {
if hasFail, cpLeftConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail {
return
}
}

for i, cond := range cpRightConditions {
if hasFail, cpRightConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail {
return
}
}

for i, cond := range cpOtherConditions {
if hasFail, cpOtherConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail {
return
}
}

for i, cond := range cpEqualConditions {
var tmp expression.Expression
if hasFail, tmp = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail {
return
}
cpEqualConditions[i] = tmp.(*expression.ScalarFunction)
}

// if all substituted, change them atomically here.
p.LeftConditions = cpLeftConditions
p.RightConditions = cpRightConditions
p.OtherConditions = cpOtherConditions
p.EqualConditions = cpEqualConditions

for i := len(p.EqualConditions) - 1; i >= 0; i-- {
newCond := expression.ColumnSubstitute(p.EqualConditions[i], schema, exprs).(*expression.ScalarFunction)
newCond := p.EqualConditions[i]

// If the columns used in the new filter all come from the left child,
// we can push this filter to it.
Expand Down Expand Up @@ -420,6 +469,7 @@ func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expres

p.EqualConditions[i] = newCond
}
return false
}

// AttachOnConds extracts on conditions for join and set the `EqualConditions`, `LeftConditions`, `RightConditions` and
Expand Down
20 changes: 19 additions & 1 deletion planner/core/rule_decorrelate.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,28 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo
// TODO: Actually, it can be optimized. We need to first push the projection down to the selection. And then the APPLY can be decorrelated.
goto NoOptimize
}

// step1: substitute the all the schema with new expressions (including correlated column maybe, but it doesn't affect the collation infer inside)
// eg: projection: constant("guo") --> column8, once upper layer substitution failed here, the lower layer behind
// projection can't supply column8 anymore.
//
// upper OP (depend on column8) --> projection(constant "guo" --> column8) --> lower layer OP
// | ^
// +-------------------------------------------------------+
//
// upper OP (depend on column8) --> lower layer OP
// | ^
// +-----------------------------+ // Fail: lower layer can't supply column8 anymore.
hasFail := apply.columnSubstituteAll(proj.Schema(), proj.Exprs)
if hasFail {
goto NoOptimize
}
// step2: when it can be substituted all, we then just do the de-correlation (apply conditions included).
for i, expr := range proj.Exprs {
proj.Exprs[i] = expr.Decorrelate(outerPlan.Schema())
}
apply.columnSubstitute(proj.Schema(), proj.Exprs)
apply.decorrelate(outerPlan.Schema())

innerPlan = proj.children[0]
apply.SetChildren(outerPlan, innerPlan)
if apply.JoinType != SemiJoin && apply.JoinType != LeftOuterSemiJoin && apply.JoinType != AntiSemiJoin && apply.JoinType != AntiLeftOuterSemiJoin {
Expand Down