diff --git a/bench_test.go b/bench_test.go index eb40debc26b25..0c3cfc025d6de 100644 --- a/bench_test.go +++ b/bench_test.go @@ -6,6 +6,7 @@ import ( "github.com/ngaut/log" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/optimizer/plan" ) var smallCount = 100 @@ -34,6 +35,16 @@ func prepareBenchData(se Session, colType string, valueFormat string, valueCount mustExecute(se, "commit") } +func prepareJoinBenchData(se Session, colType string, valueFormat string, valueCount int) { + mustExecute(se, "drop table if exists t") + mustExecute(se, fmt.Sprintf("create table t (pk int primary key auto_increment, col %s)", colType)) + mustExecute(se, "begin") + for i := 0; i < valueCount; i++ { + mustExecute(se, "insert t (col) values ("+fmt.Sprintf(valueFormat, i)+")") + } + mustExecute(se, "commit") +} + func readResult(rs ast.RecordSet, count int) { for count > 0 { x, err := rs.Next() @@ -192,3 +203,63 @@ func BenchmarkInsertNoIndex(b *testing.B) { mustExecute(se, fmt.Sprintf("insert t values (%d, %d)", i, i)) } } + +func BenchmarkJoin(b *testing.B) { + b.StopTimer() + se := prepareBenchSession() + prepareJoinBenchData(se, "int", "%v", smallCount) + b.StartTimer() + for i := 0; i < b.N; i++ { + rs, err := se.Execute("select * from t a join t b on a.col = b.col") + if err != nil { + b.Fatal(err) + } + readResult(rs[0], 100) + } +} + +func BenchmarkNewJoin(b *testing.B) { + b.StopTimer() + se := prepareBenchSession() + prepareJoinBenchData(se, "int", "%v", smallCount) + b.StartTimer() + plan.UseNewPlanner = true + for i := 0; i < b.N; i++ { + rs, err := se.Execute("select * from t a join t b on a.col = b.col") + if err != nil { + b.Fatal(err) + } + readResult(rs[0], 100) + } + plan.UseNewPlanner = false +} + +func BenchmarkJoinLimit(b *testing.B) { + b.StopTimer() + se := prepareBenchSession() + prepareJoinBenchData(se, "int", "%v", smallCount) + b.StartTimer() + for i := 0; i < b.N; i++ { + rs, err := se.Execute("select * from t a join t b on a.col = b.col limit 1") + if err != nil { + b.Fatal(err) + } + readResult(rs[0], 1) + } +} + +func BenchmarkNewJoinLimit(b *testing.B) { + b.StopTimer() + se := prepareBenchSession() + prepareJoinBenchData(se, "int", "%v", smallCount) + b.StartTimer() + plan.UseNewPlanner = true + for i := 0; i < b.N; i++ { + rs, err := se.Execute("select * from t a join t b on a.col = b.col limit 1") + if err != nil { + b.Fatal(err) + } + readResult(rs[0], 1) + } + plan.UseNewPlanner = false +} diff --git a/evaluator/evaluator_binop.go b/evaluator/evaluator_binop.go index 73ab8bc6768fa..0a89b971bdcfe 100644 --- a/evaluator/evaluator_binop.go +++ b/evaluator/evaluator_binop.go @@ -91,7 +91,6 @@ func (e *Evaluator) handleAndAnd(o *ast.BinaryOperationExpr) bool { func (e *Evaluator) handleOrOr(o *ast.BinaryOperationExpr) bool { leftDatum := o.L.GetDatum() - righDatum := o.R.GetDatum() if leftDatum.Kind() != types.KindNull { x, err := leftDatum.ToBool() if err != nil { @@ -103,6 +102,7 @@ func (e *Evaluator) handleOrOr(o *ast.BinaryOperationExpr) bool { return true } } + righDatum := o.R.GetDatum() if righDatum.Kind() != types.KindNull { y, err := righDatum.ToBool() if err != nil { diff --git a/executor/builder.go b/executor/builder.go index bc991341bb73e..1971c406b2226 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -103,12 +103,85 @@ func (b *executorBuilder) build(p plan.Plan) Executor { return b.buildUnion(v) case *plan.Update: return b.buildUpdate(v) + case *plan.Join: + return b.buildJoin(v) default: b.err = ErrUnknownPlan.Gen("Unknown Plan %T", p) return nil } } +// compose CNF items into a balance deep CNF tree, which benefits a lot for pb decoder/encoder. +func composeCondition(conditions []ast.ExprNode) ast.ExprNode { + length := len(conditions) + if length == 0 { + return nil + } else if length == 1 { + return conditions[0] + } else { + return &ast.BinaryOperationExpr{Op: opcode.AndAnd, L: composeCondition(conditions[:length/2]), R: composeCondition(conditions[length/2:])} + } +} + +//TODO: select join algorithm during cbo phase. +func (b *executorBuilder) buildJoin(v *plan.Join) Executor { + e := &HashJoinExec{ + otherFilter: composeCondition(v.OtherConditions), + prepared: false, + fields: v.Fields(), + ctx: b.ctx, + } + var leftHashKey, rightHashKey []ast.ExprNode + for _, eqCond := range v.EqualConditions { + binop, ok := eqCond.(*ast.BinaryOperationExpr) + if ok && binop.Op == opcode.EQ { + ln, lOK := binop.L.(*ast.ColumnNameExpr) + rn, rOK := binop.R.(*ast.ColumnNameExpr) + if lOK && rOK { + leftHashKey = append(leftHashKey, ln) + rightHashKey = append(rightHashKey, rn) + continue + } + } + b.err = ErrUnknownPlan.Gen("Invalid Join Equal Condition !!") + } + switch v.JoinType { + case plan.LeftOuterJoin: + e.outter = true + e.leftSmall = false + e.smallFilter = composeCondition(v.RightConditions) + e.bigFilter = composeCondition(v.LeftConditions) + e.smallHashKey = rightHashKey + e.bigHashKey = leftHashKey + case plan.RightOuterJoin: + e.outter = true + e.leftSmall = true + e.smallFilter = composeCondition(v.LeftConditions) + e.bigFilter = composeCondition(v.RightConditions) + e.smallHashKey = leftHashKey + e.bigHashKey = rightHashKey + case plan.InnerJoin: + //TODO: assume right table is the small one before cbo is realized. + e.outter = false + e.leftSmall = false + e.smallFilter = composeCondition(v.RightConditions) + e.bigFilter = composeCondition(v.LeftConditions) + e.smallHashKey = rightHashKey + e.bigHashKey = leftHashKey + default: + b.err = ErrUnknownPlan.Gen("Unknown Join Type !!") + return nil + } + if e.leftSmall { + e.smallExec = b.build(v.GetChildByIndex(0)) + e.bigExec = b.build(v.GetChildByIndex(1)) + } else { + e.smallExec = b.build(v.GetChildByIndex(1)) + e.bigExec = b.build(v.GetChildByIndex(0)) + } + return e +} + func (b *executorBuilder) buildFilter(src Executor, conditions []ast.ExprNode) Executor { if len(conditions) == 0 { return src diff --git a/executor/executor_test.go b/executor/executor_test.go index 212791481471f..e627601a7dfef 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/inspectkv" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/optimizer/plan" "github.com/pingcap/tidb/store/tikv" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" @@ -718,6 +719,8 @@ func (s *testSuite) TestSelectHaving(c *C) { r.Check(testkit.Rows(rowStr)) tk.MustExec("commit") + r = tk.MustQuery("select * from select_having_test group by id having null is not null;") + tk.MustExec("drop table select_having_test") } @@ -925,7 +928,6 @@ func (s *testSuite) TestUnion(c *C) { defer testleak.AfterTest(c)() tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - testSQL := `select 1 union select 0;` tk.MustExec(testSQL) @@ -1098,6 +1100,73 @@ func (s *testSuite) TestJoin(c *C) { result.Check(testkit.Rows(" 5 5", " 9 9", "1 1 1 1 1 1")) } +func (s *testSuite) TestNewJoin(c *C) { + plan.UseNewPlanner = true + defer testleak.AfterTest(c)() + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (c int)") + tk.MustExec("insert t values (1)") + cases := []struct { + sql string + result [][]interface{} + }{ + { + "select 1 from t as a left join t as b on 0", + testkit.Rows("1"), + }, + { + "select 1 from t as a join t as b on 1", + testkit.Rows("1"), + }, + } + for _, ca := range cases { + result := tk.MustQuery(ca.sql) + result.Check(ca.result) + } + + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(c1 int, c2 int)") + tk.MustExec("create table t1(c1 int, c2 int)") + tk.MustExec("insert into t values(1,1),(2,2)") + tk.MustExec("insert into t1 values(2,3),(4,4)") + result := tk.MustQuery("select * from t left outer join t1 on t.c1 = t1.c1 where t.c1 = 1 or t1.c2 > 20") + result.Check(testkit.Rows("1 1 ")) + result = tk.MustQuery("select * from t1 right outer join t on t.c1 = t1.c1 where t.c1 = 1 or t1.c2 > 20") + result.Check(testkit.Rows(" 1 1")) + result = tk.MustQuery("select * from t right outer join t1 on t.c1 = t1.c1 where t.c1 = 1 or t1.c2 > 20") + result.Check(testkit.Rows()) + result = tk.MustQuery("select * from t left outer join t1 on t.c1 = t1.c1 where t1.c1 = 3 or false") + result.Check(testkit.Rows()) + result = tk.MustQuery("select * from t left outer join t1 on t.c1 = t1.c1 and t.c1 != 1") + result.Check(testkit.Rows("1 1 ", "2 2 2 3")) + + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("drop table if exists t3") + + tk.MustExec("create table t1 (c1 int, c2 int)") + tk.MustExec("create table t2 (c1 int, c2 int)") + tk.MustExec("create table t3 (c1 int, c2 int)") + + tk.MustExec("insert into t1 values (1,1), (2,2), (3,3)") + tk.MustExec("insert into t2 values (1,1), (3,3), (5,5)") + tk.MustExec("insert into t3 values (1,1), (5,5), (9,9)") + + result = tk.MustQuery("select * from t1 left join t2 on t1.c1 = t2.c1 right join t3 on t2.c1 = t3.c1 order by t1.c1, t1.c2, t2.c1, t2.c2, t3.c1, t3.c2;") + result.Check(testkit.Rows(" 5 5", " 9 9", "1 1 1 1 1 1")) + + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (c1 int)") + tk.MustExec("insert into t1 values (1), (1), (1)") + result = tk.MustQuery("select * from t1 a join t1 b on a.c1 = b.c1;") + result.Check(testkit.Rows("1 1", "1 1", "1 1", "1 1", "1 1", "1 1", "1 1", "1 1", "1 1")) + + plan.UseNewPlanner = false +} + func (s *testSuite) TestIndexScan(c *C) { defer testleak.AfterTest(c)() tk := testkit.NewTestKit(c, s.store) diff --git a/executor/new_executor.go b/executor/new_executor.go new file mode 100644 index 0000000000000..1d5f104b31ae3 --- /dev/null +++ b/executor/new_executor.go @@ -0,0 +1,232 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/evaluator" + "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/types" +) + +// HashJoinExec implements the hash join algorithm +type HashJoinExec struct { + hashTable map[string][]*Row + smallHashKey []ast.ExprNode + bigHashKey []ast.ExprNode + smallExec Executor + bigExec Executor + prepared bool + fields []*ast.ResultField + ctx context.Context + smallFilter ast.ExprNode + bigFilter ast.ExprNode + otherFilter ast.ExprNode + outter bool + leftSmall bool + matchedRows []*Row + cursor int +} + +func joinTwoRow(a *Row, b *Row) *Row { + ret := &Row{ + RowKeys: make([]*RowKeyEntry, 0, len(a.RowKeys)+len(b.RowKeys)), + Data: make([]types.Datum, 0, len(a.Data)+len(b.Data)), + } + ret.RowKeys = append(ret.RowKeys, a.RowKeys...) + ret.RowKeys = append(ret.RowKeys, b.RowKeys...) + ret.Data = append(ret.Data, a.Data...) + ret.Data = append(ret.Data, b.Data...) + return ret +} + +func (e *HashJoinExec) getHashKey(exprs []ast.ExprNode) ([]byte, error) { + vals := make([]types.Datum, 0, len(exprs)) + for _, expr := range exprs { + v, err := evaluator.Eval(e.ctx, expr) + if err != nil { + return nil, errors.Trace(err) + } + vals = append(vals, v) + } + if len(vals) == 0 { + return []byte{}, nil + } + result, err := codec.EncodeValue([]byte{}, vals...) + if err != nil { + return nil, errors.Trace(err) + } + return result, err +} + +// Fields implements Executor Fields interface. +func (e *HashJoinExec) Fields() []*ast.ResultField { + return e.fields +} + +// Close implements Executor Close interface. +func (e *HashJoinExec) Close() error { + e.hashTable = nil + e.matchedRows = nil + return nil +} + +func (e *HashJoinExec) prepare() error { + e.hashTable = make(map[string][]*Row) + e.cursor = 0 + for { + row, err := e.smallExec.Next() + if err != nil { + return errors.Trace(err) + } + if row == nil { + e.smallExec.Close() + break + } + matched := true + if e.smallFilter != nil { + matched, err = evaluator.EvalBool(e.ctx, e.smallFilter) + if err != nil { + return errors.Trace(err) + } + if !matched { + continue + } + } + hashcode, err := e.getHashKey(e.smallHashKey) + if err != nil { + return err + } + if rows, ok := e.hashTable[string(hashcode)]; !ok { + e.hashTable[string(hashcode)] = []*Row{row} + } else { + e.hashTable[string(hashcode)] = append(rows, row) + } + } + e.prepared = true + return nil +} + +func (e *HashJoinExec) constructMatchedRows(bigRow *Row) (matchedRows []*Row, err error) { + hashcode, err := e.getHashKey(e.bigHashKey) + if err != nil { + return nil, errors.Trace(err) + } + // match eq condition + if rows, ok := e.hashTable[string(hashcode)]; ok { + for _, smallRow := range rows { + //TODO: remove result fields in order to reduce memory copy cost. + otherMatched := true + if e.otherFilter != nil { + startKey := 0 + if !e.leftSmall { + startKey = len(bigRow.Data) + } + for i, data := range smallRow.Data { + e.fields[i+startKey].Expr.SetValue(data.GetValue()) + } + otherMatched, err = evaluator.EvalBool(e.ctx, e.otherFilter) + } + if err != nil { + return nil, errors.Trace(err) + } + if otherMatched { + if e.leftSmall { + matchedRows = append(matchedRows, joinTwoRow(smallRow, bigRow)) + } else { + matchedRows = append(matchedRows, joinTwoRow(bigRow, smallRow)) + } + } + } + } + return matchedRows, nil +} + +func (e *HashJoinExec) fillNullRow(bigRow *Row) (returnRow *Row, err error) { + smallRow := &Row{ + RowKeys: make([]*RowKeyEntry, len(e.smallExec.Fields())), + Data: make([]types.Datum, len(e.smallExec.Fields())), + } + for _, data := range smallRow.Data { + data.SetNull() + } + if e.leftSmall { + returnRow = joinTwoRow(smallRow, bigRow) + } else { + returnRow = joinTwoRow(bigRow, smallRow) + } + for i, data := range returnRow.Data { + e.fields[i].Expr.SetValue(data.GetValue()) + } + return returnRow, nil +} + +func (e *HashJoinExec) returnRecord() (ret *Row, ok bool) { + if e.cursor >= len(e.matchedRows) { + return nil, false + } + for i, data := range e.matchedRows[e.cursor].Data { + e.fields[i].Expr.SetValue(data.GetValue()) + } + e.cursor++ + return e.matchedRows[e.cursor-1], true +} + +// Next implements Executor Next interface. +func (e *HashJoinExec) Next() (*Row, error) { + if !e.prepared { + err := e.prepare() + if err != nil { + return nil, err + } + } + row, ok := e.returnRecord() + if ok { + return row, nil + } + for { + bigRow, err := e.bigExec.Next() + if err != nil { + return nil, err + } + if bigRow == nil { + e.bigExec.Close() + return nil, nil + } + var matchedRows []*Row + bigMatched := true + if e.bigFilter != nil { + bigMatched, err = evaluator.EvalBool(e.ctx, e.bigFilter) + if err != nil { + return nil, errors.Trace(err) + } + } + if bigMatched { + matchedRows, err = e.constructMatchedRows(bigRow) + if err != nil { + return nil, errors.Trace(err) + } + } + e.matchedRows = matchedRows + e.cursor = 0 + row, ok := e.returnRecord() + if ok { + return row, nil + } else if e.outter { + return e.fillNullRow(bigRow) + } + } +} diff --git a/infoschema/infoschema.go b/infoschema/infoschema.go index 46f78cf54932e..c65b26e500e0b 100644 --- a/infoschema/infoschema.go +++ b/infoschema/infoschema.go @@ -101,6 +101,23 @@ type infoSchema struct { schemaMetaVersion int64 } +// MockInfoSchema only serves for test. +func MockInfoSchema(tbList []*model.TableInfo) InfoSchema { + result := &infoSchema{} + result.schemaNameToID = make(map[string]int64) + result.tableNameToID = make(map[tableName]int64) + result.schemas = make(map[int64]*model.DBInfo) + result.tables = make(map[int64]table.Table) + + result.schemaNameToID["test"] = 0 + result.schemas[0] = &model.DBInfo{ID: 0, Name: model.NewCIStr("test"), Tables: tbList} + for i, tb := range tbList { + result.tableNameToID[tableName{schema: "test", table: tb.Name.L}] = int64(i) + result.tables[int64(i)] = table.MockTableFromMeta(tb) + } + return result +} + var _ InfoSchema = (*infoSchema)(nil) type tableName struct { diff --git a/optimizer/new_plan_test.go b/optimizer/new_plan_test.go new file mode 100644 index 0000000000000..f3344d05f3d59 --- /dev/null +++ b/optimizer/new_plan_test.go @@ -0,0 +1,140 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "testing" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/optimizer/plan" + "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/util/testleak" +) + +var _ = Suite(&testPlanSuite{}) + +func TestT(t *testing.T) { + TestingT(t) +} + +type testPlanSuite struct{} + +func newMockResolve(node ast.Node) error { + indices := []*model.IndexInfo{ + { + Name: model.NewCIStr("b"), + Columns: []*model.IndexColumn{ + { + Name: model.NewCIStr("b"), + }, + }, + }, + { + Name: model.NewCIStr("c_d_e"), + Columns: []*model.IndexColumn{ + { + Name: model.NewCIStr("c"), + }, + { + Name: model.NewCIStr("d"), + }, + { + Name: model.NewCIStr("e"), + }, + }, + }, + } + pkColumn := &model.ColumnInfo{ + State: model.StatePublic, + Name: model.NewCIStr("a"), + } + col1 := &model.ColumnInfo{ + State: model.StatePublic, + Name: model.NewCIStr("d"), + } + pkColumn.Flag = mysql.PriKeyFlag + table := &model.TableInfo{ + Columns: []*model.ColumnInfo{pkColumn, col1}, + Indices: indices, + Name: model.NewCIStr("t"), + PKIsHandle: true, + } + is := infoschema.MockInfoSchema([]*model.TableInfo{table}) + return MockResolveName(node, is, "test") +} + +func (s *testPlanSuite) TestPredicatePushDown(c *C) { + plan.UseNewPlanner = true + defer testleak.AfterTest(c)() + cases := []struct { + sql string + first string + best string + }{ + { + sql: "select a from (select a from t where d = 0) k where k.a = 5", + first: "Table(t)->Fields->Filter->Fields", + best: "Range(t)->Fields->Fields", + }, + { + sql: "select a from (select 1+2 as a from t where d = 0) k where k.a = 5", + first: "Table(t)->Fields->Filter->Fields", + best: "Table(t)->Fields->Filter->Fields", + }, + { + sql: "select a from (select d as a from t where d = 0) k where k.a = 5", + first: "Table(t)->Fields->Filter->Fields", + best: "Table(t)->Fields->Fields", + }, + { + sql: "select * from t ta join t tb on ta.d = tb.d and ta.d > 1 where tb.a = 0", + first: "Join{Table(t)->Table(t)}->Filter->Fields", + best: "Join{Table(t)->Range(t)}->Fields", + }, + { + sql: "select * from t ta left outer join t tb on ta.d = tb.d and ta.d > 1 where tb.a = 0", + first: "Join{Table(t)->Table(t)}->Filter->Fields", + best: "Join{Table(t)->Table(t)}->Filter->Fields", + }, + { + sql: "select * from t ta right outer join t tb on ta.d = tb.d and ta.a > 1 where tb.a = 0", + first: "Join{Table(t)->Table(t)}->Filter->Fields", + best: "Join{Range(t)->Range(t)}->Fields", + }, + } + for _, ca := range cases { + comment := Commentf("for %s", ca.sql) + stmt, err := parser.ParseOneStmt(ca.sql, "", "") + c.Assert(err, IsNil, comment) + ast.SetFlag(stmt) + + err = newMockResolve(stmt) + c.Assert(err, IsNil) + + p, err := plan.BuildPlan(stmt, nil) + c.Assert(err, IsNil) + c.Assert(plan.ToString(p), Equals, ca.first, Commentf("for %s", ca.sql)) + + _, err = plan.PredicatePushDown(p, []ast.ExprNode{}) + c.Assert(err, IsNil) + err = plan.Refine(p) + c.Assert(err, IsNil) + c.Assert(plan.ToString(p), Equals, ca.best, Commentf("for %s", ca.sql)) + } + plan.UseNewPlanner = false +} diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 04bf748689a8e..41ff17dc8e6db 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -37,6 +37,12 @@ func Optimize(ctx context.Context, node ast.Node, sb plan.SubQueryBuilder) (plan if err != nil { return nil, errors.Trace(err) } + if plan.UseNewPlanner { + _, err = plan.PredicatePushDown(p, []ast.ExprNode{}) + if err != nil { + return nil, errors.Trace(err) + } + } err = plan.Refine(p) if err != nil { return nil, errors.Trace(err) diff --git a/optimizer/plan/new_plans.go b/optimizer/plan/new_plans.go new file mode 100644 index 0000000000000..a594bb47cfefe --- /dev/null +++ b/optimizer/plan/new_plans.go @@ -0,0 +1,89 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" +) + +// JoinType contains CrossJoin, InnerJoin, LeftOuterJoin, RightOuterJoin, FullOuterJoin, SemiJoin. +type JoinType int + +const ( + // CrossJoin means Cartesian Product, but not used now + CrossJoin JoinType = iota + // InnerJoin means inner join + InnerJoin + // LeftOuterJoin means left join + LeftOuterJoin + // RightOuterJoin means right join + RightOuterJoin + // TODO: support semi join. +) + +// Join is the logical join plan. +type Join struct { + basePlan + + JoinType JoinType + + EqualConditions []ast.ExprNode + LeftConditions []ast.ExprNode + RightConditions []ast.ExprNode + OtherConditions []ast.ExprNode +} + +// AddChild for parent. +func addChild(parent Plan, child Plan) { + if child == nil || parent == nil { + return + } + child.AddParent(parent) + parent.AddChild(child) +} + +// InsertPlan means inserting plan between two plans. +func InsertPlan(parent Plan, child Plan, insert Plan) error { + err := child.ReplaceParent(parent, insert) + if err != nil { + return errors.Trace(err) + } + err = parent.ReplaceChild(child, insert) + if err != nil { + return errors.Trace(err) + } + insert.AddChild(child) + insert.AddParent(parent) + return err +} + +// RemovePlan means removing a plan. +func RemovePlan(p Plan) error { + parents := p.GetParents() + children := p.GetChildren() + if len(parents) != 1 || len(children) != 1 { + return SystemInternalErrorType.Gen("can't remove this plan") + } + parent, child := parents[0], children[0] + err := parent.ReplaceChild(p, child) + if err != nil { + return errors.Trace(err) + } + err = child.ReplaceParent(p, parent) + if err != nil { + return errors.Trace(err) + } + return nil +} diff --git a/optimizer/plan/newplanbuilder.go b/optimizer/plan/newplanbuilder.go new file mode 100644 index 0000000000000..f9a8c5a54106d --- /dev/null +++ b/optimizer/plan/newplanbuilder.go @@ -0,0 +1,248 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan + +import ( + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/parser/opcode" +) + +// UseNewPlanner means if use the new planner. +var UseNewPlanner = false + +func (b *planBuilder) buildNewSinglePathPlan(node ast.ResultSetNode) Plan { + switch x := node.(type) { + case *ast.Join: + return b.buildNewJoin(x) + case *ast.TableSource: + switch v := x.Source.(type) { + case *ast.SelectStmt: + return b.buildSelect(v) + case *ast.UnionStmt: + return b.buildUnion(v) + case *ast.TableName: + //TODO: select physical algorithm during cbo phase. + return b.buildNewTableScanPlan(v) + default: + b.err = ErrUnsupportedType.Gen("unsupported table source type %T", v) + return nil + } + default: + b.err = ErrUnsupportedType.Gen("unsupported table source type %T", x) + return nil + } +} + +func fromFields(col *ast.ColumnNameExpr, fields []*ast.ResultField) bool { + for _, field := range fields { + if field == col.Refer { + return true + } + } + return false +} + +type columnsExtractor struct { + result []*ast.ColumnNameExpr +} + +func (ce *columnsExtractor) Enter(expr ast.Node) (ret ast.Node, skipChildren bool) { + switch v := expr.(type) { + case *ast.ColumnNameExpr: + ce.result = append(ce.result, v) + } + return expr, false +} + +func (ce *columnsExtractor) Leave(expr ast.Node) (ret ast.Node, skipChildren bool) { + return expr, true +} + +func extractOnCondition(conditions []ast.ExprNode, left Plan, right Plan) (eqCond []ast.ExprNode, leftCond []ast.ExprNode, rightCond []ast.ExprNode, otherCond []ast.ExprNode) { + for _, expr := range conditions { + binop, ok := expr.(*ast.BinaryOperationExpr) + if ok && binop.Op == opcode.EQ { + ln, lOK := binop.L.(*ast.ColumnNameExpr) + rn, rOK := binop.R.(*ast.ColumnNameExpr) + if lOK && rOK { + if fromFields(ln, left.Fields()) && fromFields(rn, right.Fields()) { + eqCond = append(eqCond, expr) + continue + } else if fromFields(rn, left.Fields()) && fromFields(ln, right.Fields()) { + eqCond = append(eqCond, &ast.BinaryOperationExpr{Op: opcode.EQ, L: rn, R: ln}) + continue + } + } + } + ce := &columnsExtractor{} + expr.Accept(ce) + columns := ce.result + allFromLeft, allFromRight := true, true + for _, col := range columns { + if fromFields(col, left.Fields()) { + allFromRight = false + } else { + allFromLeft = false + } + } + if allFromRight { + rightCond = append(rightCond, expr) + } else if allFromLeft { + leftCond = append(leftCond, expr) + } else { + otherCond = append(otherCond, expr) + } + } + return eqCond, leftCond, rightCond, otherCond +} + +func (b *planBuilder) buildNewJoin(join *ast.Join) Plan { + if join.Right == nil { + return b.buildNewSinglePathPlan(join.Left) + } + leftPlan := b.buildNewSinglePathPlan(join.Left) + rightPlan := b.buildNewSinglePathPlan(join.Right) + var eqCond, leftCond, rightCond, otherCond []ast.ExprNode + if join.On != nil { + onCondition := splitWhere(join.On.Expr) + eqCond, leftCond, rightCond, otherCond = extractOnCondition(onCondition, leftPlan, rightPlan) + } + joinPlan := &Join{EqualConditions: eqCond, LeftConditions: leftCond, RightConditions: rightCond, OtherConditions: otherCond} + if join.Tp == ast.LeftJoin { + joinPlan.JoinType = LeftOuterJoin + } else if join.Tp == ast.RightJoin { + joinPlan.JoinType = RightOuterJoin + } else { + joinPlan.JoinType = InnerJoin + } + addChild(joinPlan, leftPlan) + addChild(joinPlan, rightPlan) + joinPlan.SetFields(append(leftPlan.Fields(), rightPlan.Fields()...)) + return joinPlan +} + +func (b *planBuilder) buildFilter(p Plan, where ast.ExprNode) Plan { + conditions := splitWhere(where) + filter := &Filter{Conditions: conditions} + addChild(filter, p) + filter.SetFields(p.Fields()) + return filter +} + +func (b *planBuilder) buildNewSelect(sel *ast.SelectStmt) Plan { + var aggFuncs []*ast.AggregateFuncExpr + hasAgg := b.detectSelectAgg(sel) + if hasAgg { + aggFuncs = b.extractSelectAgg(sel) + } + // Build subquery + // Convert subquery to expr with plan + b.buildSubquery(sel) + var p Plan + if sel.From != nil { + p = b.buildNewSinglePathPlan(sel.From.TableRefs) + if sel.Where != nil { + p = b.buildFilter(p, sel.Where) + } + if b.err != nil { + return nil + } + if sel.LockTp != ast.SelectLockNone { + p = b.buildSelectLock(p, sel.LockTp) + if b.err != nil { + return nil + } + } + if hasAgg { + p = b.buildAggregate(p, aggFuncs, sel.GroupBy) + } + p = b.buildSelectFields(p, sel.GetResultFields()) + if b.err != nil { + return nil + } + } else { + if sel.Where != nil { + p = b.buildTableDual(sel) + } + if hasAgg { + p = b.buildAggregate(p, aggFuncs, nil) + } + p = b.buildSelectFields(p, sel.GetResultFields()) + if b.err != nil { + return nil + } + } + if sel.Having != nil { + p = b.buildFilter(p, sel.Having.Expr) + if b.err != nil { + return nil + } + } + if sel.Distinct { + p = b.buildDistinct(p) + if b.err != nil { + return nil + } + } + if sel.OrderBy != nil && !pushOrder(p, sel.OrderBy.Items) { + p = b.buildSort(p, sel.OrderBy.Items) + if b.err != nil { + return nil + } + } + if sel.Limit != nil { + p = b.buildLimit(p, sel.Limit) + if b.err != nil { + return nil + } + } + return p +} + +func (ts *TableScan) attachCondition(conditions []ast.ExprNode) { + var pkName model.CIStr + if ts.Table.PKIsHandle { + for _, colInfo := range ts.Table.Columns { + if mysql.HasPriKeyFlag(colInfo.Flag) { + pkName = colInfo.Name + } + } + } + for _, con := range conditions { + if pkName.L != "" { + checker := conditionChecker{tableName: ts.Table.Name, pkName: pkName} + if checker.check(con) { + ts.AccessConditions = append(ts.AccessConditions, con) + } else { + ts.FilterConditions = append(ts.FilterConditions, con) + } + } else { + ts.FilterConditions = append(ts.FilterConditions, con) + } + } +} + +func (b *planBuilder) buildNewTableScanPlan(tn *ast.TableName) Plan { + p := &TableScan{ + Table: tn.TableInfo, + TableName: tn, + } + // Equal condition contains a column from previous joined table. + p.RefAccess = false + p.SetFields(tn.GetResultFields()) + p.TableAsName = getTableAsName(p.Fields()) + return p +} diff --git a/optimizer/plan/plan.go b/optimizer/plan/plan.go index 4e362e8dd880f..69091de7fe43a 100644 --- a/optimizer/plan/plan.go +++ b/optimizer/plan/plan.go @@ -39,6 +39,10 @@ type Plan interface { AddParent(parent Plan) // AddChild means append a child for plan. AddChild(children Plan) + // ReplaceParent means replace a parent with another one. + ReplaceParent(parent, newPar Plan) error + // ReplaceChild means replace a child with another one. + ReplaceChild(children, newChild Plan) error // Retrieve parent by index. GetParentByIndex(index int) Plan // Retrieve child by index. @@ -97,16 +101,34 @@ func (p *basePlan) SetFields(fields []*ast.ResultField) { // AddParent implements Plan AddParent interface. func (p *basePlan) AddParent(parent Plan) { - if parent != nil { - p.parents = append(p.parents, parent) - } + p.parents = append(p.parents, parent) } // AddChild implements Plan AddChild interface. func (p *basePlan) AddChild(child Plan) { - if child != nil { - p.children = append(p.children, child) + p.children = append(p.children, child) +} + +// ReplaceParent means replace a parent for another one. +func (p *basePlan) ReplaceParent(parent, newPar Plan) error { + for i, par := range p.parents { + if par == parent { + p.parents[i] = newPar + return nil + } + } + return SystemInternalErrorType.Gen("RemoveParent Failed!") +} + +// ReplaceChild means replace a child with another one. +func (p *basePlan) ReplaceChild(child, newChild Plan) error { + for i, ch := range p.children { + if ch == child { + p.children[i] = newChild + return nil + } } + return SystemInternalErrorType.Gen("RemoveChildren Failed!") } // GetParentByIndex implements Plan GetParentByIndex interface. diff --git a/optimizer/plan/planbuilder.go b/optimizer/plan/planbuilder.go index d48b0998d6b09..037c27b0c85a7 100644 --- a/optimizer/plan/planbuilder.go +++ b/optimizer/plan/planbuilder.go @@ -31,12 +31,14 @@ import ( // Error instances. var ( - ErrUnsupportedType = terror.ClassOptimizerPlan.New(CodeUnsupportedType, "Unsupported type") + ErrUnsupportedType = terror.ClassOptimizerPlan.New(CodeUnsupportedType, "Unsupported type") + SystemInternalErrorType = terror.ClassOptimizerPlan.New(SystemInternalError, "System internal error") ) // Error codes. const ( CodeUnsupportedType terror.ErrCode = 1 + SystemInternalError terror.ErrCode = 2 ) // BuildPlan builds a plan from a node. @@ -87,6 +89,9 @@ func (b *planBuilder) build(node ast.Node) Plan { case *ast.PrepareStmt: return b.buildPrepare(x) case *ast.SelectStmt: + if UseNewPlanner { + return b.buildNewSelect(x) + } return b.buildSelect(x) case *ast.UnionStmt: return b.buildUnion(x) @@ -536,12 +541,12 @@ func (b *planBuilder) buildPseudoSelectPlan(p Plan, sel *ast.SelectStmt) Plan { x.NoLimit = true } np := &Sort{ByItems: sel.OrderBy.Items} - np.AddChild(p) + addChild(np, p) p = np } if sel.Limit != nil { np := &Limit{Offset: sel.Limit.Offset, Count: sel.Limit.Count} - np.AddChild(p) + addChild(np, p) np.SetLimit(0) p = np } else { @@ -557,14 +562,14 @@ func (b *planBuilder) buildSelectLock(src Plan, lock ast.SelectLockType) *Select selectLock := &SelectLock{ Lock: lock, } - selectLock.AddChild(src) + addChild(selectLock, src) selectLock.SetFields(src.Fields()) return selectLock } func (b *planBuilder) buildSelectFields(src Plan, fields []*ast.ResultField) Plan { selectFields := &SelectFields{} - selectFields.AddChild(src) + addChild(selectFields, src) selectFields.SetFields(fields) return selectFields } @@ -574,7 +579,7 @@ func (b *planBuilder) buildAggregate(src Plan, aggFuncs []*ast.AggregateFuncExpr aggPlan := &Aggregate{ AggFuncs: aggFuncs, } - aggPlan.AddChild(src) + addChild(aggPlan, src) if src != nil { aggPlan.SetFields(src.Fields()) } @@ -588,7 +593,7 @@ func (b *planBuilder) buildHaving(src Plan, having *ast.HavingClause) Plan { p := &Having{ Conditions: splitWhere(having.Expr), } - p.AddChild(src) + addChild(p, src) p.SetFields(src.Fields()) return p } @@ -597,7 +602,7 @@ func (b *planBuilder) buildSort(src Plan, byItems []*ast.ByItem) Plan { sort := &Sort{ ByItems: byItems, } - sort.AddChild(src) + addChild(sort, src) sort.SetFields(src.Fields()) return sort } @@ -611,7 +616,7 @@ func (b *planBuilder) buildLimit(src Plan, limit *ast.Limit) Plan { s.ExecLimit = li return s } - li.AddChild(src) + addChild(li, src) li.SetFields(src.Fields()) return li } @@ -835,7 +840,11 @@ func (se *subqueryVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { func (b *planBuilder) buildUnion(union *ast.UnionStmt) Plan { sels := make([]Plan, len(union.SelectList.Selects)) for i, sel := range union.SelectList.Selects { - sels[i] = b.buildSelect(sel) + if UseNewPlanner { + sels[i] = b.buildNewSelect(sel) + } else { + sels[i] = b.buildSelect(sel) + } } var p Plan p = &Union{ @@ -871,6 +880,7 @@ func (b *planBuilder) buildUnion(union *ast.UnionStmt) Plan { uField.Column.Tp = f.Column.Tp } } + addChild(p, sel) } for _, v := range unionFields { v.Expr.SetType(&v.Column.FieldType) @@ -891,7 +901,7 @@ func (b *planBuilder) buildUnion(union *ast.UnionStmt) Plan { func (b *planBuilder) buildDistinct(src Plan) Plan { d := &Distinct{} - d.AddChild(src) + addChild(d, src) d.SetFields(src.Fields()) return d } @@ -1034,7 +1044,7 @@ func (b *planBuilder) buildShow(show *ast.ShowStmt) Plan { } if len(conditions) != 0 { filter := &Filter{Conditions: conditions} - filter.AddChild(p) + addChild(filter, p) p = filter } return p @@ -1056,7 +1066,7 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan { } if insert.Select != nil { insertPlan.SelectPlan = b.build(insert.Select) - insertPlan.AddChild(insertPlan.SelectPlan) + addChild(insertPlan, insertPlan.SelectPlan) if b.err != nil { return nil } @@ -1077,7 +1087,7 @@ func (b *planBuilder) buildExplain(explain *ast.ExplainStmt) Plan { return nil } p := &Explain{StmtPlan: targetPlan} - p.AddChild(targetPlan) + addChild(p, targetPlan) p.SetFields(buildExplainFields()) return p } diff --git a/optimizer/plan/planbuilder_join.go b/optimizer/plan/planbuilder_join.go index daceae4d68de2..14c7aaa974885 100644 --- a/optimizer/plan/planbuilder_join.go +++ b/optimizer/plan/planbuilder_join.go @@ -619,7 +619,7 @@ func (b *planBuilder) buildJoin(sel *ast.SelectStmt) Plan { p.SetFields(rfs) if filterConditions != nil { filterPlan := &Filter{Conditions: filterConditions} - filterPlan.AddChild(p) + addChild(filterPlan, p) filterPlan.SetFields(p.Fields()) return filterPlan } @@ -737,8 +737,8 @@ func (b *planBuilder) buildPlanFromJoinPath(path *joinPath) Plan { Outer: b.buildPlanFromJoinPath(path.outer), Inner: b.buildPlanFromJoinPath(path.inner), } - join.AddChild(join.Outer) - join.AddChild(join.Inner) + addChild(join, join.Outer) + addChild(join, join.Inner) if path.rightJoin { join.SetFields(append(join.Inner.Fields(), join.Outer.Fields()...)) } else { @@ -751,7 +751,7 @@ func (b *planBuilder) buildPlanFromJoinPath(path *joinPath) Plan { inPlan := b.buildPlanFromJoinPath(in) join.Inners = append(join.Inners, inPlan) join.fields = append(join.fields, in.resultFields()...) - join.AddChild(inPlan) + addChild(join, inPlan) } join.Conditions = path.conditions for _, equiv := range path.eqConds { @@ -817,7 +817,7 @@ func (b *planBuilder) buildSubqueryJoinPath(path *joinPath) Plan { return p } filterPlan := &Filter{Conditions: path.conditions} - filterPlan.AddChild(p) + addChild(filterPlan, p) filterPlan.SetFields(p.Fields()) return filterPlan } diff --git a/optimizer/plan/predicate_push_down.go b/optimizer/plan/predicate_push_down.go new file mode 100644 index 0000000000000..2f8c498322299 --- /dev/null +++ b/optimizer/plan/predicate_push_down.go @@ -0,0 +1,187 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" +) + +func addFilter(p Plan, child Plan, conditions []ast.ExprNode) error { + filter := &Filter{Conditions: conditions} + return InsertPlan(p, child, filter) +} + +// columnSubstituor substitutes the columns in filter to expressions in select fields. +// e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k. +type columnSubstitutor struct { + fields []*ast.ResultField +} + +func (cl *columnSubstitutor) Enter(inNode ast.Node) (node ast.Node, skipChild bool) { + return inNode, false +} + +func (cl *columnSubstitutor) Leave(inNode ast.Node) (node ast.Node, ok bool) { + switch v := inNode.(type) { + case *ast.ColumnNameExpr: + for _, field := range cl.fields { + if v.Refer == field { + return field.Expr, true + } + } + } + return inNode, true +} + +// PredicatePushDown applies predicate push down to all kinds of plans, except aggregation and union. +func PredicatePushDown(p Plan, predicates []ast.ExprNode) (ret []ast.ExprNode, err error) { + switch v := p.(type) { + case *TableScan: + v.attachCondition(predicates) + return ret, nil + case *Filter: + conditions := v.Conditions + retConditions, err1 := PredicatePushDown(p.GetChildByIndex(0), append(conditions, predicates...)) + if err1 != nil { + return nil, errors.Trace(err1) + } + if len(retConditions) > 0 { + v.Conditions = retConditions + } else { + if len(p.GetParents()) == 0 { + return ret, nil + } + err1 = RemovePlan(p) + if err1 != nil { + return nil, errors.Trace(err1) + } + } + return ret, nil + case *Join: + //TODO: add null rejecter + var leftCond, rightCond []ast.ExprNode + leftPlan := v.GetChildByIndex(0) + rightPlan := v.GetChildByIndex(1) + equalCond, leftPushCond, rightPushCond, otherCond := extractOnCondition(predicates, leftPlan, rightPlan) + if v.JoinType == LeftOuterJoin { + rightCond = v.RightConditions + leftCond = leftPushCond + ret = append(equalCond, otherCond...) + ret = append(ret, rightPushCond...) + } else if v.JoinType == RightOuterJoin { + leftCond = v.LeftConditions + rightCond = rightPushCond + ret = append(equalCond, otherCond...) + ret = append(ret, leftPushCond...) + } else { + leftCond = append(v.LeftConditions, leftPushCond...) + rightCond = append(v.RightConditions, rightPushCond...) + } + leftRet, err1 := PredicatePushDown(leftPlan, leftCond) + if err1 != nil { + return nil, errors.Trace(err1) + } + rightRet, err2 := PredicatePushDown(rightPlan, rightCond) + if err2 != nil { + return nil, errors.Trace(err2) + } + if len(leftRet) > 0 { + err2 = addFilter(p, leftPlan, leftRet) + if err2 != nil { + return nil, errors.Trace(err2) + } + } + if len(rightRet) > 0 { + err2 = addFilter(p, rightPlan, rightRet) + if err2 != nil { + return nil, errors.Trace(err2) + } + } + if v.JoinType == InnerJoin { + v.EqualConditions = append(v.EqualConditions, equalCond...) + v.OtherConditions = append(v.OtherConditions, otherCond...) + } + return ret, nil + case *SelectFields: + if len(v.GetChildren()) == 0 { + return predicates, nil + } + cs := &columnSubstitutor{fields: v.Fields()} + var push []ast.ExprNode + for _, cond := range predicates { + ce := &columnsExtractor{} + ok := true + cond.Accept(ce) + for _, col := range ce.result { + match := false + for _, field := range v.Fields() { + if col.Refer == field { + switch field.Expr.(type) { + case *ast.ColumnNameExpr: + match = true + } + break + } + } + if !match { + ok = false + break + } + } + if ok { + cond1, _ := cond.Accept(cs) + cond = cond1.(ast.ExprNode) + push = append(push, cond) + } else { + ret = append(ret, cond) + } + } + restConds, err1 := PredicatePushDown(v.GetChildByIndex(0), push) + if err1 != nil { + return nil, errors.Trace(err1) + } + if len(restConds) > 0 { + err1 = addFilter(v, v.GetChildByIndex(0), restConds) + if err1 != nil { + return nil, errors.Trace(err1) + } + } + return ret, nil + case *Sort, *Limit, *Distinct: + rest, err1 := PredicatePushDown(p.GetChildByIndex(0), predicates) + if err1 != nil { + return nil, errors.Trace(err1) + } + if len(rest) > 0 { + err1 = addFilter(p, p.GetChildByIndex(0), rest) + if err1 != nil { + return nil, errors.Trace(err1) + } + } + return ret, nil + default: + if len(v.GetChildren()) == 0 { + return predicates, nil + } + //TODO: support union and sub queries when abandon result field. + for _, child := range v.GetChildren() { + _, err = PredicatePushDown(child, []ast.ExprNode{}) + if err != nil { + return nil, errors.Trace(err) + } + } + return predicates, nil + } +} diff --git a/optimizer/plan/stringer.go b/optimizer/plan/stringer.go index c2afa54c3f45c..f9bb915f168b5 100644 --- a/optimizer/plan/stringer.go +++ b/optimizer/plan/stringer.go @@ -27,7 +27,7 @@ func ToString(p Plan) string { func toString(in Plan, strs []string, idxs []int) ([]string, []int) { switch in.(type) { - case *JoinOuter, *JoinInner: + case *JoinOuter, *JoinInner, *Join: idxs = append(idxs, len(strs)) } @@ -87,6 +87,13 @@ func toString(in Plan, strs []string, idxs []int) ([]string, []int) { strs = strs[:idx] str = "InnerJoin{" + strings.Join(children, "->") + "}" idxs = idxs[:last] + case *Join: + last := len(idxs) - 1 + idx := idxs[last] + children := strs[idx:] + strs = strs[:idx] + str = "Join{" + strings.Join(children, "->") + "}" + idxs = idxs[:last] case *Aggregate: str = "Aggregate" case *Distinct: diff --git a/optimizer/resolver.go b/optimizer/resolver.go index 1071933736dd0..6434edf7da097 100644 --- a/optimizer/resolver.go +++ b/optimizer/resolver.go @@ -36,6 +36,13 @@ func ResolveName(node ast.Node, info infoschema.InfoSchema, ctx context.Context) return errors.Trace(resolver.Err) } +// MockResolveName only serves for test. +func MockResolveName(node ast.Node, info infoschema.InfoSchema, defaultSchema string) error { + resolver := nameResolver{Info: info, Ctx: nil, DefaultSchema: model.NewCIStr(defaultSchema)} + node.Accept(&resolver) + return resolver.Err +} + // nameResolver is the visitor to resolve table name and column name. // In general, a reference can only refer to information that are available for it. // So children elements are visited in the order that previous elements make information diff --git a/table/table.go b/table/table.go index b1364f4013e87..a0b0668b1e7b6 100644 --- a/table/table.go +++ b/table/table.go @@ -123,3 +123,6 @@ func GetColDefaultValue(ctx context.Context, col *model.ColumnInfo) (types.Datum return types.NewDatum(col.DefaultValue), true, nil } + +// MockTableFromMeta only serves for test. +var MockTableFromMeta func(tableInfo *model.TableInfo) Table diff --git a/table/tables/tables.go b/table/tables/tables.go index 495128d1a18ec..9ff671196678a 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -56,6 +56,11 @@ type Table struct { meta *model.TableInfo } +// MockTableFromMeta only serves for test. +func MockTableFromMeta(tableInfo *model.TableInfo) table.Table { + return &Table{ID: 0, meta: tableInfo} +} + // TableFromMeta creates a Table instance from model.TableInfo. func TableFromMeta(alloc autoid.Allocator, tblInfo *model.TableInfo) (table.Table, error) { if tblInfo.State == model.StateNone { @@ -886,4 +891,5 @@ func FindIndexByColName(t table.Table, name string) *column.IndexedCol { func init() { table.TableFromMeta = TableFromMeta + table.MockTableFromMeta = MockTableFromMeta }