Skip to content

Commit

Permalink
Merge branch 'release-6.1' into pick-40080
Browse files Browse the repository at this point in the history
  • Loading branch information
qw4990 authored Jan 18, 2023
2 parents 75e80d4 + 3af286a commit 50e6719
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 23 deletions.
28 changes: 28 additions & 0 deletions cmd/explaintest/r/subquery.result
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,31 @@ create table t1(a int(11));
create table t2(a decimal(40,20) unsigned, b decimal(40,20));
select count(*) as x from t1 group by a having x not in (select a from t2 where x = t2.b);
x
drop table if exists stu;
drop table if exists exam;
create table stu(id int, name varchar(100));
insert into stu values(1, null);
create table exam(stu_id int, course varchar(100), grade int);
insert into exam values(1, 'math', 100);
set names utf8 collate utf8_general_ci;
explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
id estRows task access object operator info
Apply 10000.00 root CARTESIAN anti semi join, other cond:eq(test.stu.name, Column#8)
├─TableReader(Build) 10000.00 root data:TableFullScan
│ └─TableFullScan 10000.00 cop[tikv] table:stu keep order:false, stats:pseudo
└─Projection(Probe) 10.00 root guo->Column#8
└─TableReader 10.00 root data:Selection
└─Selection 10.00 cop[tikv] eq(test.exam.stu_id, test.stu.id)
└─TableFullScan 10000.00 cop[tikv] table:exam keep order:false, stats:pseudo
select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
id name
set names utf8mb4;
explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
id estRows task access object operator info
HashJoin 8000.00 root anti semi join, equal:[eq(test.stu.id, test.exam.stu_id)], other cond:eq(test.stu.name, "guo")
├─TableReader(Build) 10000.00 root data:TableFullScan
│ └─TableFullScan 10000.00 cop[tikv] table:exam keep order:false, stats:pseudo
└─TableReader(Probe) 10000.00 root data:TableFullScan
└─TableFullScan 10000.00 cop[tikv] table:stu keep order:false, stats:pseudo
select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
id name
13 changes: 13 additions & 0 deletions cmd/explaintest/t/subquery.test
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,16 @@ drop table if exists t1, t2;
create table t1(a int(11));
create table t2(a decimal(40,20) unsigned, b decimal(40,20));
select count(*) as x from t1 group by a having x not in (select a from t2 where x = t2.b);

drop table if exists stu;
drop table if exists exam;
create table stu(id int, name varchar(100));
insert into stu values(1, null);
create table exam(stu_id int, course varchar(100), grade int);
insert into exam values(1, 'math', 100);
set names utf8 collate utf8_general_ci;
explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
set names utf8mb4;
explain format = 'brief' select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
select * from stu where stu.name not in (select 'guo' from exam where exam.stu_id = stu.id);
34 changes: 24 additions & 10 deletions expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ func extractColumnSet(expr Expression, set *intsets.Sparse) {
}
}

