diff --git a/executor/builder.go b/executor/builder.go index f0c5a9efd9328..e04a51245b824 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -2114,25 +2114,32 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec for _, item := range v.PartitionBy { groupByItems = append(groupByItems, item.Col) } - aggDesc := aggregation.NewAggFuncDesc(b.ctx, v.WindowFuncDesc.Name, v.WindowFuncDesc.Args, false) - resultColIdx := len(v.Schema().Columns) - 1 orderByCols := make([]*expression.Column, 0, len(v.OrderBy)) for _, item := range v.OrderBy { orderByCols = append(orderByCols, item.Col) } - agg := aggfuncs.BuildWindowFunctions(b.ctx, aggDesc, resultColIdx, orderByCols) + windowFuncs := make([]aggfuncs.AggFunc, 0, len(v.WindowFuncDescs)) + partialResults := make([]aggfuncs.PartialResult, 0, len(v.WindowFuncDescs)) + resultColIdx := v.Schema().Len() - len(v.WindowFuncDescs) + for _, desc := range v.WindowFuncDescs { + aggDesc := aggregation.NewAggFuncDesc(b.ctx, desc.Name, desc.Args, false) + agg := aggfuncs.BuildWindowFunctions(b.ctx, aggDesc, resultColIdx, orderByCols) + windowFuncs = append(windowFuncs, agg) + partialResults = append(partialResults, agg.AllocPartialResult()) + resultColIdx++ + } var processor windowProcessor if v.Frame == nil { processor = &aggWindowProcessor{ - windowFunc: agg, - partialResult: agg.AllocPartialResult(), + windowFuncs: windowFuncs, + partialResults: partialResults, } } else if v.Frame.Type == ast.Rows { processor = &rowFrameWindowProcessor{ - windowFunc: agg, - partialResult: agg.AllocPartialResult(), - start: v.Frame.Start, - end: v.Frame.End, + windowFuncs: windowFuncs, + partialResults: partialResults, + start: v.Frame.Start, + end: v.Frame.End, } } else { cmpResult := int64(-1) @@ -2140,8 +2147,8 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec cmpResult = 1 } processor = &rangeFrameWindowProcessor{ - windowFunc: agg, - partialResult: agg.AllocPartialResult(), + windowFuncs: windowFuncs, + partialResults: partialResults, start: v.Frame.Start, end: v.Frame.End, orderByCols: orderByCols, @@ -2149,8 +2156,9 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec } } return &WindowExec{baseExecutor: base, - processor: processor, - groupChecker: newGroupChecker(b.ctx.GetSessionVars().StmtCtx, groupByItems), + processor: processor, + groupChecker: newGroupChecker(b.ctx.GetSessionVars().StmtCtx, groupByItems), + numWindowFuncs: len(v.WindowFuncDescs), } } diff --git a/executor/window.go b/executor/window.go index 5451acd45bf1d..bf4e5a2dab0b1 100644 --- a/executor/window.go +++ b/executor/window.go @@ -41,6 +41,7 @@ type WindowExec struct { meetNewGroup bool remainingRowsInGroup int remainingRowsInChunk int + numWindowFuncs int processor windowProcessor } @@ -171,7 +172,7 @@ func (e *WindowExec) copyChk(chk *chunk.Chunk) { childResult := e.childResults[0] e.childResults = e.childResults[1:] e.remainingRowsInChunk = childResult.NumRows() - columns := e.Schema().Columns[:len(e.Schema().Columns)-1] + columns := e.Schema().Columns[:len(e.Schema().Columns)-e.numWindowFuncs] for i, col := range columns { chk.MakeRefTo(i, childResult, col.Index) } @@ -190,22 +191,29 @@ type windowProcessor interface { } type aggWindowProcessor struct { - windowFunc aggfuncs.AggFunc - partialResult aggfuncs.PartialResult + windowFuncs []aggfuncs.AggFunc + partialResults []aggfuncs.PartialResult } func (p *aggWindowProcessor) consumeGroupRows(ctx sessionctx.Context, rows []chunk.Row) ([]chunk.Row, error) { - err := p.windowFunc.UpdatePartialResult(ctx, rows, p.partialResult) + for i, windowFunc := range p.windowFuncs { + err := windowFunc.UpdatePartialResult(ctx, rows, p.partialResults[i]) + if err != nil { + return nil, err + } + } rows = rows[:0] - return rows, err + return rows, nil } func (p *aggWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, rows []chunk.Row, chk *chunk.Chunk, remained int) ([]chunk.Row, error) { for remained > 0 { - // TODO: We can extend the agg func interface to avoid the `for` loop here. - err := p.windowFunc.AppendFinalResult2Chunk(ctx, p.partialResult, chk) - if err != nil { - return rows, err + for i, windowFunc := range p.windowFuncs { + // TODO: We can extend the agg func interface to avoid the `for` loop here. + err := windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk) + if err != nil { + return nil, err + } } remained-- } @@ -213,15 +221,17 @@ func (p *aggWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, rows []c } func (p *aggWindowProcessor) resetPartialResult() { - p.windowFunc.ResetPartialResult(p.partialResult) + for i, windowFunc := range p.windowFuncs { + windowFunc.ResetPartialResult(p.partialResults[i]) + } } type rowFrameWindowProcessor struct { - windowFunc aggfuncs.AggFunc - partialResult aggfuncs.PartialResult - start *core.FrameBound - end *core.FrameBound - curRowIdx uint64 + windowFuncs []aggfuncs.AggFunc + partialResults []aggfuncs.PartialResult + start *core.FrameBound + end *core.FrameBound + curRowIdx uint64 } func (p *rowFrameWindowProcessor) getStartOffset(numRows uint64) uint64 { @@ -283,33 +293,36 @@ func (p *rowFrameWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, row p.curRowIdx++ remained-- if start >= end { - err := p.windowFunc.AppendFinalResult2Chunk(ctx, p.partialResult, chk) - if err != nil { - return nil, err + for i, windowFunc := range p.windowFuncs { + err := windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk) + if err != nil { + return nil, err + } } continue } - err := p.windowFunc.UpdatePartialResult(ctx, rows[start:end], p.partialResult) - if err != nil { - return nil, err - } - err = p.windowFunc.AppendFinalResult2Chunk(ctx, p.partialResult, chk) - if err != nil { - return nil, err + for i, windowFunc := range p.windowFuncs { + err := windowFunc.UpdatePartialResult(ctx, rows[start:end], p.partialResults[i]) + if err != nil { + return nil, err + } + err = windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk) + if err != nil { + return nil, err + } + windowFunc.ResetPartialResult(p.partialResults[i]) } - p.windowFunc.ResetPartialResult(p.partialResult) } return rows, nil } func (p *rowFrameWindowProcessor) resetPartialResult() { - p.windowFunc.ResetPartialResult(p.partialResult) p.curRowIdx = 0 } type rangeFrameWindowProcessor struct { - windowFunc aggfuncs.AggFunc - partialResult aggfuncs.PartialResult + windowFuncs []aggfuncs.AggFunc + partialResults []aggfuncs.PartialResult start *core.FrameBound end *core.FrameBound curRowIdx uint64 @@ -385,21 +398,25 @@ func (p *rangeFrameWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, r p.curRowIdx++ remained-- if start >= end { - err := p.windowFunc.AppendFinalResult2Chunk(ctx, p.partialResult, chk) - if err != nil { - return nil, err + for i, windowFunc := range p.windowFuncs { + err := windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk) + if err != nil { + return nil, err + } } continue } - err = p.windowFunc.UpdatePartialResult(ctx, rows[start:end], p.partialResult) - if err != nil { - return nil, err - } - err = p.windowFunc.AppendFinalResult2Chunk(ctx, p.partialResult, chk) - if err != nil { - return nil, err + for i, windowFunc := range p.windowFuncs { + err := windowFunc.UpdatePartialResult(ctx, rows[start:end], p.partialResults[i]) + if err != nil { + return nil, err + } + err = windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk) + if err != nil { + return nil, err + } + windowFunc.ResetPartialResult(p.partialResults[i]) } - p.windowFunc.ResetPartialResult(p.partialResult) } return rows, nil } @@ -409,7 +426,6 @@ func (p *rangeFrameWindowProcessor) consumeGroupRows(ctx sessionctx.Context, row } func (p *rangeFrameWindowProcessor) resetPartialResult() { - p.windowFunc.ResetPartialResult(p.partialResult) p.curRowIdx = 0 p.lastStartOffset = 0 p.lastEndOffset = 0 diff --git a/executor/window_test.go b/executor/window_test.go index a7b04f0b7f01f..357c6ff2c785e 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -159,4 +159,11 @@ func (s *testSuite4) TestWindowFunctions(c *C) { "5 2013-01-01 00:00:00 15", ), ) + + result = tk.MustQuery("select sum(a) over w, sum(b) over w from t window w as (order by a)") + result.Check(testkit.Rows("2 3", "2 3", "6 6", "6 6")) + result = tk.MustQuery("select row_number() over w, sum(b) over w from t window w as (order by a)") + result.Check(testkit.Rows("1 3", "2 3", "3 6", "4 6")) + result = tk.MustQuery("select row_number() over w, sum(b) over w from t window w as (rows between 1 preceding and 1 following)") + result.Check(testkit.Rows("1 3", "2 4", "3 5", "4 3")) } diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index 929fb31ba511b..e1eecf0766540 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -1158,10 +1158,10 @@ func (p *LogicalWindow) exhaustPhysicalPlans(prop *property.PhysicalProperty) [] return nil } window := PhysicalWindow{ - WindowFuncDesc: p.WindowFuncDesc, - PartitionBy: p.PartitionBy, - OrderBy: p.OrderBy, - Frame: p.Frame, + WindowFuncDescs: p.WindowFuncDescs, + PartitionBy: p.PartitionBy, + OrderBy: p.OrderBy, + Frame: p.Frame, }.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProperty) window.SetSchema(p.Schema()) return []PhysicalPlan{window} diff --git a/planner/core/explain.go b/planner/core/explain.go index a70ed67776ba3..3b3a027fbe606 100644 --- a/planner/core/explain.go +++ b/planner/core/explain.go @@ -323,7 +323,8 @@ func (p *PhysicalWindow) formatFrameBound(buffer *bytes.Buffer, bound *FrameBoun // ExplainInfo implements PhysicalPlan interface. func (p *PhysicalWindow) ExplainInfo() string { - buffer := bytes.NewBufferString(p.WindowFuncDesc.String()) + buffer := bytes.NewBufferString("") + formatWindowFuncDescs(buffer, p.WindowFuncDescs) buffer.WriteString(" over(") isFirst := true if len(p.PartitionBy) > 0 { @@ -370,3 +371,13 @@ func (p *PhysicalWindow) ExplainInfo() string { buffer.WriteString(")") return buffer.String() } + +func formatWindowFuncDescs(buffer *bytes.Buffer, descs []*aggregation.WindowFuncDesc) *bytes.Buffer { + for i, desc := range descs { + if i != 0 { + buffer.WriteString(", ") + } + buffer.WriteString(desc.String()) + } + return buffer +} diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 8383b18824b0c..91ea45fca7a26 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -29,7 +29,7 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" - driver "github.com/pingcap/tidb/types/parser_driver" + "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/stringutil" ) @@ -82,14 +82,14 @@ func (b *PlanBuilder) rewriteInsertOnDuplicateUpdate(exprNode ast.ExprNode, mock // asScalar means whether this expression must be treated as a scalar expression. // And this function returns a result expression, a new plan that may have apply or semi-join. func (b *PlanBuilder) rewrite(exprNode ast.ExprNode, p LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, asScalar bool) (expression.Expression, LogicalPlan, error) { - expr, resultPlan, err := b.rewriteWithPreprocess(exprNode, p, aggMapper, asScalar, nil) + expr, resultPlan, err := b.rewriteWithPreprocess(exprNode, p, aggMapper, nil, asScalar, nil) return expr, resultPlan, err } // rewriteWithPreprocess is for handling the situation that we need to adjust the input ast tree // before really using its node in `expressionRewriter.Leave`. In that case, we first call // er.preprocess(expr), which returns a new expr. Then we use the new expr in `Leave`. -func (b *PlanBuilder) rewriteWithPreprocess(exprNode ast.ExprNode, p LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, asScalar bool, preprocess func(ast.Node) ast.Node) (expression.Expression, LogicalPlan, error) { +func (b *PlanBuilder) rewriteWithPreprocess(exprNode ast.ExprNode, p LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int, asScalar bool, preprocess func(ast.Node) ast.Node) (expression.Expression, LogicalPlan, error) { b.rewriterCounter++ defer func() { b.rewriterCounter-- }() @@ -103,6 +103,7 @@ func (b *PlanBuilder) rewriteWithPreprocess(exprNode ast.ExprNode, p LogicalPlan } rewriter.aggrMap = aggMapper + rewriter.windowMap = windowMapper rewriter.asScalar = asScalar rewriter.preprocess = preprocess @@ -153,13 +154,14 @@ func (b *PlanBuilder) rewriteExprNode(rewriter *expressionRewriter, exprNode ast } type expressionRewriter struct { - ctxStack []expression.Expression - p LogicalPlan - schema *expression.Schema - err error - aggrMap map[*ast.AggregateFuncExpr]int - b *PlanBuilder - ctx sessionctx.Context + ctxStack []expression.Expression + p LogicalPlan + schema *expression.Schema + err error + aggrMap map[*ast.AggregateFuncExpr]int + windowMap map[*ast.WindowFuncExpr]int + b *PlanBuilder + ctx sessionctx.Context // asScalar indicates the return value must be a scalar value. // NOTE: This value can be changed during expression rewritten. @@ -315,7 +317,16 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { er.ctxStack = append(er.ctxStack, expression.NewValuesFunc(er.ctx, col.Index, col.RetType)) return inNode, true case *ast.WindowFuncExpr: - return er.handleWindowFunction(v) + index, ok := -1, false + if er.windowMap != nil { + index, ok = er.windowMap[v] + } + if !ok { + er.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(v.F) + return inNode, true + } + er.ctxStack = append(er.ctxStack, er.schema.Columns[index]) + return inNode, true case *ast.FuncCallExpr: if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok { er.disableFoldCounter++ @@ -326,17 +337,6 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { return inNode, false } -func (er *expressionRewriter) handleWindowFunction(v *ast.WindowFuncExpr) (ast.Node, bool) { - windowPlan, err := er.b.buildWindowFunction(er.p, v, er.aggrMap) - if err != nil { - er.err = err - return v, false - } - er.ctxStack = append(er.ctxStack, windowPlan.GetWindowResultColumn()) - er.p = windowPlan - return v, true -} - func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r expression.Expression, not bool) { var condition expression.Expression if rCol, ok := r.(*expression.Column); ok && (er.asScalar || not) { diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 55aa289940c70..f4c9ccb1a53ee 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -694,7 +694,7 @@ func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectFi } // buildProjection returns a Projection plan and non-aux columns length. -func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int, considerWindow bool) (LogicalPlan, int, error) { +func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int, considerWindow bool) (LogicalPlan, int, error) { b.optFlag |= flagEliminateProjection b.curClause = fieldList proj := LogicalProjection{Exprs: make([]expression.Expression, 0, len(fields))}.Init(b.ctx) @@ -726,11 +726,19 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, schema.Append(col) continue } - newExpr, np, err := b.rewrite(field.Expr, p, mapper, true) + newExpr, np, err := b.rewriteWithPreprocess(field.Expr, p, mapper, windowMapper, true, nil) if err != nil { return nil, 0, err } + // For window functions in the order by clause, we will append an field for it. + // We need rewrite the window mapper here so order by clause could find the added field. + if considerWindow && isWindowFuncField && field.Auxiliary { + if windowExpr, ok := field.Expr.(*ast.WindowFuncExpr); ok { + windowMapper[windowExpr] = i + } + } + p = np proj.Exprs = append(proj.Exprs, newExpr) @@ -852,7 +860,7 @@ func (b *PlanBuilder) buildUnion(union *ast.UnionStmt) (LogicalPlan, error) { oldLen := unionPlan.Schema().Len() if union.OrderBy != nil { - unionPlan, err = b.buildSort(unionPlan, union.OrderBy.Items, nil) + unionPlan, err = b.buildSort(unionPlan, union.OrderBy.Items, nil, nil) if err != nil { return nil, err } @@ -957,7 +965,7 @@ func (t *itemTransformer) Leave(inNode ast.Node) (ast.Node, bool) { return inNode, false } -func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper map[*ast.AggregateFuncExpr]int) (*LogicalSort, error) { +func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int) (*LogicalSort, error) { if _, isUnion := p.(*LogicalUnionAll); isUnion { b.curClause = globalOrderByClause } else { @@ -969,7 +977,7 @@ func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper for _, item := range byItems { newExpr, _ := item.Expr.Accept(transformer) item.Expr = newExpr.(ast.ExprNode) - it, np, err := b.rewrite(item.Expr, p, aggMapper, true) + it, np, err := b.rewriteWithPreprocess(item.Expr, p, aggMapper, windowMapper, true, nil) if err != nil { return nil, err } @@ -1210,6 +1218,13 @@ func (a *havingWindowAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, o a.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(v.F) return node, false } + if a.curClause == orderByClause { + a.selectFields = append(a.selectFields, &ast.SelectField{ + Auxiliary: true, + Expr: v, + AsName: model.NewCIStr(fmt.Sprintf("sel_window_%d", len(a.selectFields))), + }) + } case *ast.WindowSpec: a.inWindowSpec = false case *ast.ColumnNameExpr: @@ -1315,6 +1330,9 @@ func (b *PlanBuilder) resolveHavingAndOrderBy(sel *ast.SelectStmt, p LogicalPlan if sel.OrderBy != nil { extractor.curClause = orderByClause for _, item := range sel.OrderBy.Items { + if ast.HasWindowFlag(item.Expr) { + continue + } n, ok := item.Expr.Accept(extractor) if !ok { return nil, nil, errors.Trace(extractor.err) @@ -1368,6 +1386,19 @@ func (b *PlanBuilder) resolveWindowFunction(sel *ast.SelectStmt, p LogicalPlan) return nil, extractor.err } } + if sel.OrderBy != nil { + extractor.curClause = orderByClause + for _, item := range sel.OrderBy.Items { + if !ast.HasWindowFlag(item.Expr) { + continue + } + n, ok := item.Expr.Accept(extractor) + if !ok { + return nil, extractor.err + } + item.Expr = n.(ast.ExprNode) + } + } sel.Fields.Fields = extractor.selectFields return extractor.aggMapper, nil } @@ -1939,7 +1970,7 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error var ( aggFuncs []*ast.AggregateFuncExpr havingMap, orderMap, totalMap map[*ast.AggregateFuncExpr]int - windowMap map[*ast.AggregateFuncExpr]int + windowAggMap map[*ast.AggregateFuncExpr]int gbyCols []expression.Expression ) @@ -1974,7 +2005,7 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error hasWindowFuncField := b.detectSelectWindow(sel) if hasWindowFuncField { - windowMap, err = b.resolveWindowFunction(sel, p) + windowAggMap, err = b.resolveWindowFunction(sel, p) if err != nil { return nil, err } @@ -2014,7 +2045,7 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error var oldLen int // According to https://dev.mysql.com/doc/refman/8.0/en/window-functions-usage.html, // we can only process window functions after having clause, so `considerWindow` is false now. - p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, totalMap, false) + p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, totalMap, nil, false) if err != nil { return nil, err } @@ -2032,9 +2063,19 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error return nil, err } + var windowMapper map[*ast.WindowFuncExpr]int if hasWindowFuncField { + windowFuncs := extractWindowFuncs(sel.Fields.Fields) + groupedFuncs, err := b.groupWindowFuncs(windowFuncs) + if err != nil { + return nil, err + } + p, windowMapper, err = b.buildWindowFunctions(p, groupedFuncs, windowAggMap) + if err != nil { + return nil, err + } // Now we build the window function fields. - p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, windowMap, true) + p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, windowAggMap, windowMapper, true) if err != nil { return nil, err } @@ -2045,7 +2086,7 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } if sel.OrderBy != nil { - p, err = b.buildSort(p, sel.OrderBy.Items, orderMap) + p, err = b.buildSort(p, sel.OrderBy.Items, orderMap, windowMapper) if err != nil { return nil, err } @@ -2503,7 +2544,7 @@ func (b *PlanBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { } } if sel.OrderBy != nil { - p, err = b.buildSort(p, sel.OrderBy.Items, nil) + p, err = b.buildSort(p, sel.OrderBy.Items, nil, nil) if err != nil { return nil, err } @@ -2599,7 +2640,7 @@ func (b *PlanBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A return expr } } - newExpr, np, err = b.rewriteWithPreprocess(assign.Expr, p, nil, false, rewritePreprocess) + newExpr, np, err = b.rewriteWithPreprocess(assign.Expr, p, nil, nil, false, rewritePreprocess) } if err != nil { return nil, nil, err @@ -2698,7 +2739,7 @@ func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { } if sel.OrderBy != nil { - p, err = b.buildSort(p, sel.OrderBy.Items, nil) + p, err = b.buildSort(p, sel.OrderBy.Items, nil, nil) if err != nil { return nil, err } @@ -2805,28 +2846,9 @@ func getWindowName(name string) string { // buildProjectionForWindow builds the projection for expressions in the window specification that is not an column, // so after the projection, window functions only needs to deal with columns. -func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, expr *ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, []property.Item, []property.Item, []expression.Expression, error) { +func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, spec *ast.WindowSpec, args []ast.ExprNode, aggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, []property.Item, []property.Item, []expression.Expression, error) { b.optFlag |= flagEliminateProjection - if expr.Spec.Name.L != "" { - ref, ok := b.windowSpecs[expr.Spec.Name.L] - if !ok { - return nil, nil, nil, nil, ErrWindowNoSuchWindow.GenWithStackByArgs(expr.Spec.Name.O) - } - expr.Spec = ref - } - spec := expr.Spec - if spec.Ref.L != "" { - ref, ok := b.windowSpecs[spec.Ref.L] - if !ok { - return nil, nil, nil, nil, ErrWindowNoSuchWindow.GenWithStackByArgs(spec.Ref.O) - } - err := mergeWindowSpec(&spec, &ref) - if err != nil { - return nil, nil, nil, nil, err - } - } - var partitionItems, orderItems []*ast.ByItem if spec.PartitionBy != nil { partitionItems = spec.PartitionBy.Items @@ -2835,7 +2857,7 @@ func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, expr *ast.WindowFu orderItems = spec.OrderBy.Items } - projLen := len(p.Schema().Columns) + len(partitionItems) + len(orderItems) + len(expr.Args) + projLen := len(p.Schema().Columns) + len(partitionItems) + len(orderItems) + len(args) proj := LogicalProjection{Exprs: make([]expression.Expression, 0, projLen)}.Init(b.ctx) proj.SetSchema(expression.NewSchema(make([]*expression.Column, 0, projLen)...)) for _, col := range p.Schema().Columns { @@ -2855,8 +2877,8 @@ func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, expr *ast.WindowFu return nil, nil, nil, nil, err } - newArgList := make([]expression.Expression, 0, len(expr.Args)) - for _, arg := range expr.Args { + newArgList := make([]expression.Expression, 0, len(args)) + for _, arg := range args { newArg, np, err := b.rewrite(arg, p, aggMap, true) if err != nil { return nil, nil, nil, nil, err @@ -3076,62 +3098,138 @@ func (b *PlanBuilder) buildWindowFunctionFrame(spec *ast.WindowSpec, orderByItem return frame, err } -func (b *PlanBuilder) buildWindowFunction(p LogicalPlan, expr *ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (*LogicalWindow, error) { - p, partitionBy, orderBy, args, err := b.buildProjectionForWindow(p, expr, aggMap) - if err != nil { - return nil, err +func (b *PlanBuilder) buildWindowFunctions(p LogicalPlan, groupedFuncs map[*ast.WindowSpec][]*ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, map[*ast.WindowFuncExpr]int, error) { + args := make([]ast.ExprNode, 0, 4) + windowMap := make(map[*ast.WindowFuncExpr]int) + for spec, funcs := range groupedFuncs { + args = args[:0] + for _, windowFunc := range funcs { + args = append(args, windowFunc.Args...) + } + np, partitionBy, orderBy, args, err := b.buildProjectionForWindow(p, spec, args, aggMap) + if err != nil { + return nil, nil, err + } + frame, err := b.buildWindowFunctionFrame(spec, orderBy) + if err != nil { + return nil, nil, err + } + + window := LogicalWindow{ + PartitionBy: partitionBy, + OrderBy: orderBy, + Frame: frame, + }.Init(b.ctx) + schema := np.Schema().Clone() + descs := make([]*aggregation.WindowFuncDesc, 0, len(funcs)) + preArgs := 0 + for _, windowFunc := range funcs { + desc := aggregation.NewWindowFuncDesc(b.ctx, windowFunc.F, args[preArgs:preArgs+len(windowFunc.Args)]) + if desc == nil { + return nil, nil, ErrWrongArguments.GenWithStackByArgs(windowFunc.F) + } + preArgs += len(windowFunc.Args) + desc.WrapCastForAggArgs(b.ctx) + descs = append(descs, desc) + windowMap[windowFunc] = schema.Len() + schema.Append(&expression.Column{ + ColName: model.NewCIStr(fmt.Sprintf("%d_window_%d", window.id, schema.Len())), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + IsReferenced: true, + RetType: desc.RetTp, + }) + } + window.WindowFuncDescs = descs + window.SetChildren(np) + window.SetSchema(schema) + p = window } + return p, windowMap, nil +} - needFrame := aggregation.NeedFrame(expr.F) +func extractWindowFuncs(fields []*ast.SelectField) []*ast.WindowFuncExpr { + extractor := &WindowFuncExtractor{} + for _, f := range fields { + n, _ := f.Expr.Accept(extractor) + f.Expr = n.(ast.ExprNode) + } + return extractor.windowFuncs +} + +func (b *PlanBuilder) handleDefaultFrame(spec *ast.WindowSpec, name string) (*ast.WindowSpec, bool) { + needFrame := aggregation.NeedFrame(name) // According to MySQL, In the absence of a frame clause, the default frame depends on whether an ORDER BY clause is present: // (1) With order by, the default frame is equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"; // (2) Without order by, the default frame is equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", // which is the same as an empty frame. - if needFrame && expr.Spec.Frame == nil && len(orderBy) > 0 { - expr.Spec.Frame = &ast.FrameClause{ + if needFrame && spec.Frame == nil && spec.OrderBy != nil { + newSpec := *spec + newSpec.Frame = &ast.FrameClause{ Type: ast.Ranges, Extent: ast.FrameExtent{ Start: ast.FrameBound{Type: ast.Preceding, UnBounded: true}, End: ast.FrameBound{Type: ast.CurrentRow}, }, } + return &newSpec, true } // For functions that operate on the entire partition, the frame clause will be ignored. - if !needFrame && expr.Spec.Frame != nil { - b.ctx.GetSessionVars().StmtCtx.AppendNote(ErrWindowFunctionIgnoresFrame.GenWithStackByArgs(expr.F, getWindowName(expr.Spec.Name.O))) - expr.Spec.Frame = nil - } - frame, err := b.buildWindowFunctionFrame(&expr.Spec, orderBy) - if err != nil { - return nil, err + if !needFrame && spec.Frame != nil { + specName := spec.Name.O + b.ctx.GetSessionVars().StmtCtx.AppendNote(ErrWindowFunctionIgnoresFrame.GenWithStackByArgs(name, specName)) + newSpec := *spec + newSpec.Frame = nil + return &newSpec, true + } + return spec, false +} + +// groupWindowFuncs groups the window functions according to the window specification name. +// TODO: We can group the window function by the definition of window specification. +func (b *PlanBuilder) groupWindowFuncs(windowFuncs []*ast.WindowFuncExpr) (map[*ast.WindowSpec][]*ast.WindowFuncExpr, error) { + // updatedSpecMap is used to handle the specifications that have frame clause changed. + updatedSpecMap := make(map[string]*ast.WindowSpec) + groupedWindow := make(map[*ast.WindowSpec][]*ast.WindowFuncExpr) + for _, windowFunc := range windowFuncs { + if windowFunc.Spec.Name.L == "" { + spec := &windowFunc.Spec + if spec.Ref.L != "" { + ref, ok := b.windowSpecs[spec.Ref.L] + if !ok { + return nil, ErrWindowNoSuchWindow.GenWithStackByArgs(getWindowName(spec.Ref.O)) + } + err := mergeWindowSpec(spec, ref) + if err != nil { + return nil, err + } + } + spec, _ = b.handleDefaultFrame(spec, windowFunc.F) + groupedWindow[spec] = append(groupedWindow[spec], windowFunc) + continue + } + + name := windowFunc.Spec.Name.L + spec, ok := b.windowSpecs[name] + if !ok { + return nil, ErrWindowNoSuchWindow.GenWithStackByArgs(windowFunc.Spec.Name.O) + } + newSpec, updated := b.handleDefaultFrame(spec, windowFunc.F) + if !updated { + groupedWindow[spec] = append(groupedWindow[spec], windowFunc) + } else { + if _, ok := updatedSpecMap[name]; !ok { + updatedSpecMap[name] = newSpec + } + updatedSpec := updatedSpecMap[name] + groupedWindow[updatedSpec] = append(groupedWindow[updatedSpec], windowFunc) + } } - desc := aggregation.NewWindowFuncDesc(b.ctx, expr.F, args) - if desc == nil { - return nil, ErrWrongArguments.GenWithStackByArgs(expr.F) - } - // TODO: Check if the function is aggregation function after we support more functions. - desc.WrapCastForAggArgs(b.ctx) - window := LogicalWindow{ - WindowFuncDesc: desc, - PartitionBy: partitionBy, - OrderBy: orderBy, - Frame: frame, - }.Init(b.ctx) - schema := p.Schema().Clone() - schema.Append(&expression.Column{ - ColName: model.NewCIStr(fmt.Sprintf("%d_window_%d", window.id, p.Schema().Len())), - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - IsReferenced: true, - RetType: desc.RetTp, - }) - window.SetChildren(p) - window.SetSchema(schema) - return window, nil + return groupedWindow, nil } // resolveWindowSpec resolve window specifications for sql like `select ... from t window w1 as (w2), w2 as (partition by a)`. // We need to resolve the referenced window to get the definition of current window spec. -func resolveWindowSpec(spec *ast.WindowSpec, specs map[string]ast.WindowSpec, inStack map[string]bool) error { +func resolveWindowSpec(spec *ast.WindowSpec, specs map[string]*ast.WindowSpec, inStack map[string]bool) error { if inStack[spec.Name.L] { return errors.Trace(ErrWindowCircularityInWindowGraph) } @@ -3143,12 +3241,12 @@ func resolveWindowSpec(spec *ast.WindowSpec, specs map[string]ast.WindowSpec, in return ErrWindowNoSuchWindow.GenWithStackByArgs(spec.Ref.O) } inStack[spec.Name.L] = true - err := resolveWindowSpec(&ref, specs, inStack) + err := resolveWindowSpec(ref, specs, inStack) if err != nil { return err } inStack[spec.Name.L] = false - return mergeWindowSpec(spec, &ref) + return mergeWindowSpec(spec, ref) } func mergeWindowSpec(spec, ref *ast.WindowSpec) error { @@ -3169,17 +3267,18 @@ func mergeWindowSpec(spec, ref *ast.WindowSpec) error { return nil } -func buildWindowSpecs(specs []ast.WindowSpec) (map[string]ast.WindowSpec, error) { - specsMap := make(map[string]ast.WindowSpec, len(specs)) +func buildWindowSpecs(specs []ast.WindowSpec) (map[string]*ast.WindowSpec, error) { + specsMap := make(map[string]*ast.WindowSpec, len(specs)) for _, spec := range specs { if _, ok := specsMap[spec.Name.L]; ok { return nil, ErrWindowDuplicateName.GenWithStackByArgs(spec.Name.O) } - specsMap[spec.Name.L] = spec + newSpec := spec + specsMap[spec.Name.L] = &newSpec } inStack := make(map[string]bool, len(specs)) - for _, spec := range specs { - err := resolveWindowSpec(&spec, specsMap, inStack) + for _, spec := range specsMap { + err := resolveWindowSpec(spec, specsMap, inStack) if err != nil { return nil, err } diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index eba96e539e687..71598d87d2203 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -2196,8 +2196,8 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { result: "[planner:3591]Window 'w1' is defined twice.", }, { - sql: "select sum(a) over(w1), avg(a) over(w2) from t window w1 as (partition by a), w2 as (w1)", - result: "TableReader(Table(t))->Window(sum(cast(test.t.a)) over(partition by test.t.a))->Window(avg(cast(test.t.a)) over())->Projection", + sql: "select avg(a) over(w2) from t window w1 as (partition by a), w2 as (w1)", + result: "TableReader(Table(t))->Window(avg(cast(test.t.a)) over(partition by test.t.a))->Projection", }, { sql: "select a from t window w1 as (partition by a) order by (sum(a) over(w1))", @@ -2256,8 +2256,8 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { result: "TableReader(Table(t))->Window(row_number() over())->Projection", }, { - sql: "select avg(b), max(avg(b)) over(rows between 1 preceding and 1 following) max, min(avg(b)) over(rows between 1 preceding and 1 following) min from t group by c", - result: "IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))->Projection->StreamAgg->Window(max(sel_agg_4) over(rows between 1 preceding and 1 following))->Window(min(sel_agg_5) over(rows between 1 preceding and 1 following))->Projection", + sql: "select avg(b), max(avg(b)) over(rows between 1 preceding and 1 following) max from t group by c", + result: "IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))->Projection->StreamAgg->Window(max(sel_agg_3) over(rows between 1 preceding and 1 following))->Projection", }, { sql: "select nth_value(a, 1.0) over() from t", @@ -2283,6 +2283,14 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { sql: "select nth_value(i_date, 1) over() from t", result: "TableReader(Table(t))->Window(nth_value(test.t.i_date, 1) over())->Projection", }, + { + sql: "select sum(b) over w, sum(c) over w from t window w as (order by a)", + result: "TableReader(Table(t))->Window(sum(cast(test.t.b)), sum(cast(test.t.c)) over(order by test.t.a asc range between unbounded preceding and current row))->Projection", + }, + { + sql: "delete from t order by (sum(a) over())", + result: "[planner:3593]You cannot use the window function 'sum' in this context.'", + }, } s.Parser.EnableWindowFunc(true) diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 09cff86980d69..d6c2bfadbf922 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -686,15 +686,15 @@ type FrameBound struct { type LogicalWindow struct { logicalSchemaProducer - WindowFuncDesc *aggregation.WindowFuncDesc - PartitionBy []property.Item - OrderBy []property.Item - Frame *WindowFrame + WindowFuncDescs []*aggregation.WindowFuncDesc + PartitionBy []property.Item + OrderBy []property.Item + Frame *WindowFrame } -// GetWindowResultColumn returns the column storing the result of the window function. -func (p *LogicalWindow) GetWindowResultColumn() *expression.Column { - return p.schema.Columns[p.schema.Len()-1] +// GetWindowResultColumns returns the columns storing the result of the window function. +func (p *LogicalWindow) GetWindowResultColumns() []*expression.Column { + return p.schema.Columns[p.schema.Len()-len(p.WindowFuncDescs):] } // extractCorColumnsBySchema only extracts the correlated columns that match the specified schema. diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 9bec563393e3f..32d257090270d 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -384,10 +384,10 @@ type PhysicalTableDual struct { type PhysicalWindow struct { physicalSchemaProducer - WindowFuncDesc *aggregation.WindowFuncDesc - PartitionBy []property.Item - OrderBy []property.Item - Frame *WindowFrame + WindowFuncDescs []*aggregation.WindowFuncDesc + PartitionBy []property.Item + OrderBy []property.Item + Frame *WindowFrame } // CollectPlanStatsVersion uses to collect the statistics version of the plan. diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 2c6b5689dfc62..be8b3843f59ba 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -184,7 +184,7 @@ type PlanBuilder struct { // "STRAIGHT_JOIN" option. inStraightJoin bool - windowSpecs map[string]ast.WindowSpec + windowSpecs map[string]*ast.WindowSpec } // GetVisitInfo gets the visitInfo of the PlanBuilder. @@ -424,6 +424,13 @@ func (b *PlanBuilder) detectSelectWindow(sel *ast.SelectStmt) bool { return true } } + if sel.OrderBy != nil { + for _, item := range sel.OrderBy.Items { + if ast.HasWindowFlag(item.Expr) { + return true + } + } + } return false } @@ -1447,7 +1454,7 @@ func (b *PlanBuilder) buildSetValuesOfInsert(insert *ast.InsertStmt, insertPlan } for i, assign := range insert.Setlist { - expr, _, err := b.rewriteWithPreprocess(assign.Expr, mockTablePlan, nil, true, checkRefColumn) + expr, _, err := b.rewriteWithPreprocess(assign.Expr, mockTablePlan, nil, nil, true, checkRefColumn) if err != nil { return err } @@ -1508,7 +1515,7 @@ func (b *PlanBuilder) buildValuesListOfInsert(insert *ast.InsertStmt, insertPlan RetType: &x.Type, } default: - expr, _, err = b.rewriteWithPreprocess(valueItem, mockTablePlan, nil, true, checkRefColumn) + expr, _, err = b.rewriteWithPreprocess(valueItem, mockTablePlan, nil, nil, true, checkRefColumn) } if err != nil { return err diff --git a/planner/core/resolve_indices.go b/planner/core/resolve_indices.go index 6812eb71a11fa..def6cea1f32f5 100644 --- a/planner/core/resolve_indices.go +++ b/planner/core/resolve_indices.go @@ -321,7 +321,7 @@ func (p *PhysicalWindow) ResolveIndices() (err error) { if err != nil { return err } - for i := 0; i < len(p.Schema().Columns)-1; i++ { + for i := 0; i < len(p.Schema().Columns)-len(p.WindowFuncDescs); i++ { col := p.Schema().Columns[i] newCol, err := col.ResolveIndices(p.children[0].Schema()) if err != nil { @@ -343,10 +343,12 @@ func (p *PhysicalWindow) ResolveIndices() (err error) { } p.OrderBy[i].Col = newCol.(*expression.Column) } - for i, arg := range p.WindowFuncDesc.Args { - p.WindowFuncDesc.Args[i], err = arg.ResolveIndices(p.children[0].Schema()) - if err != nil { - return err + for _, desc := range p.WindowFuncDescs { + for i, arg := range desc.Args { + desc.Args[i], err = arg.ResolveIndices(p.children[0].Schema()) + if err != nil { + return err + } } } if p.Frame != nil { diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 33247d1f0d7aa..1a78a8ecb4d90 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -347,10 +347,17 @@ func (p *LogicalLock) PruneColumns(parentUsedCols []*expression.Column) error { // PruneColumns implements LogicalPlan interface. func (p *LogicalWindow) PruneColumns(parentUsedCols []*expression.Column) error { - windowColumn := p.GetWindowResultColumn() + windowColumns := p.GetWindowResultColumns() len := 0 for _, col := range parentUsedCols { - if !windowColumn.Equal(nil, col) { + used := false + for _, windowColumn := range windowColumns { + if windowColumn.Equal(nil, col) { + used = true + break + } + } + if !used { parentUsedCols[len] = col len++ } @@ -363,13 +370,15 @@ func (p *LogicalWindow) PruneColumns(parentUsedCols []*expression.Column) error } p.SetSchema(p.children[0].Schema().Clone()) - p.Schema().Append(windowColumn) + p.Schema().Append(windowColumns...) return nil } func (p *LogicalWindow) extractUsedCols(parentUsedCols []*expression.Column) []*expression.Column { - for _, arg := range p.WindowFuncDesc.Args { - parentUsedCols = append(parentUsedCols, expression.ExtractColumns(arg)...) + for _, desc := range p.WindowFuncDescs { + for _, arg := range desc.Args { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(arg)...) + } } for _, by := range p.PartitionBy { parentUsedCols = append(parentUsedCols, by.Col) diff --git a/planner/core/rule_eliminate_projection.go b/planner/core/rule_eliminate_projection.go index ab1ad4ba2c532..5e72fdafe204b 100644 --- a/planner/core/rule_eliminate_projection.go +++ b/planner/core/rule_eliminate_projection.go @@ -208,8 +208,10 @@ func (lt *LogicalTopN) replaceExprColumns(replace map[string]*expression.Column) } func (p *LogicalWindow) replaceExprColumns(replace map[string]*expression.Column) { - for _, arg := range p.WindowFuncDesc.Args { - resolveExprAndReplace(arg, replace) + for _, desc := range p.WindowFuncDescs { + for _, arg := range desc.Args { + resolveExprAndReplace(arg, replace) + } } for _, item := range p.PartitionBy { resolveColumnAndReplace(item.Col, replace) diff --git a/planner/core/stats.go b/planner/core/stats.go index 0df8d0eb38e8c..38b73f3981d2d 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -352,10 +352,13 @@ func (p *LogicalWindow) DeriveStats(childStats []*property.StatsInfo) (*property RowCount: childProfile.RowCount, Cardinality: make([]float64, p.schema.Len()), } - for i := 0; i < p.schema.Len()-1; i++ { + childLen := p.schema.Len() - len(p.WindowFuncDescs) + for i := 0; i < childLen; i++ { colIdx := p.children[0].Schema().ColumnIndex(p.schema.Columns[i]) p.stats.Cardinality[i] = childProfile.Cardinality[colIdx] } - p.stats.Cardinality[p.schema.Len()-1] = childProfile.RowCount + for i := childLen; i < p.schema.Len(); i++ { + p.stats.Cardinality[i] = childProfile.RowCount + } return p.stats, nil } diff --git a/planner/core/stringer.go b/planner/core/stringer.go index b91e8348162b6..beba86d8cd527 100644 --- a/planner/core/stringer.go +++ b/planner/core/stringer.go @@ -14,6 +14,7 @@ package core import ( + "bytes" "fmt" "strings" ) @@ -221,7 +222,9 @@ func toString(in Plan, strs []string, idxs []int) ([]string, []int) { str = fmt.Sprintf("%s->Insert", ToString(x.SelectPlan)) } case *LogicalWindow: - str = fmt.Sprintf("Window(%s)", x.WindowFuncDesc.String()) + buffer := bytes.NewBufferString("") + formatWindowFuncDescs(buffer, x.WindowFuncDescs) + str = fmt.Sprintf("Window(%s)", buffer.String()) case *PhysicalWindow: str = fmt.Sprintf("Window(%s)", x.ExplainInfo()) default: diff --git a/planner/core/util.go b/planner/core/util.go index 308344beff531..1dd8fcd762413 100644 --- a/planner/core/util.go +++ b/planner/core/util.go @@ -47,6 +47,31 @@ func (a *AggregateFuncExtractor) Leave(n ast.Node) (ast.Node, bool) { return n, true } +// WindowFuncExtractor visits Expr tree. +// It converts ColunmNameExpr to WindowFuncExpr and collects WindowFuncExpr. +type WindowFuncExtractor struct { + // WindowFuncs is the collected WindowFuncExprs. + windowFuncs []*ast.WindowFuncExpr +} + +// Enter implements Visitor interface. +func (a *WindowFuncExtractor) Enter(n ast.Node) (ast.Node, bool) { + switch n.(type) { + case *ast.SelectStmt, *ast.UnionStmt: + return n, true + } + return n, false +} + +// Leave implements Visitor interface. +func (a *WindowFuncExtractor) Leave(n ast.Node) (ast.Node, bool) { + switch v := n.(type) { + case *ast.WindowFuncExpr: + a.windowFuncs = append(a.windowFuncs, v) + } + return n, true +} + // logicalSchemaProducer stores the schema for the logical plans who can produce schema directly. type logicalSchemaProducer struct { schema *expression.Schema