Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
sql: implement new API for node transformation
Browse files Browse the repository at this point in the history
Instead of having TransformUp and TransformExpressionsUp in each
node, which was really error prone, now we have a different API.

Each node will have a WithChildren method that will receive
the children from which a new node of the same type will be created.
These nodes must come in the same number and order as the ones
returned by the Children method.

Expressioner nodes will also have WithExpressions method, which is
the same, except it will create a new node with its expressions
changed, instead of the children nodes.

The plan package will expose 3 new helpers:

- TransformUp: which transforms a node from the bottom-up.
- TransformExpressionsUp: which transforms expressions of a node from
  the bottom up.
- TransformExpressions: which transforms the expressions of just the
  given node.

Just like with nodes, expressions will also have a new WithChildren
method that does the exact same thing as it does in the nodes.

The expression package will expose a new helper:

- TransformUp: which transforms an expression from the bottom up.

Caveats and limitations:

One thing that may seem odd is the limitation that WithChildren and
WithExpressions must receive the children in the exact same order and
number as they were returned from Expressions or Children.

This is because without this limitation there is no way to build
certain nodes. If we force this limitation on one, it would feel odd
not to have it elsewhere.

For example, take Case expression into account. It may or may not have
Expr, it has a list of branches (each having 2 expressions) and it
may or may not have an Else expression. If WithChildren receives N
children, how do we know where to put all these expressions?
This limitation allows us to build the node beacause we know the shape
of children must match the current shape.

Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
  • Loading branch information
erizocosmico committed Jun 26, 2019
1 parent 8702d43 commit 2c7cd89
Show file tree
Hide file tree
Showing 125 changed files with 1,088 additions and 1,817 deletions.
8 changes: 4 additions & 4 deletions mem/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
"io"
"strconv"

errors "gopkg.in/src-d/go-errors.v1"
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
errors "gopkg.in/src-d/go-errors.v1"
)

