diff --git a/cmd/explaintest/r/subquery.result b/cmd/explaintest/r/subquery.result new file mode 100644 index 0000000000000..934101ce4d320 --- /dev/null +++ b/cmd/explaintest/r/subquery.result @@ -0,0 +1,11 @@ +drop table if exists t1; +drop table if exists t2; +create table t1(a bigint, b bigint); +create table t2(a bigint, b bigint); +explain select * from t1 where t1.a in (select t1.b + t2.b from t2); +id count task operator info +HashLeftJoin_8 8000.00 root semi join, inner:TableReader_12, other cond:eq(test.t1.a, plus(test.t1.b, test.t2.b)) +├─TableReader_10 10000.00 root data:TableScan_9 +│ └─TableScan_9 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +└─TableReader_12 10000.00 root data:TableScan_11 + └─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo diff --git a/cmd/explaintest/t/subquery.test b/cmd/explaintest/t/subquery.test new file mode 100644 index 0000000000000..de17ee3e25b5c --- /dev/null +++ b/cmd/explaintest/t/subquery.test @@ -0,0 +1,5 @@ +drop table if exists t1; +drop table if exists t2; +create table t1(a bigint, b bigint); +create table t2(a bigint, b bigint); +explain select * from t1 where t1.a in (select t1.b + t2.b from t2); diff --git a/expression/schema.go b/expression/schema.go index bd6714f95ca79..c3bfdfc60c56b 100644 --- a/expression/schema.go +++ b/expression/schema.go @@ -90,8 +90,20 @@ func (s *Schema) Clone() *Schema { // ExprFromSchema checks if all columns of this expression are from the same schema. func ExprFromSchema(expr Expression, schema *Schema) bool { - cols := ExtractColumns(expr) - return len(schema.ColumnsIndices(cols)) > 0 + switch v := expr.(type) { + case *Column: + return schema.Contains(v) + case *ScalarFunction: + for _, arg := range v.GetArgs() { + if !ExprFromSchema(arg, schema) { + return false + } + } + return true + case *CorrelatedColumn, *Constant: + return true + } + return false } // FindColumn finds an Column from schema for a ast.ColumnName. It compares the db/table/column names. diff --git a/expression/util_test.go b/expression/util_test.go index 7e810180dc264..f1e03ec971226 100644 --- a/expression/util_test.go +++ b/expression/util_test.go @@ -106,3 +106,21 @@ func BenchmarkExtractColumns(b *testing.B) { } b.ReportAllocs() } + +func BenchmarkExprFromSchema(b *testing.B) { + conditions := []Expression{ + newFunction(ast.EQ, newColumn(0), newColumn(1)), + newFunction(ast.EQ, newColumn(1), newColumn(2)), + newFunction(ast.EQ, newColumn(2), newColumn(3)), + newFunction(ast.EQ, newColumn(3), newLonglong(1)), + newFunction(ast.LogicOr, newLonglong(1), newColumn(0)), + } + expr := ComposeCNFCondition(mock.NewContext(), conditions...) + schema := &Schema{Columns: ExtractColumns(expr)} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ExprFromSchema(expr, schema) + } + b.ReportAllocs() +} diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 6237833c32d3e..7d433529bbdb3 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -125,25 +125,49 @@ type LogicalJoin struct { } func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expression.Expression) { + for i, cond := range p.LeftConditions { + p.LeftConditions[i] = expression.ColumnSubstitute(cond, schema, exprs) + } + + for i, cond := range p.RightConditions { + p.RightConditions[i] = expression.ColumnSubstitute(cond, schema, exprs) + } + + for i, cond := range p.OtherConditions { + p.OtherConditions[i] = expression.ColumnSubstitute(cond, schema, exprs) + } + for i := len(p.EqualConditions) - 1; i >= 0; i-- { - p.EqualConditions[i] = expression.ColumnSubstitute(p.EqualConditions[i], schema, exprs).(*expression.ScalarFunction) - // After the column substitute, the equal condition may become single side condition. - if p.children[0].Schema().Contains(p.EqualConditions[i].GetArgs()[1].(*expression.Column)) { - p.LeftConditions = append(p.LeftConditions, p.EqualConditions[i]) + newCond := expression.ColumnSubstitute(p.EqualConditions[i], schema, exprs).(*expression.ScalarFunction) + + // If the columns used in the new filter all come from the left child, + // we can push this filter to it. + if expression.ExprFromSchema(newCond, p.children[0].Schema()) { + p.LeftConditions = append(p.LeftConditions, newCond) p.EqualConditions = append(p.EqualConditions[:i], p.EqualConditions[i+1:]...) - } else if p.children[1].Schema().Contains(p.EqualConditions[i].GetArgs()[0].(*expression.Column)) { - p.RightConditions = append(p.RightConditions, p.EqualConditions[i]) + continue + } + + // If the columns used in the new filter all come from the right + // child, we can push this filter to it. + if expression.ExprFromSchema(newCond, p.children[1].Schema()) { + p.RightConditions = append(p.RightConditions, newCond) p.EqualConditions = append(p.EqualConditions[:i], p.EqualConditions[i+1:]...) + continue } - } - for i, fun := range p.LeftConditions { - p.LeftConditions[i] = expression.ColumnSubstitute(fun, schema, exprs) - } - for i, fun := range p.RightConditions { - p.RightConditions[i] = expression.ColumnSubstitute(fun, schema, exprs) - } - for i, fun := range p.OtherConditions { - p.OtherConditions[i] = expression.ColumnSubstitute(fun, schema, exprs) + + _, lhsIsCol := newCond.GetArgs()[0].(*expression.Column) + _, rhsIsCol := newCond.GetArgs()[1].(*expression.Column) + + // If the columns used in the new filter are not all expression.Column, + // we can not use it as join's equal condition. + if !(lhsIsCol && rhsIsCol) { + p.OtherConditions = append(p.OtherConditions, newCond) + p.EqualConditions = append(p.EqualConditions[:i], p.EqualConditions[i+1:]...) + continue + } + + p.EqualConditions[i] = newCond } }