Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: support window function rank and dense_rank #9500

Merged
merged 2 commits into from
Feb 28, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,24 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal
return buildBitXor(aggFuncDesc, ordinal)
case ast.AggFuncBitAnd:
return buildBitAnd(aggFuncDesc, ordinal)
case ast.WindowFuncRowNumber:
return buildRowNumber(aggFuncDesc, ordinal)
}
return nil
}

// BuildWindowFunctions builds specific window function according to function description and order by columns.
func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.AggFuncDesc, ordinal int, orderByCols []*expression.Column) AggFunc {
switch windowFuncDesc.Name {
case ast.WindowFuncRank:
return buildRank(ordinal, orderByCols, false)
case ast.WindowFuncDenseRank:
return buildRank(ordinal, orderByCols, true)
case ast.WindowFuncRowNumber:
return buildRowNumber(windowFuncDesc, ordinal)
default:
return Build(ctx, windowFuncDesc, ordinal)
}
}

// buildCount builds the AggFunc implementation for function "COUNT".
func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
// If mode is DedupMode, we return nil for not implemented.
Expand Down Expand Up @@ -324,3 +336,15 @@ func buildRowNumber(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
}
return &rowNumber{base}
}

func buildRank(ordinal int, orderByCols []*expression.Column, isDense bool) AggFunc {
base := baseAggFunc{
ordinal: ordinal,
}
r := &rank{baseAggFunc: base, isDense: isDense}
for _, col := range orderByCols {
r.cmpFuncs = append(r.cmpFuncs, chunk.GetCompareFunc(col.RetType))
r.colIdx = append(r.colIdx, col.Index)
}
return r
}
80 changes: 80 additions & 0 deletions executor/aggfuncs/func_rank.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)

type rank struct {
baseAggFunc
isDense bool
cmpFuncs []chunk.CompareFunc
colIdx []int
}

type partialResult4Rank struct {
curIdx int64
lastRank int64
rows []chunk.Row
}

func (r *rank) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4Rank{})
}

func (r *rank) ResetPartialResult(pr PartialResult) {
p := (*partialResult4Rank)(pr)
p.curIdx = 0
p.lastRank = 0
p.rows = p.rows[:0]
}

func (r *rank) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4Rank)(pr)
p.rows = append(p.rows, rowsInGroup...)
return nil
}

func (r *rank) compareRows(prev, curr chunk.Row) int {
for i, idx := range r.colIdx {
res := r.cmpFuncs[i](prev, idx, curr, idx)
if res != 0 {
return res
}
}
return 0
}

func (r *rank) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4Rank)(pr)
p.curIdx++
if p.curIdx == 1 {
p.lastRank = 1
chk.AppendInt64(r.ordinal, p.lastRank)
return nil
}
if r.compareRows(p.rows[p.curIdx-2], p.rows[p.curIdx-1]) == 0 {
chk.AppendInt64(r.ordinal, p.lastRank)
return nil
}
if r.isDense {
p.lastRank++
} else {
p.lastRank = p.curIdx
}
chk.AppendInt64(r.ordinal, p.lastRank)
return nil
}
6 changes: 5 additions & 1 deletion executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1908,7 +1908,11 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec
}
aggDesc := aggregation.NewAggFuncDesc(b.ctx, v.WindowFuncDesc.Name, v.WindowFuncDesc.Args, false)
resultColIdx := len(v.Schema().Columns) - 1
agg := aggfuncs.Build(b.ctx, aggDesc, resultColIdx)
orderByCols := make([]*expression.Column, 0, len(v.OrderBy))
for _, item := range v.OrderBy {
orderByCols = append(orderByCols, item.Col)
}
agg := aggfuncs.BuildWindowFunctions(b.ctx, aggDesc, resultColIdx, orderByCols)
var processor windowProcessor
if v.Frame == nil {
processor = &aggWindowProcessor{
Expand Down
17 changes: 17 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,21 @@ func (s *testSuite2) TestWindowFunctions(c *C) {
result.Check(testkit.Rows("<nil> <nil> <nil>", "1 2019-02-01 6", "2 2019-02-02 6", "3 2019-02-03 10", "5 2019-02-05 5"))
result = tk.MustQuery("select a, b, sum(a) over(order by b desc range between interval 1 day preceding and interval 2 day following) from t")
result.Check(testkit.Rows("5 2019-02-05 8", "3 2019-02-03 6", "2 2019-02-02 6", "1 2019-02-01 3", "<nil> <nil> <nil>"))

tk.MustExec("drop table t")
tk.MustExec("create table t(a int, b int)")
tk.MustExec("insert into t values (1,1),(1,2),(2,1),(2,2)")
result = tk.MustQuery("select a, b, rank() over() from t")
result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 1", "2 2 1"))
result = tk.MustQuery("select a, b, rank() over(order by a) from t")
result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 3", "2 2 3"))
result = tk.MustQuery("select a, b, rank() over(order by a, b) from t")
result.Check(testkit.Rows("1 1 1", "1 2 2", "2 1 3", "2 2 4"))

result = tk.MustQuery("select a, b, dense_rank() over() from t")
result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 1", "2 2 1"))
result = tk.MustQuery("select a, b, dense_rank() over(order by a) from t")
result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 2", "2 2 2"))
result = tk.MustQuery("select a, b, dense_rank() over(order by a, b) from t")
result.Check(testkit.Rows("1 1 1", "1 2 2", "2 1 3", "2 2 4"))
}
6 changes: 3 additions & 3 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) {
a.typeInfer4MaxMin(ctx)
case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
a.typeInfer4BitFuncs(ctx)
case ast.WindowFuncRowNumber:
a.typeInfer4RowNumber()
case ast.WindowFuncRowNumber, ast.WindowFuncRank, ast.WindowFuncDenseRank:
a.typeInfer4NumberFuncs()
default:
panic("unsupported agg function: " + a.Name)
}
Expand Down Expand Up @@ -186,7 +186,7 @@ func (a *baseFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) {
// TODO: a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0])
}

func (a *baseFuncDesc) typeInfer4RowNumber() {
func (a *baseFuncDesc) typeInfer4NumberFuncs() {
a.RetTp = types.NewFieldType(mysql.TypeLonglong)
a.RetTp.Flen = 21
types.SetBinChsClnFlag(a.RetTp)
Expand Down