func setExprColumnInOperand(expr Expression) Expression {
// SetExprColumnInOperand is used to set columns in expr as InOperand.
func SetExprColumnInOperand(expr Expression) Expression {
switch v := expr.(type) {
case *Column:
col := v.Clone().(*Column)
Expand All @@ -374,7 +375,7 @@ func setExprColumnInOperand(expr Expression) Expression {
case *ScalarFunction:
args := v.GetArgs()
for i, arg := range args {
args[i] = setExprColumnInOperand(arg)
args[i] = SetExprColumnInOperand(arg)
}
}
return expr
Expand All @@ -383,13 +384,26 @@ func setExprColumnInOperand(expr Expression) Expression {
// ColumnSubstitute substitutes the columns in filter to expressions in select fields.
// e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k.
func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Expression {
_, _, resExpr := ColumnSubstituteImpl(expr, schema, newExprs)
_, _, resExpr := ColumnSubstituteImpl(expr, schema, newExprs, false)
return resExpr
}

// ColumnSubstituteAll substitutes the columns just like ColumnSubstitute, but we don't accept partial substitution.
// Only accept:
//
// 1: substitute them all once find col in schema.
// 2: nothing in expr can be substituted.
func ColumnSubstituteAll(expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) {
_, hasFail, resExpr := ColumnSubstituteImpl(expr, schema, newExprs, true)
return hasFail, resExpr
}

// ColumnSubstituteImpl tries to substitute column expr using newExprs,
// the newFunctionInternal is only called if its child is substituted
func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression) (bool, bool, Expression) {
// @return bool means whether the expr has changed.
// @return bool means whether the expr should change (has the dependency in schema, while the corresponding expr has some compatibility), but finally fallback.
// @return Expression, the original expr or the changed expr, it depends on the first @return bool.
func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression, fail1Return bool) (bool, bool, Expression) {
switch v := expr.(type) {
case *Column:
id := schema.ColumnIndex(v)
Expand All @@ -398,7 +412,7 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression
}
newExpr := newExprs[id]
if v.InOperand {
newExpr = setExprColumnInOperand(newExpr)
newExpr = SetExprColumnInOperand(newExpr)
}
newExpr.SetCoercibility(v.Coercibility())
return true, false, newExpr
Expand All @@ -409,8 +423,8 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression
var (
newArg Expression
)
substituted, hasFail, newArg = ColumnSubstituteImpl(v.GetArgs()[0], schema, newExprs)
if hasFail {
substituted, hasFail, newArg = ColumnSubstituteImpl(v.GetArgs()[0], schema, newExprs, fail1Return)
if fail1Return && hasFail {
return substituted, hasFail, v
}
if substituted {
Expand All @@ -429,8 +443,8 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression
tmpArgForCollCheck = make([]Expression, len(v.GetArgs()))
}
for idx, arg := range v.GetArgs() {
changed, failed, newFuncExpr := ColumnSubstituteImpl(arg, schema, newExprs)
if failed {
changed, failed, newFuncExpr := ColumnSubstituteImpl(arg, schema, newExprs, fail1Return)
if fail1Return && failed {
return changed, failed, v
}
oldChanged := changed
Expand All @@ -445,7 +459,7 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression
}
}
hasFail = hasFail || failed || oldChanged != changed
if oldChanged != changed {
if fail1Return && oldChanged != changed {
// Only when the oldChanged is true and changed is false, we will get here.
// And this means there some dependency in this arg can be substituted with
// given expressions, while it has some collation compatibility, finally we
Expand Down
4 changes: 2 additions & 2 deletions expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,12 @@ func TestGetUint64FromConstant(t *testing.T) {

func TestSetExprColumnInOperand(t *testing.T) {
col := &Column{RetType: newIntFieldType()}
require.True(t, setExprColumnInOperand(col).(*Column).InOperand)
require.True(t, SetExprColumnInOperand(col).(*Column).InOperand)

f, err := funcs[ast.Abs].getFunction(mock.NewContext(), []Expression{col})
require.NoError(t, err)
fun := &ScalarFunction{Function: f}
setExprColumnInOperand(fun)
SetExprColumnInOperand(fun)
require.True(t, f.getArgs()[0].(*Column).InOperand)
}

Expand Down
2 changes: 1 addition & 1 deletion planner/cascades/transformation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ func (r *PushSelDownProjection) OnTransform(old *memo.ExprIter) (newExprs []*mem
canBePushed := make([]expression.Expression, 0, len(sel.Conditions))
canNotBePushed := make([]expression.Expression, 0, len(sel.Conditions))
for _, cond := range sel.Conditions {
substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(cond, projSchema, proj.Exprs)
substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(cond, projSchema, proj.Exprs, false)
if substituted && !hasFailed && !expression.HasGetSetVarFunc(newFilter) {
canBePushed = append(canBePushed, newFilter)
} else {
Expand Down
13 changes: 12 additions & 1 deletion planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r e
rColCopy := *rCol
rColCopy.InOperand = true
r = &rColCopy
l = expression.SetExprColumnInOperand(l)
}
} else {
rowFunc := r.(*expression.ScalarFunction)
Expand All @@ -502,6 +503,7 @@ func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r e
if er.err != nil {
return
}
l = expression.SetExprColumnInOperand(l)
}
}
}
Expand Down Expand Up @@ -912,22 +914,31 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, v *ast.Patte
// normal column equal condition, so we specially mark the inner operand here.
if v.Not || asScalar {
// If both input columns of `in` expression are not null, we can treat the expression
// as normal column equal condition instead.
// as normal column equal condition instead. Otherwise, mark the left and right side.
// eg: for some optimization, the column substitute in right side in projection elimination
// will cause case like <lcol EQ rcol(inOperand)> as <lcol EQ constant> which is not
// a valid null-aware EQ. (null in lcol still need to be null-aware)
if !expression.ExprNotNull(lexpr) || !expression.ExprNotNull(rCol) {
rColCopy := *rCol
rColCopy.InOperand = true
rexpr = &rColCopy
lexpr = expression.SetExprColumnInOperand(lexpr)
}
}
} else {
args := make([]expression.Expression, 0, np.Schema().Len())
for i, col := range np.Schema().Columns {
if v.Not || asScalar {
larg := expression.GetFuncArg(lexpr, i)
// If both input columns of `in` expression are not null, we can treat the expression
// as normal column equal condition instead. Otherwise, mark the left and right side.
if !expression.ExprNotNull(larg) || !expression.ExprNotNull(col) {
rarg := *col
rarg.InOperand = true
col = &rarg
if larg != nil {
lexpr.(*expression.ScalarFunction).GetArgs()[i] = expression.SetExprColumnInOperand(larg)
}
}
}
args = append(args, col)
Expand Down
64 changes: 57 additions & 7 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,21 +375,70 @@ func (p *LogicalJoin) GetPotentialPartitionKeys() (leftKeys, rightKeys []*proper
return
}

func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expression.Expression) {
// decorrelate eliminate the correlated column with if the col is in schema.
func (p *LogicalJoin) decorrelate(schema *expression.Schema) {
for i, cond := range p.LeftConditions {
p.LeftConditions[i] = expression.ColumnSubstitute(cond, schema, exprs)
p.LeftConditions[i] = cond.Decorrelate(schema)
}

for i, cond := range p.RightConditions {
p.RightConditions[i] = expression.ColumnSubstitute(cond, schema, exprs)
p.RightConditions[i] = cond.Decorrelate(schema)
}

for i, cond := range p.OtherConditions {
p.OtherConditions[i] = expression.ColumnSubstitute(cond, schema, exprs)
p.OtherConditions[i] = cond.Decorrelate(schema)
}
for i, cond := range p.EqualConditions {
p.EqualConditions[i] = cond.Decorrelate(schema).(*expression.ScalarFunction)
}
}

// columnSubstituteAll is used in projection elimination in apply de-correlation.
// Substitutions for all conditions should be successful, otherwise, we should keep all conditions unchanged.
func (p *LogicalJoin) columnSubstituteAll(schema *expression.Schema, exprs []expression.Expression) (hasFail bool) {
// make a copy of exprs for convenience of substitution (may change/partially change the expr tree)
cpLeftConditions := make(expression.CNFExprs, len(p.LeftConditions))
cpRightConditions := make(expression.CNFExprs, len(p.RightConditions))
cpOtherConditions := make(expression.CNFExprs, len(p.OtherConditions))
cpEqualConditions := make([]*expression.ScalarFunction, len(p.EqualConditions))
copy(cpLeftConditions, p.LeftConditions)
copy(cpRightConditions, p.RightConditions)
copy(cpOtherConditions, p.OtherConditions)
copy(cpEqualConditions, p.EqualConditions)

// try to substitute columns in these condition.
for i, cond := range cpLeftConditions {
if hasFail, cpLeftConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail {
return
}
}

for i, cond := range cpRightConditions {
if hasFail, cpRightConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail {
return
}
}

for i, cond := range cpOtherConditions {
if hasFail, cpOtherConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail {
return
}
}

for i, cond := range cpEqualConditions {
var tmp expression.Expression
if hasFail, tmp = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail {
return
}
cpEqualConditions[i] = tmp.(*expression.ScalarFunction)
}

// if all substituted, change them atomically here.
p.LeftConditions = cpLeftConditions
p.RightConditions = cpRightConditions
p.OtherConditions = cpOtherConditions
p.EqualConditions = cpEqualConditions

for i := len(p.EqualConditions) - 1; i >= 0; i-- {
newCond := expression.ColumnSubstitute(p.EqualConditions[i], schema, exprs).(*expression.ScalarFunction)
newCond := p.EqualConditions[i]

// If the columns used in the new filter all come from the left child,
// we can push this filter to it.
Expand Down Expand Up @@ -420,6 +469,7 @@ func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expres

p.EqualConditions[i] = newCond
}
return false
}

// AttachOnConds extracts on conditions for join and set the `EqualConditions`, `LeftConditions`, `RightConditions` and
Expand Down
20 changes: 19 additions & 1 deletion planner/core/rule_decorrelate.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,28 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo
// TODO: Actually, it can be optimized. We need to first push the projection down to the selection. And then the APPLY can be decorrelated.
goto NoOptimize
}

// step1: substitute the all the schema with new expressions (including correlated column maybe, but it doesn't affect the collation infer inside)
// eg: projection: constant("guo") --> column8, once upper layer substitution failed here, the lower layer behind
// projection can't supply column8 anymore.
//
// upper OP (depend on column8) --> projection(constant "guo" --> column8) --> lower layer OP
// | ^
// +-------------------------------------------------------+
//
// upper OP (depend on column8) --> lower layer OP
// | ^
// +-----------------------------+ // Fail: lower layer can't supply column8 anymore.
hasFail := apply.columnSubstituteAll(proj.Schema(), proj.Exprs)
if hasFail {
goto NoOptimize
}
// step2: when it can be substituted all, we then just do the de-correlation (apply conditions included).
for i, expr := range proj.Exprs {
proj.Exprs[i] = expr.Decorrelate(outerPlan.Schema())
}
apply.columnSubstitute(proj.Schema(), proj.Exprs)
apply.decorrelate(outerPlan.Schema())

innerPlan = proj.children[0]
apply.SetChildren(outerPlan, innerPlan)
if apply.JoinType != SemiJoin && apply.JoinType != LeftOuterSemiJoin && apply.JoinType != AntiSemiJoin && apply.JoinType != AntiLeftOuterSemiJoin {
Expand Down
2 changes: 1 addition & 1 deletion planner/core/rule_predicate_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ func (p *LogicalProjection) PredicatePushDown(predicates []expression.Expression
}
}
for _, cond := range predicates {
substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(cond, p.Schema(), p.Exprs)
substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(cond, p.Schema(), p.Exprs, true)
if substituted && !hasFailed && !expression.HasGetSetVarFunc(newFilter) {
canBePushed = append(canBePushed, newFilter)
} else {
Expand Down

0 comments on commit 50e6719

Please sign in to comment.