Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: Add initial support for window functions #8928

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sql/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -1032,8 +1032,8 @@ func makeCheckConstraint(
}

var p parser.Parser
if p.AggregateInExpr(expr) {
return nil, fmt.Errorf("aggregate functions are not allowed in CHECK expressions")
if err := p.AssertNoAggregationOrWindowing(expr, "CHECK expressions"); err != nil {
return nil, err
}

if err := sqlbase.SanitizeVarFreeExpr(expr, parser.TypeBool, "CHECK"); err != nil {
Expand Down
16 changes: 3 additions & 13 deletions sql/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ func (p *planner) groupBy(n *parser.SelectClause, s *selectNode) (*groupNode, er
// that determination is made during validation, which will require matching
// expressions.
for i := range groupBy {
if p.parser.AggregateInExpr(groupBy[i]) {
return nil, fmt.Errorf("aggregate functions are not allowed in GROUP BY")
if err := p.parser.AssertNoAggregationOrWindowing(groupBy[i], "GROUP BY"); err != nil {
return nil, err
}

// We do not need to fully analyze the GROUP BY expression here
Expand Down Expand Up @@ -276,10 +276,10 @@ func (n *groupNode) Next() (bool, error) {
}
}
if !next {
n.populated = true
if err := n.computeAggregates(); err != nil {
return false, err
}
n.populated = true
break
}
if n.explain == explainDebug && n.plan.DebugValues().output != debugValueRow {
Expand Down Expand Up @@ -325,9 +325,6 @@ func (n *groupNode) computeAggregates() error {
n.buckets[""] = struct{}{}
}

// Since this controls Eval behavior of aggregateFuncHolder, it is not set until init is complete.
n.populated = true

// Render the results.
n.values.rows = make([]parser.DTuple, 0, len(n.buckets))
for k := range n.buckets {
Expand Down Expand Up @@ -617,13 +614,6 @@ func (a *aggregateFuncHolder) TypeCheck(_ *parser.SemaContext, desired parser.Da
}

func (a *aggregateFuncHolder) Eval(ctx *parser.EvalContext) (parser.Datum, error) {
// During init of the group buckets, grouped expressions (i.e. wrapped
// qvalues) are Eval()'ed to determine the bucket for a row, so pass these
// calls through to the underlying `arg` expr Eval until init is done.
if !a.group.populated {
return a.arg.Eval(ctx)
}

found, ok := a.buckets[a.group.currentBucket]
if !ok {
found = a.create()
Expand Down
4 changes: 2 additions & 2 deletions sql/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ func (p *planner) Limit(n *parser.Limit) (*limitNode, error) {

for _, datum := range data {
if datum.src != nil {
if p.parser.AggregateInExpr(datum.src) {
return nil, fmt.Errorf("aggregate functions are not allowed in %s", datum.name)
if err := p.parser.AssertNoAggregationOrWindowing(datum.src, datum.name); err != nil {
return nil, err
}

normalized, err := p.analyzeExpr(datum.src, nil, nil, parser.TypeInt, true, datum.name)
Expand Down
17 changes: 17 additions & 0 deletions sql/parser/aggregate_builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ type IsAggregateVisitor struct {
func (v *IsAggregateVisitor) VisitPre(expr Expr) (recurse bool, newExpr Expr) {
switch t := expr.(type) {
case *FuncExpr:
if t.IsWindowFunction() {
// A window function is not an aggregate function, but it can
// contain aggregate functions.
return true, expr
}
fn, err := t.Name.Normalize()
if err != nil {
return false, expr
Expand Down Expand Up @@ -106,6 +111,18 @@ func (p *Parser) IsAggregate(n *SelectClause) bool {
return false
}

// AssertNoAggregationOrWindowing checks if the provided expression contains either
// aggregate functions or window functions, returning an error in either case.
func (p *Parser) AssertNoAggregationOrWindowing(expr Expr, op string) error {
if p.AggregateInExpr(expr) {
return fmt.Errorf("aggregate functions are not allowed in %s", op)
}
if p.WindowFuncInExpr(expr) {
return fmt.Errorf("window functions are not allowed in %s", op)
}
return nil
}

// Aggregates are a special class of builtin functions that are wrapped
// at execution in a bucketing layer to combine (aggregate) the result
// of the function being run over many rows.
Expand Down
34 changes: 31 additions & 3 deletions sql/parser/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,10 @@ func (node *UnaryExpr) TypedInnerExpr() TypedExpr {

// FuncExpr represents a function call.
type FuncExpr struct {
Name NormalizableFunctionName
Type funcType
Exprs Exprs
Name NormalizableFunctionName
Type funcType
Exprs Exprs
WindowDef *WindowDef

typeAnnotation
fn Builtin
Expand All @@ -800,6 +801,25 @@ func (node *FuncExpr) GetAggregateConstructor() func() AggregateFunc {
return node.fn.AggregateFunc
}

// GetWindowConstructor returns a window function constructor if the
// FuncExpr is a built-in window function.
func (node *FuncExpr) GetWindowConstructor() func() {
// TODO(nvanbenschoten) Support built-in window functions.
return nil
}

// IsWindowFunction returns if the function is being applied as a window function.
func (node *FuncExpr) IsWindowFunction() bool {
return node.WindowDef != nil
}

// IsImpure returns whether the function application is impure, meaning that it
// potentially returns a different value when called in the same statement with
// the same parameters.
func (node *FuncExpr) IsImpure() bool {
return node.fn.impure || node.IsWindowFunction()
}

type funcType int

// FuncExpr.Type
Expand All @@ -825,6 +845,14 @@ func (node *FuncExpr) Format(buf *bytes.Buffer, f FmtFlags) {
buf.WriteString(typ)
FormatNode(buf, f, node.Exprs)
buf.WriteByte(')')
if window := node.WindowDef; window != nil {
buf.WriteString(" OVER ")
if window.Name != "" {
FormatNode(buf, f, window.Name)
} else {
FormatNode(buf, f, window)
}
}
}

// OverlayExpr represents an overlay function call.
Expand Down
2 changes: 1 addition & 1 deletion sql/parser/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ func (v *isConstVisitor) VisitPre(expr Expr) (recurse bool, newExpr Expr) {

switch t := expr.(type) {
case *FuncExpr:
if t.fn.impure {
if t.IsImpure() {
v.isConst = false
return false, expr
}
Expand Down
1 change: 1 addition & 0 deletions sql/parser/normalize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ func TestNormalizeExpr(t *testing.T) {
{`a/2=1`, `a = 2`},
{`1=a/2`, `a = 2`},
{`s=lower('FOO')`, `s = 'foo'`},
{`s=lower('FOO') OVER ()`, `s = lower('FOO') OVER ()`},
{`lower(s)='foo'`, `lower(s) = 'foo'`},
{`random()`, `random()`},
{`9223372036854775808`, `9223372036854775808`},
Expand Down
9 changes: 5 additions & 4 deletions sql/parser/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ const (
// Parser wraps a scanner, parser and other utilities present in the parser
// package.
type Parser struct {
scanner Scanner
parserImpl sqlParserImpl
normalizeVisitor normalizeVisitor
isAggregateVisitor IsAggregateVisitor
scanner Scanner
parserImpl sqlParserImpl
normalizeVisitor normalizeVisitor
isAggregateVisitor IsAggregateVisitor
containsWindowVisitor ContainsWindowVisitor
}

// Parse parses the sql and returns a list of statements.
Expand Down
16 changes: 16 additions & 0 deletions sql/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,22 @@ func TestParse(t *testing.T) {

{`SELECT a FROM t HAVING a = b`},

{`SELECT a FROM t WINDOW w AS ()`},
{`SELECT a FROM t WINDOW w AS (w2)`},
{`SELECT a FROM t WINDOW w AS (PARTITION BY b)`},
{`SELECT a FROM t WINDOW w AS (PARTITION BY b, 1 + 2)`},
{`SELECT a FROM t WINDOW w AS (ORDER BY c)`},
{`SELECT a FROM t WINDOW w AS (ORDER BY c, 1 + 2)`},
{`SELECT a FROM t WINDOW w AS (PARTITION BY b ORDER BY c)`},

{`SELECT avg(1) OVER w FROM t`},
{`SELECT avg(1) OVER () FROM t`},
{`SELECT avg(1) OVER (w) FROM t`},
{`SELECT avg(1) OVER (PARTITION BY b) FROM t`},
{`SELECT avg(1) OVER (ORDER BY c) FROM t`},
{`SELECT avg(1) OVER (PARTITION BY b ORDER BY c) FROM t`},
{`SELECT avg(1) OVER (w PARTITION BY b ORDER BY c) FROM t`},

{`SELECT a FROM t UNION SELECT 1 FROM t`},
{`SELECT a FROM t UNION SELECT 1 FROM t UNION SELECT 1 FROM t`},
{`SELECT a FROM t UNION ALL SELECT 1 FROM t`},
Expand Down
58 changes: 58 additions & 0 deletions sql/parser/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type SelectClause struct {
Where *Where
GroupBy GroupBy
Having *Where
Window Window
Lock string
tableSelect bool
}
Expand All @@ -91,6 +92,7 @@ func (node *SelectClause) Format(buf *bytes.Buffer, f FmtFlags) {
FormatNode(buf, f, node.Where)
FormatNode(buf, f, node.GroupBy)
FormatNode(buf, f, node.Having)
FormatNode(buf, f, node.Window)
buf.WriteString(node.Lock)
}
}
Expand Down Expand Up @@ -468,3 +470,59 @@ func (node *Limit) Format(buf *bytes.Buffer, f FmtFlags) {
}
}
}

// Window represents a WINDOW clause.
type Window []*WindowDef

// Format implements the NodeFormatter interface.
func (node Window) Format(buf *bytes.Buffer, f FmtFlags) {
prefix := " WINDOW "
for _, n := range node {
buf.WriteString(prefix)
FormatNode(buf, f, n.Name)
buf.WriteString(" AS ")
FormatNode(buf, f, n)
prefix = ", "
}
}

// WindowDef represents a single window definition expression.
type WindowDef struct {
Name Name
RefName Name
Partitions Exprs
OrderBy OrderBy
}

// Format implements the NodeFormatter interface.
func (node *WindowDef) Format(buf *bytes.Buffer, f FmtFlags) {
buf.WriteRune('(')
needSpaceSeparator := false
if node.RefName != "" {
FormatNode(buf, f, node.RefName)
needSpaceSeparator = true
}
if node.Partitions != nil {
if needSpaceSeparator {
buf.WriteRune(' ')
}
buf.WriteString("PARTITION BY ")
FormatNode(buf, f, node.Partitions)
needSpaceSeparator = true
}
if node.OrderBy != nil {
if needSpaceSeparator {
FormatNode(buf, f, node.OrderBy)
} else {
// We need to remove the initial space produced by OrderBy.Format.
var tmpBuf bytes.Buffer
FormatNode(&tmpBuf, f, node.OrderBy)
buf.WriteString(tmpBuf.String()[1:])
}
needSpaceSeparator = true
_ = needSpaceSeparator // avoid compiler warning until TODO below is addressed.
}
// TODO(nvanbenschoten): Support Window Frames.
// if node.Frame != nil {}
buf.WriteRune(')')
}
Loading