Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner, executor: inject proj below TopN and Sort if byItem contains scalarFunc (#9197) #9319

Merged
merged 1 commit into from
Feb 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,32 @@ func (s *testSuite) TestBuildProjBelowAgg(c *C) {
"4 3 18 7,7,7 8"))
}

func (s *testSuite) TestInjectProjBelowTopN(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t (i int);")
tk.MustExec("insert into t values (1), (1), (1),(2),(3),(2),(3),(2),(3);")
tk.MustQuery("explain select * from t order by i + 1").Check(testkit.Rows(
"Projection_8 10000.00 root test.t.i",
"└─Sort_4 10000.00 root col_1:asc",
" └─Projection_9 10000.00 root test.t.i, plus(test.t.i, 1)",
" └─TableReader_7 10000.00 root data:TableScan_6",
" └─TableScan_6 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo"))
rs := tk.MustQuery("select * from t order by i + 1 ")
rs.Check(testkit.Rows(
"1", "1", "1", "2", "2", "2", "3", "3", "3"))
tk.MustQuery("explain select * from t order by i + 1 limit 2").Check(testkit.Rows(
"Projection_15 2.00 root test.t.i",
"└─TopN_7 2.00 root col_1:asc, offset:0, count:2",
" └─Projection_16 2.00 root test.t.i, plus(test.t.i, 1)",
" └─TableReader_12 2.00 root data:TopN_11",
" └─TopN_11 2.00 cop plus(test.t.i, 1):asc, offset:0, count:2",
" └─TableScan_10 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo"))
rs = tk.MustQuery("select * from t order by i + 1 limit 2")
rs.Check(testkit.Rows("1", "1"))
tk.MustQuery("select i, i, i from t order by i + 1").Check(testkit.Rows("1 1 1", "1 1 1", "1 1 1", "2 2 2", "2 2 2", "2 2 2", "3 3 3", "3 3 3", "3 3 3"))
}

