Skip to content

Commit

Permalink
Merge pull request #14369 from benesch/srf-expressions
Browse files Browse the repository at this point in the history
sql: support set-returning functions within other render expressions
  • Loading branch information
benesch authored Apr 25, 2017
2 parents 3c7fc52 + f179160 commit 922f206
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 57 deletions.
5 changes: 3 additions & 2 deletions pkg/sql/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ func (p *planner) groupBy(
group.groupByIdx = make([]int, 0, len(groupByExprs))
for _, g := range groupByExprs {
cols, exprs, hasStar, err := s.planner.computeRenderAllowingStars(
ctx, parser.SelectExpr{Expr: g}, parser.TypeAny, s.sourceInfo, s.ivarHelper)
ctx, parser.SelectExpr{Expr: g}, parser.TypeAny, s.sourceInfo, s.ivarHelper,
autoGenerateRenderOutputName)
if err != nil {
return nil, err
}
Expand All @@ -227,7 +228,7 @@ func (p *planner) groupBy(

cols, exprs, hasStar, err := s.planner.computeRenderAllowingStars(
ctx, parser.SelectExpr{Expr: f.filter}, parser.TypeAny,
s.sourceInfo, s.ivarHelper)
s.sourceInfo, s.ivarHelper, autoGenerateRenderOutputName)
if err != nil {
return nil, err
}
Expand Down
6 changes: 4 additions & 2 deletions pkg/sql/parser/generator_builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ var _ ValueGenerator = &arrayValueGenerator{}

func initGeneratorBuiltins() {
// Add all windows to the Builtins map after a few sanity checks.
for k, v := range generators {
for k, v := range Generators {
for _, g := range v {
if !g.impure {
panic(fmt.Sprintf("generator functions should all be impure, found %v", g))
Expand All @@ -97,7 +97,9 @@ func initGeneratorBuiltins() {
}
}

var generators = map[string][]Builtin{
// Generators is a map from name to slice of Builtins for all built-in
// generators.
var Generators = map[string][]Builtin{
"generate_series": {
makeGeneratorBuiltin(
ArgTypes{{"start", TypeInt}, {"end", TypeInt}},
Expand Down
7 changes: 7 additions & 0 deletions pkg/sql/parser/indexed_vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ func (h *IndexedVarHelper) AssertSameContainer(ivar *IndexedVar) {
}
}

// AppendSlot expands the capacity of this IndexedVarHelper by one and returns
// the index of the new slot.
func (h *IndexedVarHelper) AppendSlot() int {
h.vars = append(h.vars, IndexedVar{})
return len(h.vars) - 1
}

func (h *IndexedVarHelper) checkIndex(idx int) {
if idx < 0 || idx >= len(h.vars) {
panic(fmt.Sprintf("invalid var index %d (columns: %d)", idx, len(h.vars)))
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type planner struct {
subqueryVisitor subqueryVisitor
subqueryPlanVisitor subqueryPlanVisitor
nameResolutionVisitor nameResolutionVisitor
srfExtractionVisitor srfExtractionVisitor
}

// noteworthyInternalMemoryUsageBytes is the minimum size tracked by each
Expand Down
159 changes: 119 additions & 40 deletions pkg/sql/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package sql

import (
"bytes"
"errors"
"fmt"

"golang.org/x/net/context"
Expand Down Expand Up @@ -341,20 +342,26 @@ func (s *renderNode) initTargets(
if len(desiredTypes) > i {
desiredType = desiredTypes[i]
}
cols, exprs, hasStar, err := s.planner.computeRenderAllowingStars(ctx, target, desiredType,
s.sourceInfo, s.ivarHelper)

// Output column names should exactly match the original expression, so we
// have to determine the output column name before we rewrite SRFs below.
outputName, err := getRenderColName(s.planner.session.SearchPath, target)
if err != nil {
return err
}

// If the current expression is a set-returning function, we need to move
// it up to the sources list as a cross join and add a render for the
// If the current expression contains a set-returning function, we need to
// move it up to the sources list as a cross join and add a render for the
// function's column in the join.
if e := extractSetReturningFunction(exprs); e != nil {
cols, exprs, hasStar, err = s.transformToCrossJoin(ctx, e, desiredType)
if err != nil {
return err
}
newTarget, err := s.rewriteSRFs(ctx, target)
if err != nil {
return err
}

cols, exprs, hasStar, err := s.planner.computeRenderAllowingStars(ctx, newTarget, desiredType,
s.sourceInfo, s.ivarHelper, outputName)
if err != nil {
return err
}

s.isStar = s.isStar || hasStar
Expand All @@ -370,43 +377,91 @@ func (s *renderNode) initTargets(
return nil
}

// extractSetReturningFunction checks if the first expression in the list is a
// FuncExpr that returns a TypeTable, returning it if so.
func extractSetReturningFunction(exprs []parser.TypedExpr) *parser.FuncExpr {
if len(exprs) == 1 && exprs[0].ResolvedType().FamilyEqual(parser.TypeTable) {
switch e := exprs[0].(type) {
case *parser.FuncExpr:
return e
// srfExtractionVisitor replaces the innermost set-returning function in an
// expression with an IndexedVar that points at a new index at the end of the
// ivarHelper. The extracted SRF is retained in the srf field.
//
// This visitor is intentionally limited to extracting only one SRF, because we
// don't support lateral correlated subqueries.
type srfExtractionVisitor struct {
err error
srf *parser.FuncExpr
ivarHelper *parser.IndexedVarHelper
searchPath parser.SearchPath
}

var _ parser.Visitor = &srfExtractionVisitor{}

func (v *srfExtractionVisitor) VisitPre(expr parser.Expr) (recurse bool, newNode parser.Expr) {
_, isSubquery := expr.(*parser.Subquery)
return !isSubquery, expr
}

func (v *srfExtractionVisitor) VisitPost(expr parser.Expr) parser.Expr {
switch t := expr.(type) {
case *parser.FuncExpr:
fd, err := t.Func.Resolve(v.searchPath)
if err != nil {
v.err = err
return expr
}
if _, ok := parser.Generators[fd.Name]; ok {
if v.srf != nil {
v.err = errors.New("cannot specify two set-returning functions in the same SELECT expression")
return expr
}
v.srf = t
return v.ivarHelper.IndexedVar(v.ivarHelper.AppendSlot())
}
}
return nil
return expr
}

// transformToCrossJoin moves a would-be render expression into a data source
// cross-joined with the renderNode's existing data sources, returning a
// render expression that points at the new data source.
func (s *renderNode) transformToCrossJoin(
ctx context.Context, e *parser.FuncExpr, desiredType parser.Type,
) (columns sqlbase.ResultColumns, exprs []parser.TypedExpr, hasStar bool, err error) {
src, err := s.planner.getDataSource(ctx, e, nil, publicColumns)
// rewriteSRFs creates data sources for any set-returning functions in the
// provided render expression, cross-joins these data sources with the
// renderNode's existing data sources, and returns a new render expression with
// the set-returning function replaced by an IndexedVar that points at the new
// data source.
//
// Expressions with more than one SRF require lateral correlated subqueries,
// which are not yet supported. For now, this function returns an error if more
// than one SRF is present in the render expression.
func (s *renderNode) rewriteSRFs(
ctx context.Context, target parser.SelectExpr,
) (parser.SelectExpr, error) {
// Walk the render expression looking for SRFs.
v := &s.planner.srfExtractionVisitor
*v = srfExtractionVisitor{
err: nil,
srf: nil,
ivarHelper: &s.ivarHelper,
searchPath: s.planner.session.SearchPath,
}
expr, _ := parser.WalkExpr(v, target.Expr)
if v.err != nil {
return target, v.err
}

// Return the original render expression unchanged if the srfExtractionVisitor
// didn't find any SRFs.
if v.srf == nil {
return target, nil
}

// We rewrote exactly one SRF; cross-join it with our sources and return the
// new render expression.
src, err := s.planner.getDataSource(ctx, v.srf, nil, publicColumns)
if err != nil {
return nil, nil, false, err
return target, err
}
src, err = s.planner.makeJoin(ctx, "CROSS JOIN", s.source, src, nil)
if err != nil {
return nil, nil, false, err
return target, err
}
s.source = src
s.sourceInfo = multiSourceInfo{s.source.info}
// We must regenerate the var helper at this point since we changed
// the source list.
s.ivarHelper = parser.MakeIndexedVarHelper(s, len(s.sourceInfo[0].sourceColumns))

newTarget := parser.SelectExpr{
Expr: s.ivarHelper.IndexedVar(s.ivarHelper.NumVars() - 1),
}
return s.planner.computeRenderAllowingStars(ctx, newTarget, desiredType,
s.sourceInfo, s.ivarHelper)
return parser.SelectExpr{Expr: expr}, nil
}

func (s *renderNode) initWhere(ctx context.Context, where *parser.Where) (*filterNode, error) {
Expand Down Expand Up @@ -441,22 +496,46 @@ func (s *renderNode) initWhere(ctx context.Context, where *parser.Where) (*filte
}

// getRenderColName returns the output column name for a render expression.
func getRenderColName(target parser.SelectExpr) string {
func getRenderColName(searchPath parser.SearchPath, target parser.SelectExpr) (string, error) {
if target.As != "" {
return string(target.As)
return string(target.As), nil
}

// If the expression designates a column, try to reuse that column's name
// as render name.
if c, ok := target.Expr.(*parser.ColumnItem); ok {
if err := target.NormalizeTopLevelVarName(); err != nil {
return "", err
}

// If target.Expr is a funcExpr, resolving the function within will normalize
// target.Expr's string representation. We want the output column name to be
// unnormalized, so we compute target.Expr's string representation now, even
// though we may choose to return something other than exprStr in the switch
// below.
exprStr := target.Expr.String()

switch t := target.Expr.(type) {
case *parser.ColumnItem:
// We only shorten the name of the result column to become the
// unqualified column part of this expr name if there is
// no additional subscript on the column.
if len(c.Selector) == 0 {
return c.Column()
if len(t.Selector) == 0 {
return t.Column(), nil
}

// For compatibility with Postgres, a render expression rooted by a
// set-returning function is named after that SRF.
case *parser.FuncExpr:
fd, err := t.Func.Resolve(searchPath)
if err != nil {
return "", err
}
if _, ok := parser.Generators[fd.Name]; ok {
return fd.Name, nil
}
}
return target.Expr.String()

return exprStr, nil
}

// appendRenderColumn adds a new render expression at the end of the current list.
Expand Down
3 changes: 2 additions & 1 deletion pkg/sql/returning.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ func (p *planner) newReturningHelper(
ivarHelper := parser.MakeIndexedVarHelper(rh, len(tablecols))
for _, target := range rExprs {
cols, typedExprs, _, err := p.computeRenderAllowingStars(
ctx, target, parser.TypeAny, multiSourceInfo{rh.source}, ivarHelper)
ctx, target, parser.TypeAny, multiSourceInfo{rh.source}, ivarHelper,
autoGenerateRenderOutputName)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (p *planner) orderBy(
if index == -1 && s != nil {
cols, exprs, hasStar, err := p.computeRenderAllowingStars(
ctx, parser.SelectExpr{Expr: expr}, parser.TypeAny,
s.sourceInfo, s.ivarHelper)
s.sourceInfo, s.ivarHelper, autoGenerateRenderOutputName)
if err != nil {
return nil, err
}
Expand Down
17 changes: 13 additions & 4 deletions pkg/sql/targets.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,26 @@ import (
"github.com/pkg/errors"
)

const autoGenerateRenderOutputName = ""

// computeRender expands a target expression into a result column.
func (p *planner) computeRender(
ctx context.Context,
target parser.SelectExpr,
desiredType parser.Type,
info multiSourceInfo,
ivarHelper parser.IndexedVarHelper,
outputName string,
) (column sqlbase.ResultColumn, expr parser.TypedExpr, err error) {
// When generating an output column name it should exactly match the original
// expression, so determine the output column name before we perform any
// manipulations to the expression.
outputName := getRenderColName(target)
// expression, so if our caller has requested that we generate the output
// column name, we determine the name before we perform any manipulations to
// the expression.
if outputName == autoGenerateRenderOutputName {
if outputName, err = getRenderColName(p.session.SearchPath, target); err != nil {
return sqlbase.ResultColumn{}, nil, err
}
}

normalized, err := p.analyzeExpr(ctx, target.Expr, info, ivarHelper, desiredType, false, "")
if err != nil {
Expand All @@ -54,6 +62,7 @@ func (p *planner) computeRenderAllowingStars(
desiredType parser.Type,
info multiSourceInfo,
ivarHelper parser.IndexedVarHelper,
outputName string,
) (columns sqlbase.ResultColumns, exprs []parser.TypedExpr, hasStar bool, err error) {
// Pre-normalize any VarName so the work is not done twice below.
if err := target.NormalizeTopLevelVarName(); err != nil {
Expand All @@ -66,7 +75,7 @@ func (p *planner) computeRenderAllowingStars(
return cols, typedExprs, hasStar, nil
}

col, expr, err := p.computeRender(ctx, target, desiredType, info, ivarHelper)
col, expr, err := p.computeRender(ctx, target, desiredType, info, ivarHelper, outputName)
if err != nil {
return nil, nil, false, err
}
Expand Down
45 changes: 42 additions & 3 deletions pkg/sql/testdata/logic_test/generators
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,20 @@ SELECT 3 + x FROM generate_series(1,2) AS a(x)
4
5

# Not supported yet: transforming set-returning functions that aren't top-level
# render expressions into cross joins.
query error pq: unsupported binary operator: <int> \+ <setof tuple\{int\}>

query I colnames
SELECT 3 + generate_series(1,2)
----
3 + generate_series(1, 2)
4
5

query I
SELECT 3 + (3 * generate_series(1,3))
----
6
9
12

query I
SELECT * from unnest(ARRAY[1,2])
Expand All @@ -131,3 +141,32 @@ SELECT unnest(ARRAY[1,2]), unnest(ARRAY['a', 'b'])
2 a
2 b

query I
SELECT unnest(ARRAY[3,4]) - 2
----
1
2

query II
SELECT 1 + generate_series(0, 1), unnest(ARRAY[2, 4]) - 1
----
1 1
1 3
2 1
2 3

query I
SELECT ascii(unnest(ARRAY['a', 'b', 'c']));
----
97
98
99

query error pq: cannot specify two set-returning functions in the same SELECT expression
SELECT generate_series(generate_series(1, 3), 3)

query error pq: cannot specify two set-returning functions in the same SELECT expression
SELECT generate_series(1, 3) + generate_series(1, 3)

query error pq: column name "generate_series" not found
SELECT generate_series(1, 3) FROM t WHERE generate_series > 3
Loading

0 comments on commit 922f206

Please sign in to comment.