Skip to content

Commit

Permalink
executor: support window function rank and dense_rank (#9500)
Browse files Browse the repository at this point in the history
  • Loading branch information
alivxxx authored and zz-jason committed Feb 28, 2019
1 parent 8b84b94 commit 9259785
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 6 deletions.
28 changes: 26 additions & 2 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,24 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal
return buildBitXor(aggFuncDesc, ordinal)
case ast.AggFuncBitAnd:
return buildBitAnd(aggFuncDesc, ordinal)
case ast.WindowFuncRowNumber:
return buildRowNumber(aggFuncDesc, ordinal)
}
return nil
}

// BuildWindowFunctions builds specific window function according to function description and order by columns.
func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.AggFuncDesc, ordinal int, orderByCols []*expression.Column) AggFunc {
switch windowFuncDesc.Name {
case ast.WindowFuncRank:
return buildRank(ordinal, orderByCols, false)
case ast.WindowFuncDenseRank:
return buildRank(ordinal, orderByCols, true)
case ast.WindowFuncRowNumber:
return buildRowNumber(windowFuncDesc, ordinal)
default:
return Build(ctx, windowFuncDesc, ordinal)
}
}

// buildCount builds the AggFunc implementation for function "COUNT".
func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
// If mode is DedupMode, we return nil for not implemented.
Expand Down Expand Up @@ -324,3 +336,15 @@ func buildRowNumber(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
}
return &rowNumber{base}
}

func buildRank(ordinal int, orderByCols []*expression.Column, isDense bool) AggFunc {
base := baseAggFunc{
ordinal: ordinal,
}
r := &rank{baseAggFunc: base, isDense: isDense}
for _, col := range orderByCols {
r.cmpFuncs = append(r.cmpFuncs, chunk.GetCompareFunc(col.RetType))
r.colIdx = append(r.colIdx, col.Index)
}
return r
}
80 changes: 80 additions & 0 deletions executor/aggfuncs/func_rank.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)

type rank struct {
baseAggFunc
isDense bool
cmpFuncs []chunk.CompareFunc
colIdx []int
}

type partialResult4Rank struct {
curIdx int64
lastRank int64
rows []chunk.Row
}

func (r *rank) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4Rank{})
}

func (r *rank) ResetPartialResult(pr PartialResult) {
p := (*partialResult4Rank)(pr)
p.curIdx = 0
p.lastRank = 0
p.rows = p.rows[:0]
}

func (r *rank) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4Rank)(pr)
p.rows = append(p.rows, rowsInGroup...)
return nil
}

func (r *rank) compareRows(prev, curr chunk.Row) int {
for i, idx := range r.colIdx {
res := r.cmpFuncs[i](prev, idx, curr, idx)
if res != 0 {
return res
}
}
return 0
}

func (r *rank) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4Rank)(pr)
p.curIdx++
if p.curIdx == 1 {
p.lastRank = 1
chk.AppendInt64(r.ordinal, p.lastRank)
return nil
}
if r.compareRows(p.rows[p.curIdx-2], p.rows[p.curIdx-1]) == 0 {
chk.AppendInt64(r.ordinal, p.lastRank)
return nil
}
if r.isDense {
p.lastRank++
} else {
p.lastRank = p.curIdx
}
chk.AppendInt64(r.ordinal, p.lastRank)
return nil
}
6 changes: 5 additions & 1 deletion executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1918,7 +1918,11 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec
}
aggDesc := aggregation.NewAggFuncDesc(b.ctx, v.WindowFuncDesc.Name, v.WindowFuncDesc.Args, false)
resultColIdx := len(v.Schema().Columns) - 1
agg := aggfuncs.Build(b.ctx, aggDesc, resultColIdx)
orderByCols := make([]*expression.Column, 0, len(v.OrderBy))
for _, item := range v.OrderBy {
orderByCols = append(orderByCols, item.Col)
}
agg := aggfuncs.BuildWindowFunctions(b.ctx, aggDesc, resultColIdx, orderByCols)
var processor windowProcessor
if v.Frame == nil {
processor = &aggWindowProcessor{
Expand Down
17 changes: 17 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,21 @@ func (s *testSuite2) TestWindowFunctions(c *C) {
result.Check(testkit.Rows("<nil> <nil> <nil>", "1 2019-02-01 6", "2 2019-02-02 6", "3 2019-02-03 10", "5 2019-02-05 5"))
result = tk.MustQuery("select a, b, sum(a) over(order by b desc range between interval 1 day preceding and interval 2 day following) from t")
result.Check(testkit.Rows("5 2019-02-05 8", "3 2019-02-03 6", "2 2019-02-02 6", "1 2019-02-01 3", "<nil> <nil> <nil>"))

tk.MustExec("drop table t")
tk.MustExec("create table t(a int, b int)")
tk.MustExec("insert into t values (1,1),(1,2),(2,1),(2,2)")
result = tk.MustQuery("select a, b, rank() over() from t")
result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 1", "2 2 1"))
result = tk.MustQuery("select a, b, rank() over(order by a) from t")
result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 3", "2 2 3"))
result = tk.MustQuery("select a, b, rank() over(order by a, b) from t")
result.Check(testkit.Rows("1 1 1", "1 2 2", "2 1 3", "2 2 4"))

result = tk.MustQuery("select a, b, dense_rank() over() from t")
result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 1", "2 2 1"))
result = tk.MustQuery("select a, b, dense_rank() over(order by a) from t")
result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 2", "2 2 2"))
result = tk.MustQuery("select a, b, dense_rank() over(order by a, b) from t")
result.Check(testkit.Rows("1 1 1", "1 2 2", "2 1 3", "2 2 4"))
}
6 changes: 3 additions & 3 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) {
a.typeInfer4MaxMin(ctx)
case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
a.typeInfer4BitFuncs(ctx)
case ast.WindowFuncRowNumber:
a.typeInfer4RowNumber()
case ast.WindowFuncRowNumber, ast.WindowFuncRank, ast.WindowFuncDenseRank:
a.typeInfer4NumberFuncs()
default:
panic("unsupported agg function: " + a.Name)
}
Expand Down Expand Up @@ -186,7 +186,7 @@ func (a *baseFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) {
// TODO: a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0])
}

func (a *baseFuncDesc) typeInfer4RowNumber() {
func (a *baseFuncDesc) typeInfer4NumberFuncs() {
a.RetTp = types.NewFieldType(mysql.TypeLonglong)
a.RetTp.Flen = 21
types.SetBinChsClnFlag(a.RetTp)
Expand Down

0 comments on commit 9259785

Please sign in to comment.