func (s *testSuite) TestFirstRowEnum(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec(`use test;`)
Expand Down
128 changes: 11 additions & 117 deletions executor/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ type SortExec struct {
keyColumns []int
// keyCmpFuncs is used to compare each ByItem.
keyCmpFuncs []chunk.CompareFunc
// keyChunks is used to store ByItems values when not all ByItems are column.
keyChunks *chunk.List
// rowChunks is the chunks to store row values.
rowChunks *chunk.List
// rowPointer store the chunk index and row index for each row.
Expand Down Expand Up @@ -86,17 +84,8 @@ func (e *SortExec) Next(ctx context.Context, chk *chunk.Chunk) error {
}
e.initPointers()
e.initCompareFuncs()
allColumnExpr := e.buildKeyColumns()
if allColumnExpr {
sort.Slice(e.rowPtrs, e.keyColumnsLess)
} else {
e.buildKeyExprsAndTypes()
err = e.buildKeyChunks()
if err != nil {
return errors.Trace(err)
}
sort.Slice(e.rowPtrs, e.keyChunksLess)
}
e.buildKeyColumns()
sort.Slice(e.rowPtrs, e.keyColumnsLess)
e.fetched = true
}
for chk.NumRows() < e.maxChunkSize {
Expand Down Expand Up @@ -149,49 +138,14 @@ func (e *SortExec) initCompareFuncs() {
}
}

func (e *SortExec) buildKeyColumns() (allColumnExpr bool) {
func (e *SortExec) buildKeyColumns() {
e.keyColumns = make([]int, 0, len(e.ByItems))
for _, by := range e.ByItems {
if col, ok := by.Expr.(*expression.Column); ok {
e.keyColumns = append(e.keyColumns, col.Index)
} else {
e.keyColumns = e.keyColumns[:0]
for i := range e.ByItems {
e.keyColumns = append(e.keyColumns, i)
}
return false
}
}
return true
}

func (e *SortExec) buildKeyExprsAndTypes() {
keyLen := len(e.ByItems)
e.keyTypes = make([]*types.FieldType, keyLen)
e.keyExprs = make([]expression.Expression, keyLen)
for keyColIdx := range e.ByItems {
e.keyExprs[keyColIdx] = e.ByItems[keyColIdx].Expr
e.keyTypes[keyColIdx] = e.ByItems[keyColIdx].Expr.GetType()
col := by.Expr.(*expression.Column)
e.keyColumns = append(e.keyColumns, col.Index)
}
}

func (e *SortExec) buildKeyChunks() error {
e.keyChunks = chunk.NewList(e.keyTypes, e.initCap, e.maxChunkSize)
e.keyChunks.GetMemTracker().SetLabel("keyChunks")
e.keyChunks.GetMemTracker().AttachTo(e.memTracker)

for chkIdx := 0; chkIdx < e.rowChunks.NumChunks(); chkIdx++ {
keyChk := chunk.NewChunkWithCapacity(e.keyTypes, e.rowChunks.GetChunk(chkIdx).NumRows())
childIter := chunk.NewIterator4Chunk(e.rowChunks.GetChunk(chkIdx))
err := expression.VectorizedExecute(e.ctx, e.keyExprs, childIter, keyChk)
if err != nil {
return errors.Trace(err)
}
e.keyChunks.Add(keyChk)
}
return nil
}

func (e *SortExec) lessRow(rowI, rowJ chunk.Row) bool {
for i, colIdx := range e.keyColumns {
cmpFunc := e.keyCmpFuncs[i]
Expand All @@ -215,13 +169,6 @@ func (e *SortExec) keyColumnsLess(i, j int) bool {
return e.lessRow(rowI, rowJ)
}

// keyChunksLess is the less function for key chunk.
func (e *SortExec) keyChunksLess(i, j int) bool {
keyRowI := e.keyChunks.GetRow(e.rowPtrs[i])
keyRowJ := e.keyChunks.GetRow(e.rowPtrs[j])
return e.lessRow(keyRowI, keyRowJ)
}

// TopNExec implements a Top-N algorithm and it is built from a SELECT statement with ORDER BY and LIMIT.
// Instead of sorting all the rows fetched from the table, it keeps the Top-N elements only in a heap to reduce memory usage.
type TopNExec struct {
Expand All @@ -240,19 +187,6 @@ type topNChunkHeap struct {
// Less implement heap.Interface, but since we mantains a max heap,
// this function returns true if row i is greater than row j.
func (h *topNChunkHeap) Less(i, j int) bool {
if h.keyChunks != nil {
return h.keyChunksGreater(i, j)
}
return h.keyColumnsGreater(i, j)
}

func (h *topNChunkHeap) keyChunksGreater(i, j int) bool {
keyRowI := h.keyChunks.GetRow(h.rowPtrs[i])
keyRowJ := h.keyChunks.GetRow(h.rowPtrs[j])
return h.greaterRow(keyRowI, keyRowJ)
}

func (h *topNChunkHeap) keyColumnsGreater(i, j int) bool {
rowI := h.rowChunks.GetRow(h.rowPtrs[i])
rowJ := h.rowChunks.GetRow(h.rowPtrs[j])
return h.greaterRow(rowI, rowJ)
Expand Down Expand Up @@ -348,14 +282,7 @@ func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error {
}
e.initPointers()
e.initCompareFuncs()
allColumnExpr := e.buildKeyColumns()
if !allColumnExpr {
e.buildKeyExprsAndTypes()
err := e.buildKeyChunks()
if err != nil {
return errors.Trace(err)
}
}
e.buildKeyColumns()
return nil
}

Expand All @@ -367,10 +294,6 @@ func (e *TopNExec) executeTopN(ctx context.Context) error {
// The number of rows we loaded may exceeds total limit, remove greatest rows by Pop.
heap.Pop(e.chkHeap)
}
var childKeyChk *chunk.Chunk
if e.keyChunks != nil {
childKeyChk = chunk.NewChunkWithCapacity(e.keyTypes, e.maxChunkSize)
}
childRowChk := e.children[0].newFirstChunk()
for {
err := e.children[0].Next(ctx, childRowChk)
Expand All @@ -380,7 +303,7 @@ func (e *TopNExec) executeTopN(ctx context.Context) error {
if childRowChk.NumRows() == 0 {
break
}
err = e.processChildChk(childRowChk, childKeyChk)
err = e.processChildChk(childRowChk)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -391,38 +314,19 @@ func (e *TopNExec) executeTopN(ctx context.Context) error {
}
}
}
if e.keyChunks != nil {
sort.Slice(e.rowPtrs, e.keyChunksLess)
} else {
sort.Slice(e.rowPtrs, e.keyColumnsLess)
}
sort.Slice(e.rowPtrs, e.keyColumnsLess)
return nil
}

func (e *TopNExec) processChildChk(childRowChk, childKeyChk *chunk.Chunk) error {
if childKeyChk != nil {
childKeyChk.Reset()
err := expression.VectorizedExecute(e.ctx, e.keyExprs, chunk.NewIterator4Chunk(childRowChk), childKeyChk)
if err != nil {
return errors.Trace(err)
}
}
func (e *TopNExec) processChildChk(childRowChk *chunk.Chunk) error {
for i := 0; i < childRowChk.NumRows(); i++ {
heapMaxPtr := e.rowPtrs[0]
var heapMax, next chunk.Row
if childKeyChk != nil {
heapMax = e.keyChunks.GetRow(heapMaxPtr)
next = childKeyChk.GetRow(i)
} else {
heapMax = e.rowChunks.GetRow(heapMaxPtr)
next = childRowChk.GetRow(i)
}
heapMax = e.rowChunks.GetRow(heapMaxPtr)
next = childRowChk.GetRow(i)
if e.chkHeap.greaterRow(heapMax, next) {
// Evict heap max, keep the next row.
e.rowPtrs[0] = e.rowChunks.AppendRow(childRowChk.GetRow(i))
if childKeyChk != nil {
e.keyChunks.AppendRow(childKeyChk.GetRow(i))
}
heap.Fix(e.chkHeap, 0)
}
}
Expand All @@ -444,16 +348,6 @@ func (e *TopNExec) doCompaction() error {
e.memTracker.ReplaceChild(e.rowChunks.GetMemTracker(), newRowChunks.GetMemTracker())
e.rowChunks = newRowChunks

if e.keyChunks != nil {
newKeyChunks := chunk.NewList(e.keyTypes, e.initCap, e.maxChunkSize)
for _, rowPtr := range e.rowPtrs {
newKeyChunks.AppendRow(e.keyChunks.GetRow(rowPtr))
}
newKeyChunks.GetMemTracker().SetLabel("keyChunks")
e.memTracker.ReplaceChild(e.keyChunks.GetMemTracker(), newKeyChunks.GetMemTracker())
e.keyChunks = newKeyChunks
}

e.memTracker.Consume(int64(-8 * len(e.rowPtrs)))
e.memTracker.Consume(int64(8 * len(newRowPtrs)))
e.rowPtrs = newRowPtrs
Expand Down
4 changes: 2 additions & 2 deletions planner/core/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderSimpleCase(c *C) {
// Test TopN push down in table single read.
{
sql: "select c from t order by t.a + t.b limit 1",
best: "TableReader(Table(t)->TopN([plus(test.t.a, test.t.b)],0,1))->TopN([plus(test.t.a, test.t.b)],0,1)->Projection",
best: "TableReader(Table(t)->TopN([plus(test.t.a, test.t.b)],0,1))->Projection->TopN([col_3],0,1)->Projection",
},
// Test Limit push down in table single read.
{
Expand Down Expand Up @@ -1215,7 +1215,7 @@ func (s *testPlanSuite) TestAggEliminater(c *C) {
// If max/min contains scalar function, we can still do transformation.
{
sql: "select max(a+1) from t;",
best: "TableReader(Table(t)->Sel([not(isnull(plus(test.t.a, 1)))])->TopN([plus(test.t.a, 1) true],0,1))->TopN([plus(test.t.a, 1) true],0,1)->Projection->StreamAgg",
best: "TableReader(Table(t)->Sel([not(isnull(plus(test.t.a, 1)))])->TopN([plus(test.t.a, 1) true],0,1))->Projection->TopN([col_1 true],0,1)->Projection->StreamAgg",
},
// Do nothing to max+min.
{
Expand Down
95 changes: 86 additions & 9 deletions planner/core/rule_inject_extra_projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ import (
"github.com/pingcap/tidb/types"
)

// injectExtraProjection is used to extract the expressions of specific operators into a
// physical Projection operator and inject the Projection below the operators.
// Thus we can accelerate the expression evaluation by eager evaluation.
// injectExtraProjection is used to extract the expressions of specific
// operators into a physical Projection operator and inject the Projection below
// the operators. Thus we can accelerate the expression evaluation by eager
// evaluation.
func injectExtraProjection(plan PhysicalPlan) PhysicalPlan {
return NewProjInjector().inject(plan)
}
Expand All @@ -43,18 +44,23 @@ func (pe *projInjector) inject(plan PhysicalPlan) PhysicalPlan {
for _, child := range plan.Children() {
pe.inject(child)
}

switch p := plan.(type) {
case *PhysicalHashAgg:
plan = injectProjBelowAgg(plan, p.AggFuncs, p.GroupByItems)
case *PhysicalStreamAgg:
plan = injectProjBelowAgg(plan, p.AggFuncs, p.GroupByItems)
case *PhysicalSort:
plan = injectProjBelowSort(p, p.ByItems)
case *PhysicalTopN:
plan = injectProjBelowSort(p, p.ByItems)
}
return plan
}

// injectProjBelowAgg injects a ProjOperator below AggOperator.
// If all the args of `aggFuncs`, and all the item of `groupByItems`
// are columns or constants, we do not need to build the `proj`.
// injectProjBelowAgg injects a ProjOperator below AggOperator. If all the args
// of `aggFuncs`, and all the item of `groupByItems` are columns or constants,
// we do not need to build the `proj`.
func injectProjBelowAgg(aggPlan PhysicalPlan, aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression) PhysicalPlan {
hasScalarFunc := false

Expand Down Expand Up @@ -105,9 +111,10 @@ func injectProjBelowAgg(aggPlan PhysicalPlan, aggFuncs []*aggregation.AggFuncDes
}
projExprs = append(projExprs, item)
newArg := &expression.Column{
RetType: item.GetType(),
ColName: model.NewCIStr(fmt.Sprintf("col_%d", len(projSchemaCols))),
Index: cursor,
UniqueID: aggPlan.context().GetSessionVars().AllocPlanColumnID(),
RetType: item.GetType(),
ColName: model.NewCIStr(fmt.Sprintf("col_%d", len(projSchemaCols))),
Index: cursor,
}
projSchemaCols = append(projSchemaCols, newArg)
groupByItems[i] = newArg
Expand Down Expand Up @@ -174,3 +181,73 @@ func wrapCastForAggArgs(ctx sessionctx.Context, funcs []*aggregation.AggFuncDesc
}
}
}

// injectProjBelowSort extracts the ScalarFunctions of `orderByItems` into a
// PhysicalProjection and injects it below PhysicalTopN/PhysicalSort. The schema
// of PhysicalSort and PhysicalTopN are the same as the schema of their
// children. When a projection is injected as the child of PhysicalSort and
// PhysicalTopN, some extra columns will be added into the schema of the
// Projection, thus we need to add another Projection upon them to prune the
// redundant columns.
func injectProjBelowSort(p PhysicalPlan, orderByItems []*ByItems) PhysicalPlan {
hasScalarFunc, numOrderByItems := false, len(orderByItems)
for i := 0; !hasScalarFunc && i < numOrderByItems; i++ {
_, isScalarFunc := orderByItems[i].Expr.(*expression.ScalarFunction)
hasScalarFunc = hasScalarFunc || isScalarFunc
}
if !hasScalarFunc {
return p
}

topProjExprs := make([]expression.Expression, 0, p.Schema().Len())
for i := range p.Schema().Columns {
col := p.Schema().Columns[i]
col.Index = i
topProjExprs = append(topProjExprs, col)
}
topProj := PhysicalProjection{
Exprs: topProjExprs,
AvoidColumnEvaluator: false,
}.init(p.context(), p.statsInfo(), nil)
topProj.SetSchema(p.Schema().Clone())
topProj.SetChildren(p)

childPlan := p.Children()[0]
bottomProjSchemaCols := make([]*expression.Column, 0, len(childPlan.Schema().Columns)+numOrderByItems)
bottomProjExprs := make([]expression.Expression, 0, len(childPlan.Schema().Columns)+numOrderByItems)
for i, col := range childPlan.Schema().Columns {
col.Index = i
bottomProjSchemaCols = append(bottomProjSchemaCols, col)
bottomProjExprs = append(bottomProjExprs, col)
}

for _, item := range orderByItems {
itemExpr := item.Expr
if _, isScalarFunc := itemExpr.(*expression.ScalarFunction); !isScalarFunc {
continue
}
bottomProjExprs = append(bottomProjExprs, itemExpr)
newArg := &expression.Column{
UniqueID: p.context().GetSessionVars().AllocPlanColumnID(),
RetType: itemExpr.GetType(),
ColName: model.NewCIStr(fmt.Sprintf("col_%d", len(bottomProjSchemaCols))),
Index: len(bottomProjSchemaCols),
}
bottomProjSchemaCols = append(bottomProjSchemaCols, newArg)
item.Expr = newArg
}

childProp := p.getChildReqProps(0).Clone()
bottomProj := PhysicalProjection{
Exprs: bottomProjExprs,
AvoidColumnEvaluator: false,
}.init(p.context(), childPlan.statsInfo().ScaleByExpectCnt(childProp.ExpectedCnt), childProp)
bottomProj.SetSchema(expression.NewSchema(bottomProjSchemaCols...))
bottomProj.SetChildren(childPlan)
p.SetChildren(bottomProj)

if origChildProj, isChildProj := childPlan.(*PhysicalProjection); isChildProj {
refine4NeighbourProj(bottomProj, origChildProj)
}
return topProj
}