diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index b01dcf8203598..6d98d946ac42e 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -52,12 +52,24 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal return buildBitXor(aggFuncDesc, ordinal) case ast.AggFuncBitAnd: return buildBitAnd(aggFuncDesc, ordinal) - case ast.WindowFuncRowNumber: - return buildRowNumber(aggFuncDesc, ordinal) } return nil } +// BuildWindowFunctions builds specific window function according to function description and order by columns. +func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.AggFuncDesc, ordinal int, orderByCols []*expression.Column) AggFunc { + switch windowFuncDesc.Name { + case ast.WindowFuncRank: + return buildRank(ordinal, orderByCols, false) + case ast.WindowFuncDenseRank: + return buildRank(ordinal, orderByCols, true) + case ast.WindowFuncRowNumber: + return buildRowNumber(windowFuncDesc, ordinal) + default: + return Build(ctx, windowFuncDesc, ordinal) + } +} + // buildCount builds the AggFunc implementation for function "COUNT". func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { // If mode is DedupMode, we return nil for not implemented. @@ -324,3 +336,15 @@ func buildRowNumber(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { } return &rowNumber{base} } + +func buildRank(ordinal int, orderByCols []*expression.Column, isDense bool) AggFunc { + base := baseAggFunc{ + ordinal: ordinal, + } + r := &rank{baseAggFunc: base, isDense: isDense} + for _, col := range orderByCols { + r.cmpFuncs = append(r.cmpFuncs, chunk.GetCompareFunc(col.RetType)) + r.colIdx = append(r.colIdx, col.Index) + } + return r +} diff --git a/executor/aggfuncs/func_rank.go b/executor/aggfuncs/func_rank.go new file mode 100644 index 0000000000000..e73c46e3c45ec --- /dev/null +++ b/executor/aggfuncs/func_rank.go @@ -0,0 +1,80 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" +) + +type rank struct { + baseAggFunc + isDense bool + cmpFuncs []chunk.CompareFunc + colIdx []int +} + +type partialResult4Rank struct { + curIdx int64 + lastRank int64 + rows []chunk.Row +} + +func (r *rank) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4Rank{}) +} + +func (r *rank) ResetPartialResult(pr PartialResult) { + p := (*partialResult4Rank)(pr) + p.curIdx = 0 + p.lastRank = 0 + p.rows = p.rows[:0] +} + +func (r *rank) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4Rank)(pr) + p.rows = append(p.rows, rowsInGroup...) + return nil +} + +func (r *rank) compareRows(prev, curr chunk.Row) int { + for i, idx := range r.colIdx { + res := r.cmpFuncs[i](prev, idx, curr, idx) + if res != 0 { + return res + } + } + return 0 +} + +func (r *rank) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4Rank)(pr) + p.curIdx++ + if p.curIdx == 1 { + p.lastRank = 1 + chk.AppendInt64(r.ordinal, p.lastRank) + return nil + } + if r.compareRows(p.rows[p.curIdx-2], p.rows[p.curIdx-1]) == 0 { + chk.AppendInt64(r.ordinal, p.lastRank) + return nil + } + if r.isDense { + p.lastRank++ + } else { + p.lastRank = p.curIdx + } + chk.AppendInt64(r.ordinal, p.lastRank) + return nil +} diff --git a/executor/builder.go b/executor/builder.go index 8e94f78a30552..f698519324c8f 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -1918,7 +1918,11 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec } aggDesc := aggregation.NewAggFuncDesc(b.ctx, v.WindowFuncDesc.Name, v.WindowFuncDesc.Args, false) resultColIdx := len(v.Schema().Columns) - 1 - agg := aggfuncs.Build(b.ctx, aggDesc, resultColIdx) + orderByCols := make([]*expression.Column, 0, len(v.OrderBy)) + for _, item := range v.OrderBy { + orderByCols = append(orderByCols, item.Col) + } + agg := aggfuncs.BuildWindowFunctions(b.ctx, aggDesc, resultColIdx, orderByCols) var processor windowProcessor if v.Frame == nil { processor = &aggWindowProcessor{ diff --git a/executor/window_test.go b/executor/window_test.go index e97234f994593..3597689ad6008 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -64,4 +64,21 @@ func (s *testSuite2) TestWindowFunctions(c *C) { result.Check(testkit.Rows(" ", "1 2019-02-01 6", "2 2019-02-02 6", "3 2019-02-03 10", "5 2019-02-05 5")) result = tk.MustQuery("select a, b, sum(a) over(order by b desc range between interval 1 day preceding and interval 2 day following) from t") result.Check(testkit.Rows("5 2019-02-05 8", "3 2019-02-03 6", "2 2019-02-02 6", "1 2019-02-01 3", " ")) + + tk.MustExec("drop table t") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("insert into t values (1,1),(1,2),(2,1),(2,2)") + result = tk.MustQuery("select a, b, rank() over() from t") + result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 1", "2 2 1")) + result = tk.MustQuery("select a, b, rank() over(order by a) from t") + result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 3", "2 2 3")) + result = tk.MustQuery("select a, b, rank() over(order by a, b) from t") + result.Check(testkit.Rows("1 1 1", "1 2 2", "2 1 3", "2 2 4")) + + result = tk.MustQuery("select a, b, dense_rank() over() from t") + result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 1", "2 2 1")) + result = tk.MustQuery("select a, b, dense_rank() over(order by a) from t") + result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 2", "2 2 2")) + result = tk.MustQuery("select a, b, dense_rank() over(order by a, b) from t") + result.Check(testkit.Rows("1 1 1", "1 2 2", "2 1 3", "2 2 4")) } diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 2f241bca2a1ee..717bdf2b9613f 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -95,8 +95,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { a.typeInfer4MaxMin(ctx) case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: a.typeInfer4BitFuncs(ctx) - case ast.WindowFuncRowNumber: - a.typeInfer4RowNumber() + case ast.WindowFuncRowNumber, ast.WindowFuncRank, ast.WindowFuncDenseRank: + a.typeInfer4NumberFuncs() default: panic("unsupported agg function: " + a.Name) } @@ -186,7 +186,7 @@ func (a *baseFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) { // TODO: a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0]) } -func (a *baseFuncDesc) typeInfer4RowNumber() { +func (a *baseFuncDesc) typeInfer4NumberFuncs() { a.RetTp = types.NewFieldType(mysql.TypeLonglong) a.RetTp.Flen = 21 types.SetBinChsClnFlag(a.RetTp)