// Table represents an in-memory database table.
Expand Down Expand Up @@ -312,14 +312,14 @@ func (t *Table) HandledFilters(filters []sql.Expression) []sql.Expression {
var handled []sql.Expression
for _, f := range filters {
var hasOtherFields bool
_, _ = f.TransformUp(func(e sql.Expression) (sql.Expression, error) {
expression.Inspect(f, func(e sql.Expression) bool {
if e, ok := e.(*expression.GetField); ok {
if e.Table() != t.name || !t.schema.Contains(e.Name(), t.name) {
hasOtherFields = true
return false
}
}

return e, nil
return true
})

if !hasOtherFields {
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func reorderAggregations(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, e

a.Log("reorder aggregations, node of type: %T", n)

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.GroupBy:
if !hasHiddenAggregations(n.Aggregate...) {
Expand All @@ -38,7 +38,7 @@ func fixAggregations(projection, grouping []sql.Expression, child sql.Node) (sql

for i, p := range projection {
var transformed bool
e, err := p.TransformUp(func(e sql.Expression) (sql.Expression, error) {
e, err := expression.TransformUp(p, func(e sql.Expression) (sql.Expression, error) {
agg, ok := e.(sql.Aggregation)
if !ok {
return e, nil
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/assign_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func assignCatalog(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
span, _ := ctx.Span("assign_catalog")
defer span.Finish()

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
if !n.Resolved() {
return n, nil
}
Expand Down
10 changes: 5 additions & 5 deletions sql/analyzer/convert_dates.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
// replaced by.
var replacements = make(map[tableCol]string)

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
exp, ok := n.(sql.Expressioner)
if !ok {
return n, nil
Expand Down Expand Up @@ -48,7 +48,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
case *plan.GroupBy:
var aggregate = make([]sql.Expression, len(exp.Aggregate))
for i, a := range exp.Aggregate {
agg, err := a.TransformUp(func(e sql.Expression) (sql.Expression, error) {
agg, err := expression.TransformUp(a, func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true)
})
if err != nil {
Expand All @@ -64,7 +64,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {

var grouping = make([]sql.Expression, len(exp.Grouping))
for i, g := range exp.Grouping {
gr, err := g.TransformUp(func(e sql.Expression) (sql.Expression, error) {
gr, err := expression.TransformUp(g, func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, false)
})
if err != nil {
Expand All @@ -77,7 +77,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
case *plan.Project:
var projections = make([]sql.Expression, len(exp.Projections))
for i, e := range exp.Projections {
expr, err := e.TransformUp(func(e sql.Expression) (sql.Expression, error) {
expr, err := expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true)
})
if err != nil {
Expand All @@ -93,7 +93,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {

result = plan.NewProject(projections, exp.Child)
default:
result, err = exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
result, err = plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, n, replacements, nodeReplacements, expressions, false)
})
}
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func exprToTableFilters(expr sql.Expression) filters {
for _, expr := range splitExpression(expr) {
var seenTables = make(map[string]struct{})
var lastTable string
_, _ = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) {
expression.Inspect(expr, func(e sql.Expression) bool {
f, ok := e.(*expression.GetField)
if ok {
if _, ok := seenTables[f.Table()]; !ok {
Expand All @@ -29,7 +29,7 @@ func exprToTableFilters(expr sql.Expression) filters {
}
}

return e, nil
return true
})

if len(seenTables) == 1 {
Expand Down
21 changes: 11 additions & 10 deletions sql/analyzer/optimization_rules.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package analyzer

import (
"gopkg.in/src-d/go-errors.v1"
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/plan"
"gopkg.in/src-d/go-errors.v1"
)

func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) {
Expand All @@ -17,7 +17,7 @@ func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, er

a.Log("erase projection, node of type: %T", node)

return node.TransformUp(func(node sql.Node) (sql.Node, error) {
return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) {
project, ok := node.(*plan.Project)
if ok && project.Schema().Equals(project.Child.Schema()) {
a.Log("project erased")
Expand All @@ -35,12 +35,13 @@ func optimizeDistinct(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, e
a.Log("optimize distinct, node of type: %T", node)
if n, ok := node.(*plan.Distinct); ok {
var isSorted bool
_, _ = node.TransformUp(func(node sql.Node) (sql.Node, error) {
plan.Inspect(n, func(node sql.Node) bool {
a.Log("checking for optimization in node of type: %T", node)
if _, ok := node.(*plan.Sort); ok {
isSorted = true
return false
}
return node, nil
return true
})

if isSorted {
Expand All @@ -65,7 +66,7 @@ func reorderProjection(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err
a.Log("reorder projection, node of type: %T", n)

// Then we transform the projection
return n.TransformUp(func(node sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(node sql.Node) (sql.Node, error) {
project, ok := node.(*plan.Project)
// When we transform the projection, the children will always be
// unresolved in the case we want to fix, as the reorder happens just
Expand All @@ -92,7 +93,7 @@ func reorderProjection(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err

// And add projection nodes where needed in the child tree.
var didNeedReorder bool
child, err := project.Child.TransformUp(func(node sql.Node) (sql.Node, error) {
child, err := plan.TransformUp(project.Child, func(node sql.Node) (sql.Node, error) {
var requiredColumns []string
switch node := node.(type) {
case *plan.Sort, *plan.Filter:
Expand Down Expand Up @@ -200,7 +201,7 @@ func moveJoinConditionsToFilter(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.

a.Log("moving join conditions to filter, node of type: %T", n)

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
join, ok := n.(*plan.InnerJoin)
if !ok {
return n, nil
Expand Down Expand Up @@ -268,7 +269,7 @@ func removeUnnecessaryConverts(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.N

a.Log("removing unnecessary converts, node of type: %T", n)

return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) {
if c, ok := e.(*expression.Convert); ok && c.Child.Type() == c.Type() {
return c.Child, nil
}
Expand Down Expand Up @@ -336,13 +337,13 @@ func evalFilter(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error)

a.Log("evaluating filters, node of type: %T", node)

return node.TransformUp(func(node sql.Node) (sql.Node, error) {
return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) {
filter, ok := node.(*plan.Filter)
if !ok {
return node, nil
}

e, err := filter.Expression.TransformUp(func(e sql.Expression) (sql.Expression, error) {
e, err := expression.TransformUp(filter.Expression, func(e sql.Expression) (sql.Expression, error) {
switch e := e.(type) {
case *expression.Or:
if isTrue(e.Left) {
Expand Down
30 changes: 8 additions & 22 deletions sql/analyzer/parallelize.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error)
return node, nil
}

node, err := node.TransformUp(func(node sql.Node) (sql.Node, error) {
node, err := plan.TransformUp(node, func(node sql.Node) (sql.Node, error) {
if !isParallelizable(node) {
return node, nil
}
Expand All @@ -47,7 +47,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error)
return nil, err
}

return node.TransformUp(removeRedundantExchanges)
return plan.TransformUp(node, removeRedundantExchanges)
}

// removeRedundantExchanges removes all the exchanges except for the topmost
Expand All @@ -58,13 +58,17 @@ func removeRedundantExchanges(node sql.Node) (sql.Node, error) {
return node, nil
}

e := &protectedExchange{exchange}
return e.TransformUp(func(node sql.Node) (sql.Node, error) {
child, err := plan.TransformUp(exchange.Child, func(node sql.Node) (sql.Node, error) {
if exchange, ok := node.(*plan.Exchange); ok {
return exchange.Child, nil
}
return node, nil
})
if err != nil {
return nil, err
}

return exchange.WithChildren(child)
}

func isParallelizable(node sql.Node) bool {
Expand Down Expand Up @@ -103,21 +107,3 @@ func isParallelizable(node sql.Node) bool {

return ok && tableSeen && lastWasTable
}

// protectedExchange is a placeholder node that protects a certain exchange
// node from being removed during transformations.
type protectedExchange struct {
*plan.Exchange
}

// TransformUp transforms the child with the given transform function but it
// will not call the transform function with the new instance. Instead of
// another protectedExchange, it will return an Exchange.
func (e *protectedExchange) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) {
child, err := e.Child.TransformUp(f)
if err != nil {
return nil, err
}

return plan.NewExchange(e.Parallelism, child), nil
}
4 changes: 2 additions & 2 deletions sql/analyzer/parallelize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package analyzer
import (
"testing"

"github.com/stretchr/testify/require"
"github.com/src-d/go-mysql-server/mem"
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/plan"
"github.com/stretchr/testify/require"
)

func TestParallelize(t *testing.T) {
Expand Down Expand Up @@ -222,7 +222,7 @@ func TestRemoveRedundantExchanges(t *testing.T) {
),
)

result, err := node.TransformUp(removeRedundantExchanges)
result, err := plan.TransformUp(node, removeRedundantExchanges)
require.NoError(err)
require.Equal(expected, result)
}
4 changes: 2 additions & 2 deletions sql/analyzer/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
processList := a.Catalog.ProcessList

var seen = make(map[string]struct{})
n, err := n.TransformUp(func(n sql.Node) (sql.Node, error) {
n, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.ResolvedTable:
switch n.Table.(type) {
Expand Down Expand Up @@ -73,7 +73,7 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {

// Remove QueryProcess nodes from the subqueries. Otherwise, the process
// will be marked as done as soon as a subquery finishes.
node, err := n.TransformUp(func(n sql.Node) (sql.Node, error) {
node, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
if sq, ok := n.(*plan.SubqueryAlias); ok {
if qp, ok := sq.Child.(*plan.QueryProcess); ok {
return plan.NewSubqueryAlias(sq.Name(), qp.Child), nil
Expand Down
11 changes: 5 additions & 6 deletions sql/analyzer/prune_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func pruneSubqueries(
n sql.Node,
parentColumns usedColumns,
) (sql.Node, error) {
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
subq, ok := n.(*plan.SubqueryAlias)
if !ok {
return n, nil
Expand All @@ -142,7 +142,7 @@ func pruneSubqueries(
}

func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) {
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.Project:
return pruneProject(n, columns), nil
Expand All @@ -155,7 +155,7 @@ func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) {
}

func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.SubqueryAlias:
child, err := fixRemainingFieldsIndexes(n.Child)
Expand All @@ -165,8 +165,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {

return plan.NewSubqueryAlias(n.Name(), child), nil
default:
exp, ok := n.(sql.Expressioner)
if !ok {
if _, ok := n.(sql.Expressioner); !ok {
return n, nil
}

Expand All @@ -184,7 +183,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
indexes[tableCol{col.Source, col.Name}] = i
}

return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) {
gf, ok := e.(*expression.GetField)
if !ok {
return e, nil
Expand Down
Loading

0 comments on commit 2c7cd89

Please sign in to comment.