From 2c7cd8915e07b63a553551f02c2174fa3628c67d Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Wed, 26 Jun 2019 16:42:03 +0200 Subject: [PATCH] sql: implement new API for node transformation Instead of having TransformUp and TransformExpressionsUp in each node, which was really error prone, now we have a different API. Each node will have a WithChildren method that will receive the children from which a new node of the same type will be created. These nodes must come in the same number and order as the ones returned by the Children method. Expressioner nodes will also have WithExpressions method, which is the same, except it will create a new node with its expressions changed, instead of the children nodes. The plan package will expose 3 new helpers: - TransformUp: which transforms a node from the bottom-up. - TransformExpressionsUp: which transforms expressions of a node from the bottom up. - TransformExpressions: which transforms the expressions of just the given node. Just like with nodes, expressions will also have a new WithChildren method that does the exact same thing as it does in the nodes. The expression package will expose a new helper: - TransformUp: which transforms an expression from the bottom up. Caveats and limitations: One thing that may seem odd is the limitation that WithChildren and WithExpressions must receive the children in the exact same order and number as they were returned from Expressions or Children. This is because without this limitation there is no way to build certain nodes. If we force this limitation on one, it would feel odd not to have it elsewhere. For example, take Case expression into account. It may or may not have Expr, it has a list of branches (each having 2 expressions) and it may or may not have an Else expression. If WithChildren receives N children, how do we know where to put all these expressions? This limitation allows us to build the node beacause we know the shape of children must match the current shape. Signed-off-by: Miguel Molina --- mem/table.go | 8 +- sql/analyzer/aggregations.go | 4 +- sql/analyzer/assign_catalog.go | 2 +- sql/analyzer/convert_dates.go | 10 +- sql/analyzer/filters.go | 4 +- sql/analyzer/optimization_rules.go | 21 +-- sql/analyzer/parallelize.go | 30 +--- sql/analyzer/parallelize_test.go | 4 +- sql/analyzer/process.go | 4 +- sql/analyzer/prune_columns.go | 11 +- sql/analyzer/pushdown.go | 26 +-- sql/analyzer/pushdown_test.go | 4 +- sql/analyzer/resolve_columns.go | 33 ++-- sql/analyzer/resolve_database.go | 3 +- sql/analyzer/resolve_functions.go | 5 +- sql/analyzer/resolve_generators.go | 4 +- sql/analyzer/resolve_having.go | 6 +- sql/analyzer/resolve_natural_joins.go | 10 +- sql/analyzer/resolve_orderby.go | 6 +- sql/analyzer/resolve_stars.go | 2 +- sql/analyzer/resolve_subqueries.go | 2 +- sql/analyzer/resolve_tables.go | 2 +- sql/analyzer/validation_rules_test.go | 17 +- sql/core.go | 46 +++--- sql/expression/alias.go | 11 +- sql/expression/arithmetic.go | 29 ++-- sql/expression/between.go | 22 +-- sql/expression/boolean.go | 11 +- sql/expression/case.go | 53 +++---- sql/expression/comparison.go | 138 +++++----------- sql/expression/convert.go | 14 +- sql/expression/default.go | 10 +- sql/expression/function/aggregation/avg.go | 11 +- sql/expression/function/aggregation/count.go | 11 +- sql/expression/function/aggregation/max.go | 11 +- sql/expression/function/aggregation/min.go | 11 +- sql/expression/function/aggregation/sum.go | 11 +- sql/expression/function/arraylength.go | 11 +- sql/expression/function/ceil_round_floor.go | 51 ++---- sql/expression/function/coalesce.go | 26 +-- sql/expression/function/concat.go | 22 +-- sql/expression/function/concat_ws.go | 20 +-- sql/expression/function/connection_id.go | 9 +- sql/expression/function/database.go | 9 +- sql/expression/function/date.go | 30 +--- sql/expression/function/explode.go | 24 ++- sql/expression/function/greatest_least.go | 42 +---- sql/expression/function/ifnull.go | 17 +- sql/expression/function/isbinary.go | 11 +- sql/expression/function/json_extract.go | 19 +-- sql/expression/function/json_unquote.go | 11 +- sql/expression/function/length.go | 13 +- sql/expression/function/logarithm.go | 36 ++--- sql/expression/function/lower_upper.go | 25 ++- sql/expression/function/nullif.go | 17 +- .../function/reverse_repeat_replace.go | 52 ++---- sql/expression/function/rpad_lpad.go | 23 +-- sql/expression/function/sleep.go | 13 +- sql/expression/function/soundex.go | 11 +- sql/expression/function/split.go | 17 +- sql/expression/function/sqrt_power.go | 30 ++-- sql/expression/function/substring.go | 51 +----- sql/expression/function/time.go | 149 +++++++----------- .../function/tobase64_frombase64.go | 23 ++- sql/expression/function/trim_ltrim_rtrim.go | 15 +- sql/expression/function/version.go | 9 +- sql/expression/get_field.go | 21 ++- sql/expression/interval.go | 14 +- sql/expression/isnull.go | 11 +- sql/expression/like.go | 17 +- sql/expression/literal.go | 10 +- sql/expression/logic.go | 34 ++-- sql/expression/star.go | 10 +- sql/expression/transform.go | 28 ++++ sql/expression/tuple.go | 16 +- sql/expression/unresolved.go | 26 ++- sql/index_test.go | 4 +- sql/parse/parse.go | 2 +- sql/plan/common.go | 17 -- sql/plan/create_index.go | 52 ++---- sql/plan/cross_join.go | 31 +--- sql/plan/ddl.go | 14 +- sql/plan/describe.go | 40 ++--- sql/plan/distinct.go | 38 ++--- sql/plan/drop_index.go | 29 +--- sql/plan/empty_table.go | 11 +- sql/plan/exchange.go | 32 ++-- sql/plan/exchange_test.go | 11 +- sql/plan/filter.go | 38 ++--- sql/plan/generate.go | 46 ++---- sql/plan/generate_test.go | 7 +- sql/plan/group_by.go | 54 +++---- sql/plan/having.go | 37 ++--- sql/plan/insert.go | 33 +--- sql/plan/join.go | 141 ++++------------- sql/plan/limit.go | 21 +-- sql/plan/lock.go | 45 +++--- sql/plan/naturaljoin.go | 31 +--- sql/plan/nothing.go | 13 +- sql/plan/offset.go | 19 +-- sql/plan/process.go | 25 +-- sql/plan/processlist.go | 11 +- sql/plan/project.go | 44 ++---- sql/plan/resolved_table.go | 11 +- sql/plan/set.go | 44 +++--- sql/plan/show_collation.go | 13 +- sql/plan/show_create_database.go | 13 +- sql/plan/show_create_table.go | 13 +- sql/plan/show_indexes.go | 11 +- sql/plan/show_tables.go | 11 +- sql/plan/showcolumns.go | 21 +-- sql/plan/showdatabases.go | 13 +- sql/plan/showtablestatus.go | 11 +- sql/plan/showvariables.go | 11 +- sql/plan/showwarnings.go | 11 +- sql/plan/sort.go | 65 +++----- sql/plan/subqueryalias.go | 18 ++- sql/plan/tablealias.go | 19 +-- sql/plan/transaction.go | 11 +- sql/plan/transform.go | 89 +++++++++++ sql/plan/transform_test.go | 2 +- sql/plan/unresolved.go | 13 +- sql/plan/use.go | 11 +- sql/plan/values.go | 53 +++---- sql/session_test.go | 15 +- 125 files changed, 1088 insertions(+), 1817 deletions(-) create mode 100644 sql/expression/transform.go create mode 100644 sql/plan/transform.go diff --git a/mem/table.go b/mem/table.go index 8217a1211..d12a1e316 100644 --- a/mem/table.go +++ b/mem/table.go @@ -7,9 +7,9 @@ import ( "io" "strconv" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + errors "gopkg.in/src-d/go-errors.v1" ) // Table represents an in-memory database table. @@ -312,14 +312,14 @@ func (t *Table) HandledFilters(filters []sql.Expression) []sql.Expression { var handled []sql.Expression for _, f := range filters { var hasOtherFields bool - _, _ = f.TransformUp(func(e sql.Expression) (sql.Expression, error) { + expression.Inspect(f, func(e sql.Expression) bool { if e, ok := e.(*expression.GetField); ok { if e.Table() != t.name || !t.schema.Contains(e.Name(), t.name) { hasOtherFields = true + return false } } - - return e, nil + return true }) if !hasOtherFields { diff --git a/sql/analyzer/aggregations.go b/sql/analyzer/aggregations.go index 90e5baab9..9afa5b49f 100644 --- a/sql/analyzer/aggregations.go +++ b/sql/analyzer/aggregations.go @@ -16,7 +16,7 @@ func reorderAggregations(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, e a.Log("reorder aggregations, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.GroupBy: if !hasHiddenAggregations(n.Aggregate...) { @@ -38,7 +38,7 @@ func fixAggregations(projection, grouping []sql.Expression, child sql.Node) (sql for i, p := range projection { var transformed bool - e, err := p.TransformUp(func(e sql.Expression) (sql.Expression, error) { + e, err := expression.TransformUp(p, func(e sql.Expression) (sql.Expression, error) { agg, ok := e.(sql.Aggregation) if !ok { return e, nil diff --git a/sql/analyzer/assign_catalog.go b/sql/analyzer/assign_catalog.go index 409ae082e..0a8f76ee8 100644 --- a/sql/analyzer/assign_catalog.go +++ b/sql/analyzer/assign_catalog.go @@ -10,7 +10,7 @@ func assignCatalog(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) span, _ := ctx.Span("assign_catalog") defer span.Finish() - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { if !n.Resolved() { return n, nil } diff --git a/sql/analyzer/convert_dates.go b/sql/analyzer/convert_dates.go index 1397d867a..c7af55431 100644 --- a/sql/analyzer/convert_dates.go +++ b/sql/analyzer/convert_dates.go @@ -19,7 +19,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { // replaced by. var replacements = make(map[tableCol]string) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { exp, ok := n.(sql.Expressioner) if !ok { return n, nil @@ -48,7 +48,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { case *plan.GroupBy: var aggregate = make([]sql.Expression, len(exp.Aggregate)) for i, a := range exp.Aggregate { - agg, err := a.TransformUp(func(e sql.Expression) (sql.Expression, error) { + agg, err := expression.TransformUp(a, func(e sql.Expression) (sql.Expression, error) { return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true) }) if err != nil { @@ -64,7 +64,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { var grouping = make([]sql.Expression, len(exp.Grouping)) for i, g := range exp.Grouping { - gr, err := g.TransformUp(func(e sql.Expression) (sql.Expression, error) { + gr, err := expression.TransformUp(g, func(e sql.Expression) (sql.Expression, error) { return addDateConvert(e, exp, replacements, nodeReplacements, expressions, false) }) if err != nil { @@ -77,7 +77,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { case *plan.Project: var projections = make([]sql.Expression, len(exp.Projections)) for i, e := range exp.Projections { - expr, err := e.TransformUp(func(e sql.Expression) (sql.Expression, error) { + expr, err := expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) { return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true) }) if err != nil { @@ -93,7 +93,7 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { result = plan.NewProject(projections, exp.Child) default: - result, err = exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + result, err = plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { return addDateConvert(e, n, replacements, nodeReplacements, expressions, false) }) } diff --git a/sql/analyzer/filters.go b/sql/analyzer/filters.go index 9cad5207f..bbe54adc7 100644 --- a/sql/analyzer/filters.go +++ b/sql/analyzer/filters.go @@ -20,7 +20,7 @@ func exprToTableFilters(expr sql.Expression) filters { for _, expr := range splitExpression(expr) { var seenTables = make(map[string]struct{}) var lastTable string - _, _ = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) { + expression.Inspect(expr, func(e sql.Expression) bool { f, ok := e.(*expression.GetField) if ok { if _, ok := seenTables[f.Table()]; !ok { @@ -29,7 +29,7 @@ func exprToTableFilters(expr sql.Expression) filters { } } - return e, nil + return true }) if len(seenTables) == 1 { diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index bfa83aeaa..48f70f159 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -1,10 +1,10 @@ package analyzer import ( - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" ) func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { @@ -17,7 +17,7 @@ func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, er a.Log("erase projection, node of type: %T", node) - return node.TransformUp(func(node sql.Node) (sql.Node, error) { + return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) { project, ok := node.(*plan.Project) if ok && project.Schema().Equals(project.Child.Schema()) { a.Log("project erased") @@ -35,12 +35,13 @@ func optimizeDistinct(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, e a.Log("optimize distinct, node of type: %T", node) if n, ok := node.(*plan.Distinct); ok { var isSorted bool - _, _ = node.TransformUp(func(node sql.Node) (sql.Node, error) { + plan.Inspect(n, func(node sql.Node) bool { a.Log("checking for optimization in node of type: %T", node) if _, ok := node.(*plan.Sort); ok { isSorted = true + return false } - return node, nil + return true }) if isSorted { @@ -65,7 +66,7 @@ func reorderProjection(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err a.Log("reorder projection, node of type: %T", n) // Then we transform the projection - return n.TransformUp(func(node sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(node sql.Node) (sql.Node, error) { project, ok := node.(*plan.Project) // When we transform the projection, the children will always be // unresolved in the case we want to fix, as the reorder happens just @@ -92,7 +93,7 @@ func reorderProjection(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err // And add projection nodes where needed in the child tree. var didNeedReorder bool - child, err := project.Child.TransformUp(func(node sql.Node) (sql.Node, error) { + child, err := plan.TransformUp(project.Child, func(node sql.Node) (sql.Node, error) { var requiredColumns []string switch node := node.(type) { case *plan.Sort, *plan.Filter: @@ -200,7 +201,7 @@ func moveJoinConditionsToFilter(ctx *sql.Context, a *Analyzer, n sql.Node) (sql. a.Log("moving join conditions to filter, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { join, ok := n.(*plan.InnerJoin) if !ok { return n, nil @@ -268,7 +269,7 @@ func removeUnnecessaryConverts(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.N a.Log("removing unnecessary converts, node of type: %T", n) - return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) { + return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) { if c, ok := e.(*expression.Convert); ok && c.Child.Type() == c.Type() { return c.Child, nil } @@ -336,13 +337,13 @@ func evalFilter(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) a.Log("evaluating filters, node of type: %T", node) - return node.TransformUp(func(node sql.Node) (sql.Node, error) { + return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) { filter, ok := node.(*plan.Filter) if !ok { return node, nil } - e, err := filter.Expression.TransformUp(func(e sql.Expression) (sql.Expression, error) { + e, err := expression.TransformUp(filter.Expression, func(e sql.Expression) (sql.Expression, error) { switch e := e.(type) { case *expression.Or: if isTrue(e.Left) { diff --git a/sql/analyzer/parallelize.go b/sql/analyzer/parallelize.go index 9af592339..a56b9479e 100644 --- a/sql/analyzer/parallelize.go +++ b/sql/analyzer/parallelize.go @@ -34,7 +34,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) return node, nil } - node, err := node.TransformUp(func(node sql.Node) (sql.Node, error) { + node, err := plan.TransformUp(node, func(node sql.Node) (sql.Node, error) { if !isParallelizable(node) { return node, nil } @@ -47,7 +47,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) return nil, err } - return node.TransformUp(removeRedundantExchanges) + return plan.TransformUp(node, removeRedundantExchanges) } // removeRedundantExchanges removes all the exchanges except for the topmost @@ -58,13 +58,17 @@ func removeRedundantExchanges(node sql.Node) (sql.Node, error) { return node, nil } - e := &protectedExchange{exchange} - return e.TransformUp(func(node sql.Node) (sql.Node, error) { + child, err := plan.TransformUp(exchange.Child, func(node sql.Node) (sql.Node, error) { if exchange, ok := node.(*plan.Exchange); ok { return exchange.Child, nil } return node, nil }) + if err != nil { + return nil, err + } + + return exchange.WithChildren(child) } func isParallelizable(node sql.Node) bool { @@ -103,21 +107,3 @@ func isParallelizable(node sql.Node) bool { return ok && tableSeen && lastWasTable } - -// protectedExchange is a placeholder node that protects a certain exchange -// node from being removed during transformations. -type protectedExchange struct { - *plan.Exchange -} - -// TransformUp transforms the child with the given transform function but it -// will not call the transform function with the new instance. Instead of -// another protectedExchange, it will return an Exchange. -func (e *protectedExchange) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err - } - - return plan.NewExchange(e.Parallelism, child), nil -} diff --git a/sql/analyzer/parallelize_test.go b/sql/analyzer/parallelize_test.go index 7af884aff..41282c03e 100644 --- a/sql/analyzer/parallelize_test.go +++ b/sql/analyzer/parallelize_test.go @@ -3,11 +3,11 @@ package analyzer import ( "testing" - "github.com/stretchr/testify/require" "github.com/src-d/go-mysql-server/mem" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" ) func TestParallelize(t *testing.T) { @@ -222,7 +222,7 @@ func TestRemoveRedundantExchanges(t *testing.T) { ), ) - result, err := node.TransformUp(removeRedundantExchanges) + result, err := plan.TransformUp(node, removeRedundantExchanges) require.NoError(err) require.Equal(expected, result) } diff --git a/sql/analyzer/process.go b/sql/analyzer/process.go index 2d2d0f3f1..926d52c83 100644 --- a/sql/analyzer/process.go +++ b/sql/analyzer/process.go @@ -19,7 +19,7 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { processList := a.Catalog.ProcessList var seen = make(map[string]struct{}) - n, err := n.TransformUp(func(n sql.Node) (sql.Node, error) { + n, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.ResolvedTable: switch n.Table.(type) { @@ -73,7 +73,7 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { // Remove QueryProcess nodes from the subqueries. Otherwise, the process // will be marked as done as soon as a subquery finishes. - node, err := n.TransformUp(func(n sql.Node) (sql.Node, error) { + node, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { if sq, ok := n.(*plan.SubqueryAlias); ok { if qp, ok := sq.Child.(*plan.QueryProcess); ok { return plan.NewSubqueryAlias(sq.Name(), qp.Child), nil diff --git a/sql/analyzer/prune_columns.go b/sql/analyzer/prune_columns.go index db424de9f..0b4121ebe 100644 --- a/sql/analyzer/prune_columns.go +++ b/sql/analyzer/prune_columns.go @@ -131,7 +131,7 @@ func pruneSubqueries( n sql.Node, parentColumns usedColumns, ) (sql.Node, error) { - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { subq, ok := n.(*plan.SubqueryAlias) if !ok { return n, nil @@ -142,7 +142,7 @@ func pruneSubqueries( } func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) { - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.Project: return pruneProject(n, columns), nil @@ -155,7 +155,7 @@ func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) { } func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) { - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.SubqueryAlias: child, err := fixRemainingFieldsIndexes(n.Child) @@ -165,8 +165,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) { return plan.NewSubqueryAlias(n.Name(), child), nil default: - exp, ok := n.(sql.Expressioner) - if !ok { + if _, ok := n.(sql.Expressioner); !ok { return n, nil } @@ -184,7 +183,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) { indexes[tableCol{col.Source, col.Name}] = i } - return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { gf, ok := e.(*expression.GetField) if !ok { return e, nil diff --git a/sql/analyzer/pushdown.go b/sql/analyzer/pushdown.go index 55321f54d..4c500ffcb 100644 --- a/sql/analyzer/pushdown.go +++ b/sql/analyzer/pushdown.go @@ -66,7 +66,7 @@ func fixFieldIndexesOnExpressions(schema sql.Schema, expressions ...sql.Expressi // for GetField expressions according to the schema of the row in the table // and not the one where the filter came from. func fixFieldIndexes(schema sql.Schema, exp sql.Expression) (sql.Expression, error) { - return exp.TransformUp(func(e sql.Expression) (sql.Expression, error) { + return expression.TransformUp(exp, func(e sql.Expression) (sql.Expression, error) { switch e := e.(type) { case *expression.GetField: // we need to rewrite the indexes for the table row @@ -134,7 +134,7 @@ func transformPushdown( var handledFilters []sql.Expression var queryIndexes []sql.Index - node, err := n.TransformUp(func(node sql.Node) (sql.Node, error) { + node, err := plan.TransformUp(n, func(node sql.Node) (sql.Node, error) { a.Log("transforming node of type: %T", node) switch node := node.(type) { case *plan.Filter: @@ -173,8 +173,7 @@ func transformPushdown( } func transformExpressioners(node sql.Node) (sql.Node, error) { - expressioner, ok := node.(sql.Expressioner) - if !ok { + if _, ok := node.(sql.Expressioner); !ok { return node, nil } @@ -187,7 +186,7 @@ func transformExpressioners(node sql.Node) (sql.Node, error) { return node, nil } - n, err := expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + n, err := plan.TransformExpressions(node, func(e sql.Expression) (sql.Expression, error) { for _, schema := range schemas { fixed, err := fixFieldIndexes(schema, e) if err == nil { @@ -338,20 +337,11 @@ func (r *releaser) Schema() sql.Schema { return r.Child.Schema() } -func (r *releaser) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := r.Child.TransformUp(f) - if err != nil { - return nil, err - } - return f(&releaser{child, r.Release}) -} - -func (r *releaser) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := r.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +func (r *releaser) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 1) } - return &releaser{child, r.Release}, nil + return &releaser{children[0], r.Release}, nil } func (r *releaser) String() string { diff --git a/sql/analyzer/pushdown_test.go b/sql/analyzer/pushdown_test.go index b30a54630..d0cb737ad 100644 --- a/sql/analyzer/pushdown_test.go +++ b/sql/analyzer/pushdown_test.go @@ -3,11 +3,11 @@ package analyzer import ( "testing" - "github.com/stretchr/testify/require" "github.com/src-d/go-mysql-server/mem" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" ) func TestPushdownProjectionAndFilters(t *testing.T) { @@ -212,7 +212,7 @@ func TestPushdownIndexable(t *testing.T) { require.NoError(err) // we need to remove the release function to compare, otherwise it will fail - result, err = result.TransformUp(func(node sql.Node) (sql.Node, error) { + result, err = plan.TransformUp(result, func(node sql.Node) (sql.Node, error) { switch node := node.(type) { case *releaser: return &releaser{Child: node.Child}, nil diff --git a/sql/analyzer/resolve_columns.go b/sql/analyzer/resolve_columns.go index 5c4c33769..79e976739 100644 --- a/sql/analyzer/resolve_columns.go +++ b/sql/analyzer/resolve_columns.go @@ -5,11 +5,11 @@ import ( "sort" "strings" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/internal/similartext" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" "vitess.io/vitess/go/vt/sqlparser" ) @@ -98,8 +98,15 @@ func (deferredColumn) IsNullable() bool { return true } -func (e deferredColumn) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - return fn(e) +// Children implements the Expression interface. +func (deferredColumn) Children() []sql.Expression { return nil } + +// WithChildren implements the Expression interface. +func (e deferredColumn) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 0) + } + return e, nil } type tableCol struct { @@ -120,16 +127,15 @@ type column interface { } func qualifyColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - return n.TransformUp(func(n sql.Node) (sql.Node, error) { - exp, ok := n.(sql.Expressioner) - if !ok || n.Resolved() { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + if _, ok := n.(sql.Expressioner); !ok || n.Resolved() { return n, nil } columns := getNodeAvailableColumns(n) tables := getNodeAvailableTables(n) - return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { return qualifyExpression(e, columns, tables) }) }) @@ -198,7 +204,7 @@ func qualifyExpression( default: // If any other kind of expression has a star, just replace it // with an unqualified star because it cannot be expanded. - return e.TransformUp(func(e sql.Expression) (sql.Expression, error) { + return expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) { if _, ok := e.(*expression.Star); ok { return expression.NewStar(), nil } @@ -289,14 +295,13 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) defer span.Finish() a.Log("resolve columns, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { a.Log("transforming node of type: %T", n) if n.Resolved() { return n, nil } - expressioner, ok := n.(sql.Expressioner) - if !ok { + if _, ok := n.(sql.Expressioner); !ok { return n, nil } @@ -308,7 +313,7 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) } columns := findChildIndexedColumns(n) - return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { a.Log("transforming expression of type: %T", e) uc, ok := e.(column) @@ -394,7 +399,7 @@ func resolveGroupingColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node return n, nil } - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { g, ok := n.(*plan.GroupBy) if n.Resolved() || !ok || len(g.Grouping) == 0 { return n, nil @@ -510,7 +515,7 @@ func resolveGroupingColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node if len(renames) > 0 { for i, expr := range newAggregate { var err error - newAggregate[i], err = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) { + newAggregate[i], err = expression.TransformUp(expr, func(e sql.Expression) (sql.Expression, error) { col, ok := e.(*expression.UnresolvedColumn) if ok { // We need to make sure we don't rename the reference to the diff --git a/sql/analyzer/resolve_database.go b/sql/analyzer/resolve_database.go index 6590e0748..2c5d0f628 100644 --- a/sql/analyzer/resolve_database.go +++ b/sql/analyzer/resolve_database.go @@ -2,6 +2,7 @@ package analyzer import ( "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" ) func resolveDatabase(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { @@ -10,7 +11,7 @@ func resolveDatabase(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error a.Log("resolve database, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { d, ok := n.(sql.Databaser) if !ok { return n, nil diff --git a/sql/analyzer/resolve_functions.go b/sql/analyzer/resolve_functions.go index a1a200a87..34bac11b0 100644 --- a/sql/analyzer/resolve_functions.go +++ b/sql/analyzer/resolve_functions.go @@ -3,6 +3,7 @@ package analyzer import ( "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" ) func resolveFunctions(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { @@ -10,13 +11,13 @@ func resolveFunctions(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, erro defer span.Finish() a.Log("resolve functions, node of type %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { a.Log("transforming node of type: %T", n) if n.Resolved() { return n, nil } - return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) { + return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) { a.Log("transforming expression of type: %T", e) if e.Resolved() { return e, nil diff --git a/sql/analyzer/resolve_generators.go b/sql/analyzer/resolve_generators.go index 437cf332a..4635e24d6 100644 --- a/sql/analyzer/resolve_generators.go +++ b/sql/analyzer/resolve_generators.go @@ -1,11 +1,11 @@ package analyzer import ( - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/expression/function" "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" ) var ( @@ -14,7 +14,7 @@ var ( ) func resolveGenerators(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { p, ok := n.(*plan.Project) if !ok { return n, nil diff --git a/sql/analyzer/resolve_having.go b/sql/analyzer/resolve_having.go index 64399b3b1..a891bab8c 100644 --- a/sql/analyzer/resolve_having.go +++ b/sql/analyzer/resolve_having.go @@ -3,15 +3,15 @@ package analyzer import ( "reflect" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" ) func resolveHaving(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { - return node.TransformUp(func(node sql.Node) (sql.Node, error) { + return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) { having, ok := node.(*plan.Having) if !ok { return node, nil @@ -38,7 +38,7 @@ func resolveHaving(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, erro // We need to find all the aggregations in the having that are already present in // the group by and replace them with a GetField. If the aggregation is not // present, we need to move it to the GroupBy and reference it with a GetField. - cond, err := having.Cond.TransformUp(func(e sql.Expression) (sql.Expression, error) { + cond, err := expression.TransformUp(having.Cond, func(e sql.Expression) (sql.Expression, error) { agg, ok := e.(sql.Aggregation) if !ok { return e, nil diff --git a/sql/analyzer/resolve_natural_joins.go b/sql/analyzer/resolve_natural_joins.go index 0ada6eb36..a6cf1fdb9 100644 --- a/sql/analyzer/resolve_natural_joins.go +++ b/sql/analyzer/resolve_natural_joins.go @@ -15,8 +15,8 @@ func resolveNaturalJoins(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, e var replacements = make(map[tableCol]tableCol) var tableAliases = make(map[string]string) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { - switch n := n.(type) { + return plan.TransformUp(n, func(node sql.Node) (sql.Node, error) { + switch n := node.(type) { case *plan.TableAlias: alias := n.Name() table := n.Child.(*plan.ResolvedTable).Name() @@ -25,7 +25,7 @@ func resolveNaturalJoins(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, e case *plan.NaturalJoin: return resolveNaturalJoin(n, replacements) case sql.Expressioner: - return replaceExpressions(n, replacements, tableAliases) + return replaceExpressions(node, replacements, tableAliases) default: return n, nil } @@ -115,11 +115,11 @@ func findCol(s sql.Schema, name string) (int, *sql.Column) { } func replaceExpressions( - n sql.Expressioner, + n sql.Node, replacements map[tableCol]tableCol, tableAliases map[string]string, ) (sql.Node, error) { - return n.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { switch e := e.(type) { case *expression.GetField, *expression.UnresolvedColumn: var tableName = e.(sql.Tableable).Table() diff --git a/sql/analyzer/resolve_orderby.go b/sql/analyzer/resolve_orderby.go index dc24e674a..706ace606 100644 --- a/sql/analyzer/resolve_orderby.go +++ b/sql/analyzer/resolve_orderby.go @@ -3,10 +3,10 @@ package analyzer import ( "strings" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" + errors "gopkg.in/src-d/go-errors.v1" ) func resolveOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { @@ -14,7 +14,7 @@ func resolveOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) defer span.Finish() a.Log("resolving order bys, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { a.Log("transforming node of type: %T", n) sort, ok := n.(*plan.Sort) if !ok { @@ -175,7 +175,7 @@ func pushSortDown(sort *plan.Sort) (sql.Node, error) { func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { a.Log("resolve order by literals") - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { sort, ok := n.(*plan.Sort) if !ok { return n, nil diff --git a/sql/analyzer/resolve_stars.go b/sql/analyzer/resolve_stars.go index 9f8e8534f..2261752f3 100644 --- a/sql/analyzer/resolve_stars.go +++ b/sql/analyzer/resolve_stars.go @@ -11,7 +11,7 @@ func resolveStar(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { defer span.Finish() a.Log("resolving star, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { a.Log("transforming node of type: %T", n) if n.Resolved() { return n, nil diff --git a/sql/analyzer/resolve_subqueries.go b/sql/analyzer/resolve_subqueries.go index 6c315b50b..5015253c1 100644 --- a/sql/analyzer/resolve_subqueries.go +++ b/sql/analyzer/resolve_subqueries.go @@ -10,7 +10,7 @@ func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err defer span.Finish() a.Log("resolving subqueries") - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.SubqueryAlias: a.Log("found subquery %q with child of type %T", n.Name(), n.Child) diff --git a/sql/analyzer/resolve_tables.go b/sql/analyzer/resolve_tables.go index afe00d47f..7632d1a48 100644 --- a/sql/analyzer/resolve_tables.go +++ b/sql/analyzer/resolve_tables.go @@ -21,7 +21,7 @@ func resolveTables(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) defer span.Finish() a.Log("resolve table, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { a.Log("transforming node of type: %T", n) if n.Resolved() { return n, nil diff --git a/sql/analyzer/validation_rules_test.go b/sql/analyzer/validation_rules_test.go index 7bb83d1bb..2d87614b1 100644 --- a/sql/analyzer/validation_rules_test.go +++ b/sql/analyzer/validation_rules_test.go @@ -676,17 +676,12 @@ func TestValidateExplodeUsage(t *testing.T) { type dummyNode struct{ resolved bool } -func (n dummyNode) String() string { return "dummynode" } -func (n dummyNode) Resolved() bool { return n.resolved } -func (dummyNode) Schema() sql.Schema { return nil } -func (dummyNode) Children() []sql.Node { return nil } -func (dummyNode) RowIter(*sql.Context) (sql.RowIter, error) { return nil, nil } -func (dummyNode) TransformUp(sql.TransformNodeFunc) (sql.Node, error) { return nil, nil } -func (dummyNode) TransformExpressionsUp( - sql.TransformExprFunc, -) (sql.Node, error) { - return nil, nil -} +func (n dummyNode) String() string { return "dummynode" } +func (n dummyNode) Resolved() bool { return n.resolved } +func (dummyNode) Schema() sql.Schema { return nil } +func (dummyNode) Children() []sql.Node { return nil } +func (dummyNode) RowIter(*sql.Context) (sql.RowIter, error) { return nil, nil } +func (dummyNode) WithChildren(...sql.Node) (sql.Node, error) { return nil, nil } func getValidationRule(name string) Rule { for _, rule := range DefaultValidationRules { diff --git a/sql/core.go b/sql/core.go index 14c916ef3..d676372ec 100644 --- a/sql/core.go +++ b/sql/core.go @@ -25,6 +25,10 @@ var ( //ErrUnexpectedRowLength is thrown when the obtained row has more columns than the schema ErrUnexpectedRowLength = errors.NewKind("expected %d values, got %d") + + // ErrInvalidChildrenNumber is returned when the WithChildren method of a + // node or expression is called with an invalid number of arguments. + ErrInvalidChildrenNumber = errors.NewKind("%T: invalid children number, got %d, expected %d") ) // Nameable is something that has a name. @@ -45,17 +49,6 @@ type Resolvable interface { Resolved() bool } -// Transformable is a node which can be transformed. -type Transformable interface { - // TransformUp transforms all nodes and returns the result of this transformation. - // Transformation is not propagated to subqueries. - TransformUp(TransformNodeFunc) (Node, error) - // TransformExpressionsUp transforms all expressions inside the node and all its - // children and returns a node with the result of the transformations. - // Transformation is not propagated to subqueries. - TransformExpressionsUp(TransformExprFunc) (Node, error) -} - // TransformNodeFunc is a function that given a node will return that node // as is or transformed along with an error, if any. type TransformNodeFunc func(Node) (Node, error) @@ -74,11 +67,13 @@ type Expression interface { IsNullable() bool // Eval evaluates the given row and returns a result. Eval(*Context, Row) (interface{}, error) - // TransformUp transforms the expression and all its children with the - // given transform function. - TransformUp(TransformExprFunc) (Expression, error) // Children returns the children expressions of this expression. Children() []Expression + // WithChildren returns a copy of the expression with children replaced. + // It will return an error if the number of children is different than + // the current number of children. They must be given in the same order + // as they are returned by Children. + WithChildren(...Expression) (Expression, error) } // Aggregation implements an aggregation expression, where an @@ -100,7 +95,6 @@ type Aggregation interface { // Node is a node in the execution plan tree. type Node interface { Resolvable - Transformable fmt.Stringer // Schema of the node. Schema() Schema @@ -108,16 +102,30 @@ type Node interface { Children() []Node // RowIter produces a row iterator from this node. RowIter(*Context) (RowIter, error) + // WithChildren returns a copy of the node with children replaced. + // It will return an error if the number of children is different than + // the current number of children. They must be given in the same order + // as they are returned by Children. + WithChildren(...Node) (Node, error) +} + +// OpaqueNode is a node that doesn't allow transformations to its children and +// acts a a black box. +type OpaqueNode interface { + Node + // Opaque reports whether the node is opaque or not. + Opaque() bool } // Expressioner is a node that contains expressions. type Expressioner interface { // Expressions returns the list of expressions contained by the node. Expressions() []Expression - // TransformExpressions applies for each expression in this node - // the expression's TransformUp method with the given function, and - // return a new node with the transformed expressions. - TransformExpressions(TransformExprFunc) (Node, error) + // WithExpressions returns a copy of the node with expressions replaced. + // It will return an error if the number of expressions is different than + // the current number of expressions. They must be given in the same order + // as they are returned by Expressions. + WithExpressions(...Expression) (Node, error) } // Databaser is a node that contains a reference to a database. diff --git a/sql/expression/alias.go b/sql/expression/alias.go index 6b9b1edd4..c7485dfd9 100644 --- a/sql/expression/alias.go +++ b/sql/expression/alias.go @@ -31,13 +31,12 @@ func (e *Alias) String() string { return fmt.Sprintf("%s as %s", e.Child, e.name) } -// TransformUp implements the Expression interface. -func (e *Alias) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Alias) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - return f(NewAlias(child, e.name)) + return NewAlias(children[0], e.name), nil } // Name implements the Nameable interface. diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index ecdf26961..d7044e06e 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -137,19 +137,12 @@ func isInterval(expr sql.Expression) bool { return ok } -// TransformUp implements the Expression interface. -func (a *Arithmetic) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - l, err := a.Left.TransformUp(f) - if err != nil { - return nil, err - } - - r, err := a.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (a *Arithmetic) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 2) } - - return f(NewArithmetic(l, r, a.Op)) + return NewArithmetic(children[0], children[1], a.Op), nil } // Eval implements the Expression interface. @@ -549,12 +542,10 @@ func (e *UnaryMinus) String() string { return fmt.Sprintf("-%s", e.Child) } -// TransformUp implements the sql.Expression interface. -func (e *UnaryMinus) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - c, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *UnaryMinus) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - - return f(NewUnaryMinus(c)) + return NewUnaryMinus(children[0]), nil } diff --git a/sql/expression/between.go b/sql/expression/between.go index ce892158a..15114890b 100644 --- a/sql/expression/between.go +++ b/sql/expression/between.go @@ -98,22 +98,10 @@ func (b *Between) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return cmpLower >= 0 && cmpUpper <= 0, nil } -// TransformUp implements the Expression interface. -func (b *Between) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - val, err := b.Val.TransformUp(f) - if err != nil { - return nil, err - } - - lower, err := b.Lower.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (b *Between) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 3 { + return nil, sql.ErrInvalidChildrenNumber.New(b, len(children), 3) } - - upper, err := b.Upper.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewBetween(val, lower, upper)) + return NewBetween(children[0], children[1], children[2]), nil } diff --git a/sql/expression/boolean.go b/sql/expression/boolean.go index c4991029a..73815fb7d 100644 --- a/sql/expression/boolean.go +++ b/sql/expression/boolean.go @@ -52,11 +52,10 @@ func (e *Not) String() string { return fmt.Sprintf("NOT(%s)", e.Child) } -// TransformUp implements the Expression interface. -func (e *Not) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Not) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - return f(NewNot(child)) + return NewNot(children[0]), nil } diff --git a/sql/expression/case.go b/sql/expression/case.go index 28eef6d03..31f037fb3 100644 --- a/sql/expression/case.go +++ b/sql/expression/case.go @@ -130,44 +130,41 @@ func (c *Case) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } -// TransformUp implements the sql.Expression interface. -func (c *Case) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - var expr sql.Expression - var err error - +// WithChildren implements the Expression interface. +func (c *Case) WithChildren(children ...sql.Expression) (sql.Expression, error) { + var expected = len(c.Branches) * 2 if c.Expr != nil { - expr, err = c.Expr.TransformUp(f) - if err != nil { - return nil, err - } + expected++ } - var branches []CaseBranch - for _, b := range c.Branches { - var nb CaseBranch - - nb.Cond, err = b.Cond.TransformUp(f) - if err != nil { - return nil, err - } + if c.Else != nil { + expected++ + } - nb.Value, err = b.Value.TransformUp(f) - if err != nil { - return nil, err - } + if len(children) != expected { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), expected) + } - branches = append(branches, nb) + var expr, elseExpr sql.Expression + if c.Expr != nil { + expr = children[0] + children = children[1:] } - var elseExpr sql.Expression if c.Else != nil { - elseExpr, err = c.Else.TransformUp(f) - if err != nil { - return nil, err - } + elseExpr = children[len(children)-1] + children = children[:len(children)-1] + } + + var branches []CaseBranch + for i := 0; i < len(children); i += 2 { + branches = append(branches, CaseBranch{ + Cond: children[i], + Value: children[i+1], + }) } - return f(NewCase(expr, branches, elseExpr)) + return NewCase(expr, branches, elseExpr), nil } func (c *Case) String() string { diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index c1fd93e04..8491f9cef 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -4,9 +4,9 @@ import ( "fmt" "sync" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/internal/regex" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // Comparer implements a comparison expression. @@ -157,19 +157,12 @@ func (e *Equals) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return result == 0, nil } -// TransformUp implements the Expression interface. -func (e *Equals) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := e.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := e.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Equals) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 2) } - - return f(NewEquals(left, right)) + return NewEquals(children[0], children[1]), nil } func (e *Equals) String() string { @@ -278,19 +271,12 @@ func (re *Regexp) compareRegexp(ctx *sql.Context, row sql.Row) (interface{}, err return ok, nil } -// TransformUp implements the Expression interface. -func (re *Regexp) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := re.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := re.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (re *Regexp) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(re, len(children), 2) } - - return f(NewRegexp(left, right)) + return NewRegexp(children[0], children[1]), nil } func (re *Regexp) String() string { @@ -321,19 +307,12 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } -// TransformUp implements the Expression interface. -func (gt *GreaterThan) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := gt.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := gt.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (gt *GreaterThan) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(gt, len(children), 2) } - - return f(NewGreaterThan(left, right)) + return NewGreaterThan(children[0], children[1]), nil } func (gt *GreaterThan) String() string { @@ -364,19 +343,12 @@ func (lt *LessThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return result == -1, nil } -// TransformUp implements the Expression interface. -func (lt *LessThan) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := lt.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := lt.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (lt *LessThan) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(lt, len(children), 2) } - - return f(NewLessThan(left, right)) + return NewLessThan(children[0], children[1]), nil } func (lt *LessThan) String() string { @@ -408,19 +380,12 @@ func (gte *GreaterThanOrEqual) Eval(ctx *sql.Context, row sql.Row) (interface{}, return result > -1, nil } -// TransformUp implements the Expression interface. -func (gte *GreaterThanOrEqual) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := gte.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := gte.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (gte *GreaterThanOrEqual) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(gte, len(children), 2) } - - return f(NewGreaterThanOrEqual(left, right)) + return NewGreaterThanOrEqual(children[0], children[1]), nil } func (gte *GreaterThanOrEqual) String() string { @@ -452,19 +417,12 @@ func (lte *LessThanOrEqual) Eval(ctx *sql.Context, row sql.Row) (interface{}, er return result < 1, nil } -// TransformUp implements the Expression interface. -func (lte *LessThanOrEqual) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := lte.Left().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (lte *LessThanOrEqual) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(lte, len(children), 2) } - - right, err := lte.Right().TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewLessThanOrEqual(left, right)) + return NewLessThanOrEqual(children[0], children[1]), nil } func (lte *LessThanOrEqual) String() string { @@ -544,19 +502,12 @@ func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } -// TransformUp implements the Expression interface. -func (in *In) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := in.Left().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (in *In) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(in, len(children), 2) } - - right, err := in.Right().TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewIn(left, right)) + return NewIn(children[0], children[1]), nil } func (in *In) String() string { @@ -632,19 +583,12 @@ func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } -// TransformUp implements the Expression interface. -func (in *NotIn) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := in.Left().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (in *NotIn) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(in, len(children), 2) } - - right, err := in.Right().TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewNotIn(left, right)) + return NewNotIn(children[0], children[1]), nil } func (in *NotIn) String() string { diff --git a/sql/expression/convert.go b/sql/expression/convert.go index d40640354..bcfe778e0 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -8,8 +8,8 @@ import ( "time" "github.com/spf13/cast" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // ErrConvertExpression is returned when a conversion is not possible. @@ -90,14 +90,12 @@ func (c *Convert) String() string { return fmt.Sprintf("convert(%v, %v)", c.Child, c.castToType) } -// TransformUp implements the Expression interface. -func (c *Convert) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := c.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (c *Convert) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } - - return f(NewConvert(child, c.castToType)) + return NewConvert(children[0], c.castToType), nil } // Eval implements the Expression interface. diff --git a/sql/expression/default.go b/sql/expression/default.go index 5757a6cff..82c5cb9e7 100644 --- a/sql/expression/default.go +++ b/sql/expression/default.go @@ -53,8 +53,10 @@ func (*DefaultColumn) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) { panic("default column is a placeholder node, but Eval was called") } -// TransformUp implements the sql.Expression interface. -func (c *DefaultColumn) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *c - return f(&n) +// WithChildren implements the Expression interface. +func (c *DefaultColumn) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) + } + return c, nil } diff --git a/sql/expression/function/aggregation/avg.go b/sql/expression/function/aggregation/avg.go index 57153f12a..5e6b2ff49 100644 --- a/sql/expression/function/aggregation/avg.go +++ b/sql/expression/function/aggregation/avg.go @@ -53,13 +53,12 @@ func (a *Avg) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) { return sum / float64(rows), nil } -// TransformUp implements AggregationExpression interface. -func (a *Avg) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := a.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (a *Avg) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 1) } - return f(NewAvg(child)) + return NewAvg(children[0]), nil } // NewBuffer implements AggregationExpression interface. (AggregationExpression) diff --git a/sql/expression/function/aggregation/count.go b/sql/expression/function/aggregation/count.go index 96406b11f..b3f900b4d 100644 --- a/sql/expression/function/aggregation/count.go +++ b/sql/expression/function/aggregation/count.go @@ -45,13 +45,12 @@ func (c *Count) String() string { return fmt.Sprintf("COUNT(%s)", c.Child) } -// TransformUp implements the Expression interface. -func (c *Count) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := c.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (c *Count) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } - return f(NewCount(child)) + return NewCount(children[0]), nil } // Update implements the Aggregation interface. diff --git a/sql/expression/function/aggregation/max.go b/sql/expression/function/aggregation/max.go index 40ec86399..e47211f2c 100644 --- a/sql/expression/function/aggregation/max.go +++ b/sql/expression/function/aggregation/max.go @@ -38,13 +38,12 @@ func (m *Max) IsNullable() bool { return false } -// TransformUp implements the Transformable interface. -func (m *Max) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Max) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - return f(NewMax(child)) + return NewMax(children[0]), nil } // NewBuffer creates a new buffer to compute the result. diff --git a/sql/expression/function/aggregation/min.go b/sql/expression/function/aggregation/min.go index 19c08c296..8e73e0812 100644 --- a/sql/expression/function/aggregation/min.go +++ b/sql/expression/function/aggregation/min.go @@ -38,13 +38,12 @@ func (m *Min) IsNullable() bool { return true } -// TransformUp implements the Transformable interface. -func (m *Min) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Min) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - return f(NewMin(child)) + return NewMin(children[0]), nil } // NewBuffer creates a new buffer to compute the result. diff --git a/sql/expression/function/aggregation/sum.go b/sql/expression/function/aggregation/sum.go index dd62e0dfe..09df362be 100644 --- a/sql/expression/function/aggregation/sum.go +++ b/sql/expression/function/aggregation/sum.go @@ -27,13 +27,12 @@ func (m *Sum) String() string { return fmt.Sprintf("SUM(%s)", m.Child) } -// TransformUp implements the Transformable interface. -func (m *Sum) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Sum) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - return f(NewSum(child)) + return NewSum(children[0]), nil } // NewBuffer creates a new buffer to compute the result. diff --git a/sql/expression/function/arraylength.go b/sql/expression/function/arraylength.go index 3002f2df9..00d10cfd2 100644 --- a/sql/expression/function/arraylength.go +++ b/sql/expression/function/arraylength.go @@ -26,13 +26,12 @@ func (f *ArrayLength) String() string { return fmt.Sprintf("array_length(%s)", f.Child) } -// TransformUp implements the Expression interface. -func (f *ArrayLength) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - child, err := f.Child.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (f *ArrayLength) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1) } - return fn(NewArrayLength(child)) + return NewArrayLength(children[0]), nil } // Eval implements the Expression interface. diff --git a/sql/expression/function/ceil_round_floor.go b/sql/expression/function/ceil_round_floor.go index 1aafa66f7..4c8c1d757 100644 --- a/sql/expression/function/ceil_round_floor.go +++ b/sql/expression/function/ceil_round_floor.go @@ -32,13 +32,12 @@ func (c *Ceil) String() string { return fmt.Sprintf("CEIL(%s)", c.Child) } -// TransformUp implements the Expression interface. -func (c *Ceil) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - child, err := c.Child.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (c *Ceil) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } - return fn(NewCeil(child)) + return NewCeil(children[0]), nil } // Eval implements the Expression interface. @@ -99,13 +98,12 @@ func (f *Floor) String() string { return fmt.Sprintf("FLOOR(%s)", f.Child) } -// TransformUp implements the Expression interface. -func (f *Floor) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - child, err := f.Child.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (f *Floor) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1) } - return fn(NewFloor(child)) + return NewFloor(children[0]), nil } // Eval implements the Expression interface. @@ -269,30 +267,7 @@ func (r *Round) Type() sql.Type { return sql.Int32 } -// TransformUp implements the Expression interface. -func (r *Round) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, 2) - - arg, err := r.Left.TransformUp(f) - if err != nil { - return nil, err - } - args[0] = arg - - args[1] = nil - if r.Right != nil { - var arg sql.Expression - arg, err = r.Right.TransformUp(f) - if err != nil { - return nil, err - } - args[1] = arg - } - - expr, err := NewRound(args...) - if err != nil { - return nil, err - } - - return f(expr) +// WithChildren implements the Expression interface. +func (r *Round) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewRound(children...) } diff --git a/sql/expression/function/coalesce.go b/sql/expression/function/coalesce.go index 529a8c455..07f7f64e6 100644 --- a/sql/expression/function/coalesce.go +++ b/sql/expression/function/coalesce.go @@ -59,29 +59,9 @@ func (c *Coalesce) String() string { return fmt.Sprintf("coalesce(%s)", strings.Join(args, ", ")) } -// TransformUp implements the sql.Expression interface. -func (c *Coalesce) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - var ( - args = make([]sql.Expression, len(c.args)) - err error - ) - - for i, arg := range c.args { - if arg != nil { - arg, err = arg.TransformUp(fn) - if err != nil { - return nil, err - } - } - args[i] = arg - } - - expr, err := NewCoalesce(args...) - if err != nil { - return nil, err - } - - return fn(expr) +// WithChildren implements the Expression interface. +func (*Coalesce) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewCoalesce(children...) } // Resolved implements the sql.Expression interface. diff --git a/sql/expression/function/concat.go b/sql/expression/function/concat.go index 77a384a43..56e7bcbab 100644 --- a/sql/expression/function/concat.go +++ b/sql/expression/function/concat.go @@ -4,8 +4,8 @@ import ( "fmt" "strings" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // Concat joins several strings together. @@ -63,23 +63,9 @@ func (f *Concat) String() string { return fmt.Sprintf("concat(%s)", strings.Join(args, ", ")) } -// TransformUp implements the Expression interface. -func (f *Concat) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, len(f.args)) - for i, arg := range f.args { - a, err := arg.TransformUp(fn) - if err != nil { - return nil, err - } - args[i] = a - } - - expr, err := NewConcat(args...) - if err != nil { - return nil, err - } - - return fn(expr) +// WithChildren implements the Expression interface. +func (*Concat) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewConcat(children...) } // Resolved implements the Expression interface. diff --git a/sql/expression/function/concat_ws.go b/sql/expression/function/concat_ws.go index 7ceffc195..c1e2dacc1 100644 --- a/sql/expression/function/concat_ws.go +++ b/sql/expression/function/concat_ws.go @@ -61,23 +61,9 @@ func (f *ConcatWithSeparator) String() string { return fmt.Sprintf("concat_ws(%s)", strings.Join(args, ", ")) } -// TransformUp implements the Expression interface. -func (f *ConcatWithSeparator) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, len(f.args)) - for i, arg := range f.args { - a, err := arg.TransformUp(fn) - if err != nil { - return nil, err - } - args[i] = a - } - - expr, err := NewConcatWithSeparator(args...) - if err != nil { - return nil, err - } - - return fn(expr) +// WithChildren implements the Expression interface. +func (*ConcatWithSeparator) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewConcatWithSeparator(children...) } // Resolved implements the Expression interface. diff --git a/sql/expression/function/connection_id.go b/sql/expression/function/connection_id.go index a34ba020e..7ee96ef58 100644 --- a/sql/expression/function/connection_id.go +++ b/sql/expression/function/connection_id.go @@ -19,9 +19,12 @@ func (ConnectionID) Type() sql.Type { return sql.Uint32 } // Resolved implements the sql.Expression interface. func (ConnectionID) Resolved() bool { return true } -// TransformUp implements the sql.Expression interface. -func (ConnectionID) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - return f(ConnectionID{}) +// WithChildren implements the Expression interface. +func (c ConnectionID) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) + } + return c, nil } // IsNullable implements the sql.Expression interface. diff --git a/sql/expression/function/database.go b/sql/expression/function/database.go index f9d6ac5e2..1246e488c 100644 --- a/sql/expression/function/database.go +++ b/sql/expression/function/database.go @@ -29,9 +29,12 @@ func (*Database) String() string { return "DATABASE()" } -// TransformUp implements the sql.Expression interface. -func (db *Database) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - return fn(db) +// WithChildren implements the Expression interface. +func (d *Database) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 0) + } + return NewDatabase(d.catalog)(), nil } // Resolved implements the sql.Expression interface. diff --git a/sql/expression/function/date.go b/sql/expression/function/date.go index cdca25692..919775a42 100644 --- a/sql/expression/function/date.go +++ b/sql/expression/function/date.go @@ -46,18 +46,9 @@ func (d *DateAdd) IsNullable() bool { // Type implements the sql.Expression interface. func (d *DateAdd) Type() sql.Type { return sql.Date } -// TransformUp implements the sql.Expression interface. -func (d *DateAdd) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - date, err := d.Date.TransformUp(f) - if err != nil { - return nil, err - } - interval, err := d.Interval.TransformUp(f) - if err != nil { - return nil, err - } - - return &DateAdd{date, interval.(*expression.Interval)}, nil +// WithChildren implements the Expression interface. +func (d *DateAdd) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewDateAdd(children...) } // Eval implements the sql.Expression interface. @@ -130,18 +121,9 @@ func (d *DateSub) IsNullable() bool { // Type implements the sql.Expression interface. func (d *DateSub) Type() sql.Type { return sql.Date } -// TransformUp implements the sql.Expression interface. -func (d *DateSub) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - date, err := d.Date.TransformUp(f) - if err != nil { - return nil, err - } - interval, err := d.Interval.TransformUp(f) - if err != nil { - return nil, err - } - - return &DateSub{date, interval.(*expression.Interval)}, nil +// WithChildren implements the Expression interface. +func (d *DateSub) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewDateSub(children...) } // Eval implements the sql.Expression interface. diff --git a/sql/expression/function/explode.go b/sql/expression/function/explode.go index 64e7d85a2..51cd2b66b 100644 --- a/sql/expression/function/explode.go +++ b/sql/expression/function/explode.go @@ -40,14 +40,12 @@ func (e *Explode) String() string { return fmt.Sprintf("EXPLODE(%s)", e.Child) } -// TransformUp implements the sql.Expression interface. -func (e *Explode) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - c, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Explode) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - - return f(NewExplode(c)) + return NewExplode(children[0]), nil } // Generate is a function that generates a row for each value of its child. @@ -84,12 +82,10 @@ func (e *Generate) String() string { return fmt.Sprintf("EXPLODE(%s)", e.Child) } -// TransformUp implements the sql.Expression interface. -func (e *Generate) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - c, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Generate) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - - return f(NewGenerate(c)) + return NewGenerate(children[0]), nil } diff --git a/sql/expression/function/greatest_least.go b/sql/expression/function/greatest_least.go index 0dcc67905..9e822d8ed 100644 --- a/sql/expression/function/greatest_least.go +++ b/sql/expression/function/greatest_least.go @@ -5,8 +5,8 @@ import ( "strconv" "strings" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" ) var ErrUintOverflow = errors.NewKind( @@ -194,23 +194,9 @@ func (f *Greatest) String() string { return fmt.Sprintf("greatest(%s)", strings.Join(args, ", ")) } -// TransformUp implements the Expression interface. -func (f *Greatest) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, len(f.Args)) - for i, arg := range f.Args { - a, err := arg.TransformUp(fn) - if err != nil { - return nil, err - } - args[i] = a - } - - expr, err := NewGreatest(args...) - if err != nil { - return nil, err - } - - return fn(expr) +// WithChildren implements the Expression interface. +func (f *Greatest) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewGreatest(children...) } // Resolved implements the Expression interface. @@ -298,23 +284,9 @@ func (f *Least) String() string { return fmt.Sprintf("least(%s)", strings.Join(args, ", ")) } -// TransformUp implements the Expression interface. -func (f *Least) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, len(f.Args)) - for i, arg := range f.Args { - a, err := arg.TransformUp(fn) - if err != nil { - return nil, err - } - args[i] = a - } - - expr, err := NewLeast(args...) - if err != nil { - return nil, err - } - - return fn(expr) +// WithChildren implements the Expression interface. +func (f *Least) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewLeast(children...) } // Resolved implements the Expression interface. diff --git a/sql/expression/function/ifnull.go b/sql/expression/function/ifnull.go index f607da7f1..566d62ea6 100644 --- a/sql/expression/function/ifnull.go +++ b/sql/expression/function/ifnull.go @@ -65,17 +65,10 @@ func (f *IfNull) String() string { return fmt.Sprintf("ifnull(%s, %s)", f.Left, f.Right) } -// TransformUp implements the Expression interface. -func (f *IfNull) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - left, err := f.Left.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (f *IfNull) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2) } - - right, err := f.Right.TransformUp(fn) - if err != nil { - return nil, err - } - - return fn(NewIfNull(left, right)) + return NewIfNull(children[0], children[1]), nil } diff --git a/sql/expression/function/isbinary.go b/sql/expression/function/isbinary.go index 150048356..e3edf74d6 100644 --- a/sql/expression/function/isbinary.go +++ b/sql/expression/function/isbinary.go @@ -45,13 +45,12 @@ func (ib *IsBinary) String() string { return fmt.Sprintf("IS_BINARY(%s)", ib.Child) } -// TransformUp implements the Expression interface. -func (ib *IsBinary) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := ib.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (ib *IsBinary) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(ib, len(children), 1) } - return f(NewIsBinary(child)) + return NewIsBinary(children[0]), nil } // Type implements the Expression interface. diff --git a/sql/expression/function/json_extract.go b/sql/expression/function/json_extract.go index 1dbdb758b..c3f8f65eb 100644 --- a/sql/expression/function/json_extract.go +++ b/sql/expression/function/json_extract.go @@ -108,22 +108,9 @@ func (j *JSONExtract) Children() []sql.Expression { return append([]sql.Expression{j.JSON}, j.Paths...) } -// TransformUp implements the sql.Expression interface. -func (j *JSONExtract) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - json, err := j.JSON.TransformUp(f) - if err != nil { - return nil, err - } - - paths := make([]sql.Expression, len(j.Paths)) - for i, p := range j.Paths { - paths[i], err = p.TransformUp(f) - if err != nil { - return nil, err - } - } - - return f(&JSONExtract{json, paths}) +// WithChildren implements the Expression interface. +func (j *JSONExtract) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewJSONExtract(children...) } func (j *JSONExtract) String() string { diff --git a/sql/expression/function/json_unquote.go b/sql/expression/function/json_unquote.go index 4b5715c4e..8a0c42de3 100644 --- a/sql/expression/function/json_unquote.go +++ b/sql/expression/function/json_unquote.go @@ -33,13 +33,12 @@ func (*JSONUnquote) Type() sql.Type { return sql.Text } -// TransformUp implements the Expression interface. -func (js *JSONUnquote) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - json, err := js.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (js *JSONUnquote) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(js, len(children), 1) } - return f(NewJSONUnquote(json)) + return NewJSONUnquote(children[0]), nil } // Eval implements the Expression interface. diff --git a/sql/expression/function/length.go b/sql/expression/function/length.go index 5e3d7c8c8..49d46aaf8 100644 --- a/sql/expression/function/length.go +++ b/sql/expression/function/length.go @@ -35,16 +35,13 @@ func NewCharLength(e sql.Expression) sql.Expression { return &Length{expression.UnaryExpression{Child: e}, NumChars} } -// TransformUp implements the sql.Expression interface. -func (l *Length) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := l.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (l *Length) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) } - nl := *l - nl.Child = child - return &nl, nil + return &Length{expression.UnaryExpression{Child: children[0]}, l.CountType}, nil } // Type implements the sql.Expression interface. diff --git a/sql/expression/function/logarithm.go b/sql/expression/function/logarithm.go index 61afc2d81..eb6355a83 100644 --- a/sql/expression/function/logarithm.go +++ b/sql/expression/function/logarithm.go @@ -5,9 +5,9 @@ import ( "math" "reflect" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + "gopkg.in/src-d/go-errors.v1" ) // ErrInvalidArgumentForLogarithm is returned when an invalid argument value is passed to a @@ -45,13 +45,12 @@ func (l *LogBase) String() string { } } -// TransformUp implements the Expression interface. -func (l *LogBase) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := l.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (l *LogBase) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) } - return f(NewLogBase(l.base, child)) + return NewLogBase(l.base, children[0]), nil } // Type returns the resultant type of the function. @@ -108,26 +107,9 @@ func (l *Log) String() string { return fmt.Sprintf("log(%s, %s)", l.Left, l.Right) } -// TransformUp implements the Expression interface. -func (l *Log) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, 2) - arg, err := l.Left.TransformUp(f) - if err != nil { - return nil, err - } - args[0] = arg - - arg, err = l.Right.TransformUp(f) - if err != nil { - return nil, err - } - args[1] = arg - expr, err := NewLog(args...) - if err != nil { - return nil, err - } - - return f(expr) +// WithChildren implements the Expression interface. +func (l *Log) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewLog(children...) } // Children implements the Expression interface. diff --git a/sql/expression/function/lower_upper.go b/sql/expression/function/lower_upper.go index 2c2b09928..f0d6cdb9d 100644 --- a/sql/expression/function/lower_upper.go +++ b/sql/expression/function/lower_upper.go @@ -2,9 +2,10 @@ package function import ( "fmt" + "strings" + "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" - "strings" ) // Lower is a function that returns the lowercase of the text provided. @@ -43,13 +44,12 @@ func (l *Lower) String() string { return fmt.Sprintf("LOWER(%s)", l.Child) } -// TransformUp implements the Expression interface. -func (l *Lower) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := l.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (l *Lower) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) } - return f(NewLower(child)) + return NewLower(children[0]), nil } // Type implements the Expression interface. @@ -93,13 +93,12 @@ func (u *Upper) String() string { return fmt.Sprintf("UPPER(%s)", u.Child) } -// TransformUp implements the Expression interface. -func (u *Upper) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := u.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (u *Upper) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) } - return f(NewUpper(child)) + return NewUpper(children[0]), nil } // Type implements the Expression interface. diff --git a/sql/expression/function/nullif.go b/sql/expression/function/nullif.go index fc08b98ce..49b5a5d9d 100644 --- a/sql/expression/function/nullif.go +++ b/sql/expression/function/nullif.go @@ -57,17 +57,10 @@ func (f *NullIf) String() string { return fmt.Sprintf("nullif(%s, %s)", f.Left, f.Right) } -// TransformUp implements the Expression interface. -func (f *NullIf) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - left, err := f.Left.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (f *NullIf) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2) } - - right, err := f.Right.TransformUp(fn) - if err != nil { - return nil, err - } - - return fn(NewNullIf(left, right)) + return NewNullIf(children[0], children[1]), nil } diff --git a/sql/expression/function/reverse_repeat_replace.go b/sql/expression/function/reverse_repeat_replace.go index 5ec196fc3..cef9e9691 100644 --- a/sql/expression/function/reverse_repeat_replace.go +++ b/sql/expression/function/reverse_repeat_replace.go @@ -4,8 +4,8 @@ import ( "fmt" "strings" - "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "gopkg.in/src-d/go-errors.v1" ) @@ -39,7 +39,7 @@ func (r *Reverse) Eval( func reverseString(s string) string { r := []rune(s) - for i, j := 0, len(r) - 1; i < j; i, j = i+1, j-1 { + for i, j := 0, len(r)-1; i < j; i, j = i+1, j-1 { r[i], r[j] = r[j], r[i] } return string(r) @@ -49,13 +49,12 @@ func (r *Reverse) String() string { return fmt.Sprintf("reverse(%s)", r.Child) } -// TransformUp implements the Expression interface. -func (r *Reverse) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := r.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (r *Reverse) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 1) } - return f(NewReverse(child)) + return NewReverse(children[0]), nil } // Type implements the Expression interface. @@ -84,18 +83,12 @@ func (r *Repeat) Type() sql.Type { return sql.Text } -// TransformUp implements the Expression interface. -func (r *Repeat) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := r.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := r.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (r *Repeat) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 2) } - return f(NewRepeat(left, right)) + return NewRepeat(children[0], children[1]), nil } // Eval implements the Expression interface. @@ -165,23 +158,12 @@ func (r *Replace) Type() sql.Type { return sql.Text } -// TransformUp implements the Expression interface. -func (r *Replace) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - str, err := r.str.TransformUp(f) - if err != nil { - return nil, err - } - - fromStr, err := r.fromStr.TransformUp(f) - if err != nil { - return nil, err - } - - toStr, err := r.toStr.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (r *Replace) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 3 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 3) } - return f(NewReplace(str, fromStr, toStr)) + return NewReplace(children[0], children[1], children[2]), nil } // Eval implements the Expression interface. diff --git a/sql/expression/function/rpad_lpad.go b/sql/expression/function/rpad_lpad.go index 74c1d762c..12b33695b 100644 --- a/sql/expression/function/rpad_lpad.go +++ b/sql/expression/function/rpad_lpad.go @@ -5,8 +5,8 @@ import ( "reflect" "strings" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" ) var ErrDivisionByZero = errors.NewKind("division by zero") @@ -68,24 +68,9 @@ func (p *Pad) String() string { return fmt.Sprintf("rpad(%s, %s, %s)", p.str, p.length, p.padStr) } -// TransformUp implements the Expression interface. -func (p *Pad) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - str, err := p.str.TransformUp(f) - if err != nil { - return nil, err - } - - len, err := p.length.TransformUp(f) - if err != nil { - return nil, err - } - - padStr, err := p.padStr.TransformUp(f) - if err != nil { - return nil, err - } - padded, _ := NewPad(p.padType, str, len, padStr) - return f(padded) +// WithChildren implements the Expression interface. +func (p *Pad) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewPad(p.padType, children...) } // Eval implements the Expression interface. diff --git a/sql/expression/function/sleep.go b/sql/expression/function/sleep.go index c8119bcbc..2c672464b 100644 --- a/sql/expression/function/sleep.go +++ b/sql/expression/function/sleep.go @@ -37,7 +37,7 @@ func (s *Sleep) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - time.Sleep(time.Duration(child.(float64) * 1000) * time.Millisecond) + time.Sleep(time.Duration(child.(float64)*1000) * time.Millisecond) return 0, nil } @@ -51,13 +51,12 @@ func (s *Sleep) IsNullable() bool { return false } -// TransformUp implements the Expression interface. -func (s *Sleep) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := s.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (s *Sleep) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) } - return f(NewSleep(child)) + return NewSleep(children[0]), nil } // Type implements the Expression interface. diff --git a/sql/expression/function/soundex.go b/sql/expression/function/soundex.go index 1f33767f1..37774228e 100644 --- a/sql/expression/function/soundex.go +++ b/sql/expression/function/soundex.go @@ -87,13 +87,12 @@ func (s *Soundex) String() string { return fmt.Sprintf("SOUNDEX(%s)", s.Child) } -// TransformUp implements the Expression interface. -func (s *Soundex) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := s.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (s *Soundex) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) } - return f(NewSoundex(child)) + return NewSoundex(children[0]), nil } // Type implements the Expression interface. diff --git a/sql/expression/function/split.go b/sql/expression/function/split.go index 57765c7b6..20e2a49f9 100644 --- a/sql/expression/function/split.go +++ b/sql/expression/function/split.go @@ -76,17 +76,10 @@ func (f *Split) String() string { return fmt.Sprintf("split(%s, %s)", f.Left, f.Right) } -// TransformUp implements the Expression interface. -func (f *Split) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - left, err := f.Left.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (f *Split) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2) } - - right, err := f.Right.TransformUp(fn) - if err != nil { - return nil, err - } - - return fn(NewSplit(left, right)) + return NewSplit(children[0], children[1]), nil } diff --git a/sql/expression/function/sqrt_power.go b/sql/expression/function/sqrt_power.go index c5bdf630a..020c8b7bf 100644 --- a/sql/expression/function/sqrt_power.go +++ b/sql/expression/function/sqrt_power.go @@ -4,8 +4,8 @@ import ( "fmt" "math" - "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) // Sqrt is a function that returns the square value of the number provided. @@ -32,13 +32,12 @@ func (s *Sqrt) IsNullable() bool { return s.Child.IsNullable() } -// TransformUp implements the Expression interface. -func (s *Sqrt) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - child, err := s.Child.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (s *Sqrt) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) } - return fn(NewSqrt(child)) + return NewSqrt(children[0]), nil } // Eval implements the Expression interface. @@ -86,19 +85,12 @@ func (p *Power) String() string { return fmt.Sprintf("power(%s, %s)", p.Left, p.Right) } -// TransformUp implements the Expression interface. -func (p *Power) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - left, err := p.Left.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (p *Power) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2) } - - right, err := p.Right.TransformUp(fn) - if err != nil { - return nil, err - } - - return fn(NewPower(left, right)) + return NewPower(children[0], children[0]), nil } // Eval implements the Expression interface. diff --git a/sql/expression/function/substring.go b/sql/expression/function/substring.go index bc1337d88..c5227b9bc 100644 --- a/sql/expression/function/substring.go +++ b/sql/expression/function/substring.go @@ -142,32 +142,9 @@ func (s *Substring) Resolved() bool { // Type implements the Expression interface. func (*Substring) Type() sql.Type { return sql.Text } -// TransformUp implements the Expression interface. -func (s *Substring) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - str, err := s.str.TransformUp(f) - if err != nil { - return nil, err - } - - start, err := s.start.TransformUp(f) - if err != nil { - return nil, err - } - - // It is safe to omit the errors of NewSubstring here because to be able to call - // this method, you need a valid instance of Substring, so the arity must be correct - // and that's the only error NewSubstring can return. - var sub sql.Expression - if s.len != nil { - len, err := s.len.TransformUp(f) - if err != nil { - return nil, err - } - sub, _ = NewSubstring(str, start, len) - } else { - sub, _ = NewSubstring(str, start) - } - return f(sub) +/// WithChildren implements the Expression interface. +func (*Substring) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewSubstring(children...) } // SubstringIndex returns the substring from string str before count occurrences of the delimiter delim. @@ -273,22 +250,10 @@ func (s *SubstringIndex) Resolved() bool { // Type implements the Expression interface. func (*SubstringIndex) Type() sql.Type { return sql.Text } -// TransformUp implements the Expression interface. -func (s *SubstringIndex) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - str, err := s.str.TransformUp(f) - if err != nil { - return nil, err - } - - delim, err := s.delim.TransformUp(f) - if err != nil { - return nil, err - } - - count, err := s.count.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (s *SubstringIndex) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 3 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 3) } - - return f(NewSubstringIndex(str, delim, count)) + return NewSubstringIndex(children[0], children[1], children[2]), nil } diff --git a/sql/expression/function/time.go b/sql/expression/function/time.go index 345c5b98d..2385aec0d 100644 --- a/sql/expression/function/time.go +++ b/sql/expression/function/time.go @@ -66,14 +66,12 @@ func (y *Year) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, y.UnaryExpression, row, year) } -// TransformUp implements the Expression interface. -func (y *Year) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := y.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (y *Year) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(y, len(children), 1) } - - return f(NewYear(child)) + return NewYear(children[0]), nil } // Month is a function that returns the month of a date. @@ -96,14 +94,12 @@ func (m *Month) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, m.UnaryExpression, row, month) } -// TransformUp implements the Expression interface. -func (m *Month) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Month) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - - return f(NewMonth(child)) + return NewMonth(children[0]), nil } // Day is a function that returns the day of a date. @@ -126,14 +122,12 @@ func (d *Day) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, d.UnaryExpression, row, day) } -// TransformUp implements the Expression interface. -func (d *Day) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (d *Day) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - - return f(NewDay(child)) + return NewDay(children[0]), nil } // Weekday is a function that returns the weekday of a date where 0 = Monday, @@ -157,14 +151,12 @@ func (d *Weekday) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, d.UnaryExpression, row, weekday) } -// TransformUp implements the Expression interface. -func (d *Weekday) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (d *Weekday) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - - return f(NewWeekday(child)) + return NewWeekday(children[0]), nil } // Hour is a function that returns the hour of a date. @@ -187,14 +179,12 @@ func (h *Hour) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, h.UnaryExpression, row, hour) } -// TransformUp implements the Expression interface. -func (h *Hour) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := h.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (h *Hour) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1) } - - return f(NewHour(child)) + return NewHour(children[0]), nil } // Minute is a function that returns the minute of a date. @@ -217,14 +207,12 @@ func (m *Minute) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, m.UnaryExpression, row, minute) } -// TransformUp implements the Expression interface. -func (m *Minute) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Minute) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - - return f(NewMinute(child)) + return NewMinute(children[0]), nil } // Second is a function that returns the second of a date. @@ -247,14 +235,12 @@ func (s *Second) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, s.UnaryExpression, row, second) } -// TransformUp implements the Expression interface. -func (s *Second) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := s.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (s *Second) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) } - - return f(NewSecond(child)) + return NewSecond(children[0]), nil } // DayOfWeek is a function that returns the day of the week from a date where @@ -278,14 +264,12 @@ func (d *DayOfWeek) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, d.UnaryExpression, row, dayOfWeek) } -// TransformUp implements the Expression interface. -func (d *DayOfWeek) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (d *DayOfWeek) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - - return f(NewDayOfWeek(child)) + return NewDayOfWeek(children[0]), nil } // DayOfYear is a function that returns the day of the year from a date. @@ -308,14 +292,12 @@ func (d *DayOfYear) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return getDatePart(ctx, d.UnaryExpression, row, dayOfYear) } -// TransformUp implements the Expression interface. -func (d *DayOfYear) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (d *DayOfYear) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - - return f(NewDayOfYear(child)) + return NewDayOfYear(children[0]), nil } func datePartFunc(fn func(time.Time) int) func(interface{}) interface{} { @@ -403,23 +385,9 @@ func (d *YearWeek) IsNullable() bool { return d.date.IsNullable() } -// TransformUp implements the Expression interface. -func (d *YearWeek) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - date, err := d.date.TransformUp(f) - if err != nil { - return nil, err - } - - mode, err := d.mode.TransformUp(f) - if err != nil { - return nil, err - } - - yw, err := NewYearWeek(date, mode) - if err != nil { - return nil, err - } - return f(yw) +// WithChildren implements the Expression interface. +func (*YearWeek) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewYearWeek(children...) } // Following solution of YearWeek was taken from tidb: https://github.com/pingcap/tidb/blob/master/types/mytime.go @@ -567,9 +535,12 @@ func (n *Now) Eval(*sql.Context, sql.Row) (interface{}, error) { return n.clock(), nil } -// TransformUp implements the sql.Expression interface. -func (n *Now) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - return f(n) +// WithChildren implements the Expression interface. +func (n *Now) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0) + } + return n, nil } // Date a function takes the DATE part out from a datetime expression. @@ -598,12 +569,10 @@ func (d *Date) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { }) } -// TransformUp implements the sql.Expression interface. -func (d *Date) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (d *Date) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - - return f(NewDate(child)) + return NewDate(children[0]), nil } diff --git a/sql/expression/function/tobase64_frombase64.go b/sql/expression/function/tobase64_frombase64.go index 543d482f7..f3c638983 100644 --- a/sql/expression/function/tobase64_frombase64.go +++ b/sql/expression/function/tobase64_frombase64.go @@ -72,13 +72,12 @@ func (t *ToBase64) IsNullable() bool { return t.Child.IsNullable() } -// TransformUp implements the Expression interface. -func (t *ToBase64) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := t.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (t *ToBase64) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1) } - return f(NewToBase64(child)) + return NewToBase64(children[0]), nil } // Type implements the Expression interface. @@ -86,7 +85,6 @@ func (t *ToBase64) Type() sql.Type { return sql.Text } - // FromBase64 is a function to decode a Base64-formatted string // using the same dialect that MySQL's FROM_BASE64 uses type FromBase64 struct { @@ -133,13 +131,12 @@ func (t *FromBase64) IsNullable() bool { return t.Child.IsNullable() } -// TransformUp implements the Expression interface. -func (t *FromBase64) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := t.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (t *FromBase64) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1) } - return f(NewFromBase64(child)) + return NewFromBase64(children[0]), nil } // Type implements the Expression interface. diff --git a/sql/expression/function/trim_ltrim_rtrim.go b/sql/expression/function/trim_ltrim_rtrim.go index 9aacacb9b..b08704dfb 100644 --- a/sql/expression/function/trim_ltrim_rtrim.go +++ b/sql/expression/function/trim_ltrim_rtrim.go @@ -6,11 +6,12 @@ import ( "strings" "unicode" - "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) type trimType rune + const ( lTrimType trimType = 'l' rTrimType trimType = 'r' @@ -54,14 +55,12 @@ func (t *Trim) IsNullable() bool { return t.Child.IsNullable() } -// TransformUp implements the Expression interface. -func (t *Trim) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - str, err := t.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (t *Trim) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1) } - - return f(NewTrim(t.trimType, str)) + return NewTrim(t.trimType, children[0]), nil } // Eval implements the Expression interface. diff --git a/sql/expression/function/version.go b/sql/expression/function/version.go index b100128f6..eafeb4454 100644 --- a/sql/expression/function/version.go +++ b/sql/expression/function/version.go @@ -30,9 +30,12 @@ func (f Version) String() string { return "VERSION()" } -// TransformUp implements the Expression interface. -func (f Version) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - return fn(f) +// WithChildren implements the Expression interface. +func (f Version) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 0) + } + return f, nil } // Resolved implements the Expression interface. diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index cb8ac9aae..e6a5884dc 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -3,8 +3,8 @@ package expression import ( "fmt" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // GetField is an expression to get the field of a table. @@ -72,10 +72,12 @@ func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return row[p.fieldIndex], nil } -// TransformUp implements the Expression interface. -func (p *GetField) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *p - return f(&n) +// WithChildren implements the Expression interface. +func (p *GetField) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + return p, nil } func (p *GetField) String() string { @@ -124,7 +126,10 @@ func (f *GetSessionField) Resolved() bool { return true } // String implements the sql.Expression interface. func (f *GetSessionField) String() string { return "@@" + f.name } -// TransformUp implements the sql.Expression interface. -func (f *GetSessionField) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - return fn(f) +// WithChildren implements the Expression interface. +func (f *GetSessionField) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 0) + } + return f, nil } diff --git a/sql/expression/interval.go b/sql/expression/interval.go index 6ec5cfa91..a175d55b5 100644 --- a/sql/expression/interval.go +++ b/sql/expression/interval.go @@ -7,8 +7,8 @@ import ( "strings" "time" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // Interval defines a time duration. @@ -148,14 +148,12 @@ func (i *Interval) EvalDelta(ctx *sql.Context, row sql.Row) (*TimeDelta, error) return &td, nil } -// TransformUp implements the sql.Expression interface. -func (i *Interval) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := i.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (i *Interval) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 1) } - - return NewInterval(child, i.Unit), nil + return NewInterval(children[0], i.Unit), nil } func (i *Interval) String() string { diff --git a/sql/expression/isnull.go b/sql/expression/isnull.go index 252b95f70..a9ae575d5 100644 --- a/sql/expression/isnull.go +++ b/sql/expression/isnull.go @@ -36,11 +36,10 @@ func (e IsNull) String() string { return e.Child.String() + " IS NULL" } -// TransformUp implements the Expression interface. -func (e *IsNull) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *IsNull) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - return f(NewIsNull(child)) + return NewIsNull(children[0]), nil } diff --git a/sql/expression/like.go b/sql/expression/like.go index 6c361b22c..1fcf08bf8 100644 --- a/sql/expression/like.go +++ b/sql/expression/like.go @@ -107,19 +107,12 @@ func (l *Like) String() string { return fmt.Sprintf("%s LIKE %s", l.Left, l.Right) } -// TransformUp implements the sql.Expression interface. -func (l *Like) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := l.Left.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (l *Like) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 2) } - - right, err := l.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewLike(left, right)) + return NewLike(children[0], children[1]), nil } func patternToRegex(pattern string) string { diff --git a/sql/expression/literal.go b/sql/expression/literal.go index cb4158865..65b3aadb1 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -51,10 +51,12 @@ func (p *Literal) String() string { } } -// TransformUp implements the Expression interface. -func (p *Literal) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *p - return f(&n) +// WithChildren implements the Expression interface. +func (p *Literal) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + return p, nil } // Children implements the Expression interface. diff --git a/sql/expression/logic.go b/sql/expression/logic.go index 6551c793a..08a31c087 100644 --- a/sql/expression/logic.go +++ b/sql/expression/logic.go @@ -68,19 +68,12 @@ func (a *And) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return true, nil } -// TransformUp implements the Expression interface. -func (a *And) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := a.Left.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (a *And) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 2) } - - right, err := a.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewAnd(left, right)) + return NewAnd(children[0], children[1]), nil } // Or checks whether one of the two given expressions is true. @@ -125,17 +118,10 @@ func (o *Or) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return rval == true, nil } -// TransformUp implements the Expression interface. -func (o *Or) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := o.Left.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (o *Or) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 2) } - - right, err := o.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewOr(left, right)) + return NewOr(children[0], children[1]), nil } diff --git a/sql/expression/star.go b/sql/expression/star.go index c2ce8691f..5d6603be1 100644 --- a/sql/expression/star.go +++ b/sql/expression/star.go @@ -55,8 +55,10 @@ func (*Star) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) { panic("star is just a placeholder node, but Eval was called") } -// TransformUp implements the Expression interface. -func (s *Star) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *s - return f(&n) +// WithChildren implements the Expression interface. +func (s *Star) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } + return s, nil } diff --git a/sql/expression/transform.go b/sql/expression/transform.go new file mode 100644 index 000000000..ccf3d4276 --- /dev/null +++ b/sql/expression/transform.go @@ -0,0 +1,28 @@ +package expression + +import "github.com/src-d/go-mysql-server/sql" + +// TransformUp applies a transformation function to the given expression from the +// bottom up. +func TransformUp(e sql.Expression, f sql.TransformExprFunc) (sql.Expression, error) { + children := e.Children() + if len(children) == 0 { + return f(e) + } + + newChildren := make([]sql.Expression, len(children)) + for i, c := range children { + c, err := TransformUp(c, f) + if err != nil { + return nil, err + } + newChildren[i] = c + } + + e, err := e.WithChildren(newChildren...) + if err != nil { + return nil, err + } + + return f(e) +} diff --git a/sql/expression/tuple.go b/sql/expression/tuple.go index 05af0367a..11d35e1f5 100644 --- a/sql/expression/tuple.go +++ b/sql/expression/tuple.go @@ -77,18 +77,12 @@ func (t Tuple) Type() sql.Type { return sql.Tuple(types...) } -// TransformUp implements the Expression interface. -func (t Tuple) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - var exprs = make([]sql.Expression, len(t)) - for i, e := range t { - var err error - exprs[i], err = f(e) - if err != nil { - return nil, err - } +// WithChildren implements the Expression interface. +func (t Tuple) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != len(t) { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), len(t)) } - - return f(Tuple(exprs)) + return NewTuple(children...), nil } // Children implements the Expression interface. diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index 276367cd7..6580655d2 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -64,10 +64,12 @@ func (*UnresolvedColumn) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) panic("unresolved column is a placeholder node, but Eval was called") } -// TransformUp implements the Expression interface. -func (uc *UnresolvedColumn) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *uc - return f(&n) +// WithChildren implements the Expression interface. +func (uc *UnresolvedColumn) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(uc, len(children), 0) + } + return uc, nil } // UnresolvedFunction represents a function that is not yet resolved. @@ -126,16 +128,10 @@ func (*UnresolvedFunction) Eval(ctx *sql.Context, r sql.Row) (interface{}, error panic("unresolved function is a placeholder node, but Eval was called") } -// TransformUp implements the Expression interface. -func (uf *UnresolvedFunction) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - var rc []sql.Expression - for _, c := range uf.Arguments { - ct, err := c.TransformUp(f) - if err != nil { - return nil, err - } - rc = append(rc, ct) +// WithChildren implements the Expression interface. +func (uf *UnresolvedFunction) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != len(uf.Arguments) { + return nil, sql.ErrInvalidChildrenNumber.New(uf, len(children), len(uf.Arguments)) } - - return f(NewUnresolvedFunction(uf.name, uf.IsAggregate, rc...)) + return NewUnresolvedFunction(uf.name, uf.IsAggregate, children...), nil } diff --git a/sql/index_test.go b/sql/index_test.go index da42c2464..1e631cba5 100644 --- a/sql/index_test.go +++ b/sql/index_test.go @@ -401,8 +401,8 @@ var _ Expression = (*dummyExpr)(nil) func (dummyExpr) Children() []Expression { return nil } func (dummyExpr) Eval(*Context, Row) (interface{}, error) { panic("not implemented") } -func (e dummyExpr) TransformUp(fn TransformExprFunc) (Expression, error) { - return fn(e) +func (e dummyExpr) WithChildren(children ...Expression) (Expression, error) { + return e, nil } func (e dummyExpr) String() string { return fmt.Sprintf("dummyExpr{%d, %s}", e.index, e.colName) } func (dummyExpr) IsNullable() bool { return false } diff --git a/sql/parse/parse.go b/sql/parse/parse.go index c8a69c1ea..8ea1803ce 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -168,7 +168,7 @@ func convertSet(ctx *sql.Context, n *sqlparser.Set) (sql.Node, error) { } name := strings.TrimSpace(e.Name.Lowered()) - if expr, err = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) { + if expr, err = expression.TransformUp(expr, func(e sql.Expression) (sql.Expression, error) { if _, ok := e.(*expression.DefaultColumn); ok { return e, nil } diff --git a/sql/plan/common.go b/sql/plan/common.go index 061a9f13a..beec46177 100644 --- a/sql/plan/common.go +++ b/sql/plan/common.go @@ -57,20 +57,3 @@ func expressionsResolved(exprs ...sql.Expression) bool { return true } - -func transformExpressionsUp( - f sql.TransformExprFunc, - exprs []sql.Expression, -) ([]sql.Expression, error) { - - var es []sql.Expression - for _, e := range exprs { - te, err := e.TransformUp(f) - if err != nil { - return nil, err - } - es = append(es, te) - } - - return es, nil -} diff --git a/sql/plan/create_index.go b/sql/plan/create_index.go index ce170bbd4..f45de2f7e 100644 --- a/sql/plan/create_index.go +++ b/sql/plan/create_index.go @@ -8,9 +8,9 @@ import ( opentracing "github.com/opentracing/opentracing-go" otlog "github.com/opentracing/opentracing-go/log" "github.com/sirupsen/logrus" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + errors "gopkg.in/src-d/go-errors.v1" ) var ( @@ -260,58 +260,28 @@ func (c *CreateIndex) Expressions() []sql.Expression { return c.Exprs } -// TransformExpressions implements the Expressioner interface. -func (c *CreateIndex) TransformExpressions(fn sql.TransformExprFunc) (sql.Node, error) { - var exprs = make([]sql.Expression, len(c.Exprs)) - var err error - for i, e := range c.Exprs { - exprs[i], err = e.TransformUp(fn) - if err != nil { - return nil, err - } +// WithExpressions implements the Expressioner interface. +func (c *CreateIndex) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(c.Exprs) { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(exprs), len(c.Exprs)) } nc := *c nc.Exprs = exprs - return &nc, nil } -// TransformExpressionsUp implements the Node interface. -func (c *CreateIndex) TransformExpressionsUp(fn sql.TransformExprFunc) (sql.Node, error) { - table, err := c.Table.TransformExpressionsUp(fn) - if err != nil { - return nil, err - } - - var exprs = make([]sql.Expression, len(c.Exprs)) - for i, e := range c.Exprs { - exprs[i], err = e.TransformUp(fn) - if err != nil { - return nil, err - } +// WithChildren implements the Node interface. +func (c *CreateIndex) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } nc := *c - nc.Table = table - nc.Exprs = exprs - + nc.Table = children[0] return &nc, nil } -// TransformUp implements the Node interface. -func (c *CreateIndex) TransformUp(fn sql.TransformNodeFunc) (sql.Node, error) { - table, err := c.Table.TransformUp(fn) - if err != nil { - return nil, err - } - - nc := *c - nc.Table = table - - return fn(&nc) -} - // getColumnsAndPrepareExpressions extracts the unique columns required by all // those expressions and fixes the indexes of the GetFields in the expressions // to match a row with only the returned columns in that same order. @@ -323,7 +293,7 @@ func getColumnsAndPrepareExpressions( var expressions = make([]sql.Expression, len(exprs)) for i, e := range exprs { - ex, err := e.TransformUp(func(e sql.Expression) (sql.Expression, error) { + ex, err := expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) { gf, ok := e.(*expression.GetField) if !ok { return e, nil diff --git a/sql/plan/cross_join.go b/sql/plan/cross_join.go index 6943719b9..f7dc1d80d 100644 --- a/sql/plan/cross_join.go +++ b/sql/plan/cross_join.go @@ -66,34 +66,13 @@ func (p *CrossJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { }), nil } -// TransformUp implements the Transformable interface. -func (p *CrossJoin) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := p.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := p.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewCrossJoin(left, right)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (p *CrossJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := p.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - right, err := p.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *CrossJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2) } - return NewCrossJoin(left, right), nil + return NewCrossJoin(children[0], children[1]), nil } func (p *CrossJoin) String() string { diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index b92075e52..3eb97331e 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -1,8 +1,8 @@ package plan import ( - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" ) // ErrCreateTable is thrown when the database doesn't support table creation @@ -64,13 +64,11 @@ func (c *CreateTable) Schema() sql.Schema { return nil } // Children implements the Node interface. func (c *CreateTable) Children() []sql.Node { return nil } -// TransformUp implements the Transformable interface. -func (c *CreateTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewCreateTable(c.db, c.name, c.schema)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (c *CreateTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { +// WithChildren implements the Node interface. +func (c *CreateTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) + } return c, nil } diff --git a/sql/plan/describe.go b/sql/plan/describe.go index aed8aeaf5..e84cdc8a0 100644 --- a/sql/plan/describe.go +++ b/sql/plan/describe.go @@ -33,22 +33,13 @@ func (d *Describe) RowIter(ctx *sql.Context) (sql.RowIter, error) { return &describeIter{schema: d.Child.Schema()}, nil } -// TransformUp implements the Transformable interface. -func (d *Describe) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *Describe) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - return f(NewDescribe(child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (d *Describe) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := d.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewDescribe(child), nil + return NewDescribe(children[0]), nil } func (d Describe) String() string { @@ -116,22 +107,11 @@ func (d *DescribeQuery) String() string { return pr.String() } -// TransformUp implements the Node interface. -func (d *DescribeQuery) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewDescribeQuery(d.Format, child)) -} - -// TransformExpressionsUp implements the Node interface. -func (d *DescribeQuery) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := d.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *DescribeQuery) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - return NewDescribeQuery(d.Format, child), nil + return NewDescribeQuery(d.Format, children[0]), nil } diff --git a/sql/plan/distinct.go b/sql/plan/distinct.go index 556c1377f..e14aee5fb 100644 --- a/sql/plan/distinct.go +++ b/sql/plan/distinct.go @@ -37,22 +37,13 @@ func (d *Distinct) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, newDistinctIter(it)), nil } -// TransformUp implements the Transformable interface. -func (d *Distinct) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *Distinct) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - return f(NewDistinct(child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (d *Distinct) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := d.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewDistinct(child), nil + return NewDistinct(children[0]), nil } func (d Distinct) String() string { @@ -135,22 +126,13 @@ func (d *OrderedDistinct) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, newOrderedDistinctIter(it, d.Child.Schema())), nil } -// TransformUp implements the Transformable interface. -func (d *OrderedDistinct) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *OrderedDistinct) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - return f(NewOrderedDistinct(child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (d *OrderedDistinct) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := d.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewOrderedDistinct(child), nil + return NewOrderedDistinct(children[0]), nil } func (d OrderedDistinct) String() string { diff --git a/sql/plan/drop_index.go b/sql/plan/drop_index.go index e6917ec04..f08caf711 100644 --- a/sql/plan/drop_index.go +++ b/sql/plan/drop_index.go @@ -1,9 +1,9 @@ package plan import ( - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/internal/similartext" "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" ) var ( @@ -104,26 +104,13 @@ func (d *DropIndex) String() string { return pr.String() } -// TransformExpressionsUp implements the Node interface. -func (d *DropIndex) TransformExpressionsUp(fn sql.TransformExprFunc) (sql.Node, error) { - t, err := d.Table.TransformExpressionsUp(fn) - if err != nil { - return nil, err - } - - nc := *d - nc.Table = t - return &nc, nil -} - -// TransformUp implements the Node interface. -func (d *DropIndex) TransformUp(fn sql.TransformNodeFunc) (sql.Node, error) { - t, err := d.Table.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *DropIndex) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - nc := *d - nc.Table = t - return fn(&nc) + nd := *d + nd.Table = children[0] + return &nd, nil } diff --git a/sql/plan/empty_table.go b/sql/plan/empty_table.go index 9a10ebe73..198cef41d 100644 --- a/sql/plan/empty_table.go +++ b/sql/plan/empty_table.go @@ -16,12 +16,11 @@ func (emptyTable) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(), nil } -// TransformUp implements the Transformable interface. -func (e *emptyTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(e) -} +// WithChildren implements the Node interface. +func (e *emptyTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (e *emptyTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return e, nil } diff --git a/sql/plan/exchange.go b/sql/plan/exchange.go index 9c4347369..8ff51ee6f 100644 --- a/sql/plan/exchange.go +++ b/sql/plan/exchange.go @@ -61,24 +61,13 @@ func (e *Exchange) String() string { return p.String() } -// TransformUp implements the sql.Node interface. -func (e *Exchange) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (e *Exchange) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - return f(NewExchange(e.Parallelism, child)) -} - -// TransformExpressionsUp implements the sql.Node interface. -func (e *Exchange) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := e.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewExchange(e.Parallelism, child), nil + return NewExchange(e.Parallelism, children[0]), nil } type exchangeRowIter struct { @@ -208,7 +197,7 @@ func (it *exchangeRowIter) iterPartitions(ch chan<- sql.Partition) { } func (it *exchangeRowIter) iterPartition(p sql.Partition) { - node, err := it.tree.TransformUp(func(n sql.Node) (sql.Node, error) { + node, err := TransformUp(it.tree, func(n sql.Node) (sql.Node, error) { if t, ok := n.(sql.Table); ok { return &exchangePartition{p, t}, nil } @@ -310,10 +299,11 @@ func (p *exchangePartition) Schema() sql.Schema { return p.table.Schema() } -func (p *exchangePartition) TransformExpressionsUp(sql.TransformExprFunc) (sql.Node, error) { - return p, nil -} +// WithChildren implements the Node interface. +func (p *exchangePartition) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } -func (p *exchangePartition) TransformUp(sql.TransformNodeFunc) (sql.Node, error) { return p, nil } diff --git a/sql/plan/exchange_test.go b/sql/plan/exchange_test.go index eca284216..5a8e8e317 100644 --- a/sql/plan/exchange_test.go +++ b/sql/plan/exchange_test.go @@ -6,9 +6,9 @@ import ( "io" "testing" - "github.com/stretchr/testify/require" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" ) func TestExchange(t *testing.T) { @@ -106,11 +106,12 @@ type partitionable struct { rowsPerPartition int } -func (p *partitionable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(p) -} +// WithChildren implements the Node interface. +func (p *partitionable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } -func (p *partitionable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return p, nil } diff --git a/sql/plan/filter.go b/sql/plan/filter.go index a86c51eab..f160b3f66 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -36,28 +36,22 @@ func (p *Filter) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, NewFilterIter(ctx, p.Expression, i)), nil } -// TransformUp implements the Transformable interface. -func (p *Filter) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := p.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *Filter) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - return f(NewFilter(p.Expression, child)) + + return NewFilter(p.Expression, children[0]), nil } -// TransformExpressionsUp implements the Transformable interface. -func (p *Filter) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - expr, err := p.Expression.TransformUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (p *Filter) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), 1) } - child, err := p.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewFilter(expr, child), nil + return NewFilter(exprs[0], p.Child), nil } func (p *Filter) String() string { @@ -72,16 +66,6 @@ func (p *Filter) Expressions() []sql.Expression { return []sql.Expression{p.Expression} } -// TransformExpressions implements the Expressioner interface. -func (p *Filter) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - e, err := p.Expression.TransformUp(f) - if err != nil { - return nil, err - } - - return NewFilter(e, p.Child), nil -} - // FilterIter is an iterator that filters another iterator and skips rows that // don't match the given condition. type FilterIter struct { diff --git a/sql/plan/generate.go b/sql/plan/generate.go index 841a45890..c259b8d8d 100644 --- a/sql/plan/generate.go +++ b/sql/plan/generate.go @@ -46,48 +46,30 @@ func (g *Generate) RowIter(ctx *sql.Context) (sql.RowIter, error) { }), nil } -func (g *Generate) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - col, err := g.Column.TransformUp(f) - if err != nil { - return nil, err - } - - field, ok := col.(*expression.GetField) - if !ok { - return nil, fmt.Errorf("column of Generate node transformed into %T, must be GetField", col) - } - - return NewGenerate(g.Child, field), nil -} +// Expressions implements the Expressioner interface. +func (g *Generate) Expressions() []sql.Expression { return []sql.Expression{g.Column} } -// TransformUp implements the sql.Node interface. -func (g *Generate) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := g.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (g *Generate) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(g, len(children), 1) } - return f(NewGenerate(child, g.Column)) + return NewGenerate(children[0], g.Column), nil } -// TransformExpressionsUp implements the sql.Node interface. -func (g *Generate) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := g.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - col, err := g.Column.TransformUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (g *Generate) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(g, len(exprs), 1) } - field, ok := col.(*expression.GetField) + gf, ok := exprs[0].(*expression.GetField) if !ok { - return nil, fmt.Errorf("column of Generate node transformed into %T, must be GetField", col) + return nil, fmt.Errorf("Generate expects child to be expression.GetField, but is %T", exprs[0]) } - return NewGenerate(child, field), nil + return NewGenerate(g.Child, gf), nil } func (g *Generate) String() string { diff --git a/sql/plan/generate_test.go b/sql/plan/generate_test.go index ba35f33cf..7f32db68c 100644 --- a/sql/plan/generate_test.go +++ b/sql/plan/generate_test.go @@ -3,9 +3,9 @@ package plan import ( "testing" - "github.com/stretchr/testify/require" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" ) func TestGenerateRowIter(t *testing.T) { @@ -80,9 +80,6 @@ func (n *fakeNode) Resolved() bool { return true } func (n *fakeNode) Schema() sql.Schema { return n.schema } func (n *fakeNode) RowIter(*sql.Context) (sql.RowIter, error) { return n.iter, nil } func (n *fakeNode) String() string { return "fakeNode" } -func (n *fakeNode) TransformUp(sql.TransformNodeFunc) (sql.Node, error) { - panic("placeholder") -} -func (n *fakeNode) TransformExpressionsUp(sql.TransformExprFunc) (sql.Node, error) { +func (*fakeNode) WithChildren(children ...sql.Node) (sql.Node, error) { panic("placeholder") } diff --git a/sql/plan/group_by.go b/sql/plan/group_by.go index a61b9204a..620167568 100644 --- a/sql/plan/group_by.go +++ b/sql/plan/group_by.go @@ -7,9 +7,9 @@ import ( "strings" opentracing "github.com/opentracing/opentracing-go" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + errors "gopkg.in/src-d/go-errors.v1" ) // ErrGroupBy is returned when the aggregation is not supported. @@ -93,33 +93,34 @@ func (p *GroupBy) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, iter), nil } -// TransformUp implements the Transformable interface. -func (p *GroupBy) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := p.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *GroupBy) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - return f(NewGroupBy(p.Aggregate, p.Grouping, child)) + + return NewGroupBy(p.Aggregate, p.Grouping, children[0]), nil } -// TransformExpressionsUp implements the Transformable interface. -func (p *GroupBy) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - aggregate, err := transformExpressionsUp(f, p.Aggregate) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *GroupBy) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + expected := len(p.Aggregate) + len(p.Grouping) + if len(exprs) != expected { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), expected) } - grouping, err := transformExpressionsUp(f, p.Grouping) - if err != nil { - return nil, err + var agg = make([]sql.Expression, len(p.Aggregate)) + for i := 0; i < len(p.Aggregate); i++ { + agg[i] = exprs[i] } - child, err := p.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err + var grouping = make([]sql.Expression, len(p.Grouping)) + offset := len(p.Aggregate) + for i := 0; i < len(p.Grouping); i++ { + grouping[i] = exprs[i+offset] } - return NewGroupBy(aggregate, grouping, child), nil + return NewGroupBy(agg, grouping, p.Child), nil } func (p *GroupBy) String() string { @@ -152,21 +153,6 @@ func (p *GroupBy) Expressions() []sql.Expression { return exprs } -// TransformExpressions implements the Expressioner interface. -func (p *GroupBy) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - agg, err := transformExpressionsUp(f, p.Aggregate) - if err != nil { - return nil, err - } - - group, err := transformExpressionsUp(f, p.Grouping) - if err != nil { - return nil, err - } - - return NewGroupBy(agg, group, p.Child), nil -} - type groupByIter struct { aggregate []sql.Expression child sql.RowIter diff --git a/sql/plan/having.go b/sql/plan/having.go index 4a7f81b85..48a8bde25 100644 --- a/sql/plan/having.go +++ b/sql/plan/having.go @@ -24,39 +24,22 @@ func (h *Having) Resolved() bool { return h.Cond.Resolved() && h.Child.Resolved( // Expressions implements the sql.Expressioner interface. func (h *Having) Expressions() []sql.Expression { return []sql.Expression{h.Cond} } -// TransformExpressions implements the sql.Expressioner interface. -func (h *Having) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - e, err := h.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return &Having{h.UnaryNode, e}, nil -} - -// TransformExpressionsUp implements the sql.Node interface. -func (h *Having) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := h.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (h *Having) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1) } - e, err := h.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return &Having{UnaryNode{child}, e}, nil + return NewHaving(h.Cond, children[0]), nil } -// TransformUp implements the sql.Node interface. -func (h *Having) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := h.Child.TransformUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (h *Having) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(h, len(exprs), 1) } - return f(&Having{UnaryNode{child}, h.Cond}) + return NewHaving(exprs[0], h.Child), nil } // RowIter implements the sql.Node interface. diff --git a/sql/plan/insert.go b/sql/plan/insert.go index ebd6482d3..538fc160a 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -4,9 +4,9 @@ import ( "io" "strings" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + "gopkg.in/src-d/go-errors.v1" ) // ErrInsertIntoNotSupported is thrown when a table doesn't support inserts @@ -123,34 +123,13 @@ func (p *InsertInto) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(sql.NewRow(int64(n))), nil } -// TransformUp implements the Transformable interface. -func (p *InsertInto) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := p.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := p.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewInsertInto(left, right, p.Columns)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (p *InsertInto) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := p.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - right, err := p.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *InsertInto) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2) } - return NewInsertInto(left, right, p.Columns), nil + return NewInsertInto(children[0], children[1], p.Columns), nil } func (p InsertInto) String() string { diff --git a/sql/plan/join.go b/sql/plan/join.go index 1b3475908..a287a424a 100644 --- a/sql/plan/join.go +++ b/sql/plan/join.go @@ -83,39 +83,22 @@ func (j *InnerJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { return joinRowIter(ctx, innerJoin, j.Left, j.Right, j.Cond) } -// TransformUp implements the Transformable interface. -func (j *InnerJoin) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := j.Left.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (j *InnerJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(children), 2) } - right, err := j.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewInnerJoin(left, right, j.Cond)) + return NewInnerJoin(children[0], children[1], j.Cond), nil } -// TransformExpressionsUp implements the Transformable interface. -func (j *InnerJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := j.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (j *InnerJoin) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(exprs), 1) } - right, err := j.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return NewInnerJoin(left, right, cond), nil + return NewInnerJoin(j.Left, j.Right, exprs[0]), nil } func (j *InnerJoin) String() string { @@ -130,16 +113,6 @@ func (j *InnerJoin) Expressions() []sql.Expression { return []sql.Expression{j.Cond} } -// TransformExpressions implements the Expressioner interface. -func (j *InnerJoin) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return NewInnerJoin(j.Left, j.Right, cond), nil -} - // LeftJoin is a left join between two tables. type LeftJoin struct { BinaryNode @@ -172,39 +145,22 @@ func (j *LeftJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { return joinRowIter(ctx, leftJoin, j.Left, j.Right, j.Cond) } -// TransformUp implements the Transformable interface. -func (j *LeftJoin) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := j.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := j.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (j *LeftJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(children), 1) } - return f(NewLeftJoin(left, right, j.Cond)) + return NewLeftJoin(children[0], children[1], j.Cond), nil } -// TransformExpressionsUp implements the Transformable interface. -func (j *LeftJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := j.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - right, err := j.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (j *LeftJoin) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(exprs), 1) } - return NewLeftJoin(left, right, cond), nil + return NewLeftJoin(j.Left, j.Right, exprs[0]), nil } func (j *LeftJoin) String() string { @@ -219,16 +175,6 @@ func (j *LeftJoin) Expressions() []sql.Expression { return []sql.Expression{j.Cond} } -// TransformExpressions implements the Expressioner interface. -func (j *LeftJoin) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return NewLeftJoin(j.Left, j.Right, cond), nil -} - // RightJoin is a left join between two tables. type RightJoin struct { BinaryNode @@ -261,39 +207,22 @@ func (j *RightJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { return joinRowIter(ctx, rightJoin, j.Left, j.Right, j.Cond) } -// TransformUp implements the Transformable interface. -func (j *RightJoin) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := j.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := j.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (j *RightJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(children), 2) } - return f(NewRightJoin(left, right, j.Cond)) + return NewRightJoin(children[0], children[1], j.Cond), nil } -// TransformExpressionsUp implements the Transformable interface. -func (j *RightJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := j.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - right, err := j.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (j *RightJoin) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(exprs), 1) } - return NewRightJoin(left, right, cond), nil + return NewRightJoin(j.Left, j.Right, exprs[0]), nil } func (j *RightJoin) String() string { @@ -308,16 +237,6 @@ func (j *RightJoin) Expressions() []sql.Expression { return []sql.Expression{j.Cond} } -// TransformExpressions implements the Expressioner interface. -func (j *RightJoin) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return NewRightJoin(j.Left, j.Right, cond), nil -} - type joinType byte const ( diff --git a/sql/plan/limit.go b/sql/plan/limit.go index 004b750bc..2d8446562 100644 --- a/sql/plan/limit.go +++ b/sql/plan/limit.go @@ -7,8 +7,6 @@ import ( "github.com/src-d/go-mysql-server/sql" ) -var _ sql.Node = &Limit{} - // Limit is a node that only allows up to N rows to be retrieved. type Limit struct { UnaryNode @@ -40,22 +38,13 @@ func (l *Limit) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, &limitIter{l, 0, li}), nil } -// TransformUp implements the Transformable interface. -func (l *Limit) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := l.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (l *Limit) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) } - return f(NewLimit(l.size, child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (l *Limit) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := l.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewLimit(l.size, child), nil + return NewLimit(l.size, children[0]), nil } func (l Limit) String() string { diff --git a/sql/plan/lock.go b/sql/plan/lock.go index d7582c947..8e51edec6 100644 --- a/sql/plan/lock.go +++ b/sql/plan/lock.go @@ -3,8 +3,8 @@ package plan import ( "fmt" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // TableLock is a read or write lock on a table. @@ -25,8 +25,6 @@ func NewLockTables(locks []*TableLock) *LockTables { return &LockTables{Locks: locks} } -var _ sql.Node = (*LockTables)(nil) - // Children implements the sql.Node interface. func (t *LockTables) Children() []sql.Node { var children = make([]sql.Node, len(t.Locks)) @@ -89,25 +87,21 @@ func (t *LockTables) String() string { return p.String() } -// TransformUp implements the sql.Node interface. -func (t *LockTables) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - var children = make([]*TableLock, len(t.Locks)) - for i, l := range t.Locks { - node, err := l.Table.TransformUp(f) - if err != nil { - return nil, err - } - children[i] = &TableLock{node, l.Write} +// WithChildren implements the Node interface. +func (t *LockTables) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != len(t.Locks) { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), len(t.Locks)) } - nt := *t - nt.Locks = children - return f(&nt) -} + var locks = make([]*TableLock, len(t.Locks)) + for i, n := range children { + locks[i] = &TableLock{ + Table: n, + Write: t.Locks[i].Write, + } + } -// TransformExpressionsUp implements the sql.Node interface. -func (t *LockTables) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - return t, nil + return &LockTables{t.Catalog, locks}, nil } // ErrTableNotLockable is returned whenever a lockable table can't be found. @@ -143,8 +137,6 @@ func NewUnlockTables() *UnlockTables { return new(UnlockTables) } -var _ sql.Node = (*UnlockTables)(nil) - // Children implements the sql.Node interface. func (t *UnlockTables) Children() []sql.Node { return nil } @@ -172,12 +164,11 @@ func (t *UnlockTables) String() string { return p.String() } -// TransformUp implements the sql.Node interface. -func (t *UnlockTables) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(t) -} +// WithChildren implements the Node interface. +func (t *UnlockTables) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 0) + } -// TransformExpressionsUp implements the sql.Node interface. -func (t *UnlockTables) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return t, nil } diff --git a/sql/plan/naturaljoin.go b/sql/plan/naturaljoin.go index fe8ec7a8b..6ccf0182b 100644 --- a/sql/plan/naturaljoin.go +++ b/sql/plan/naturaljoin.go @@ -35,32 +35,11 @@ func (j NaturalJoin) String() string { return pr.String() } -// TransformUp implements the Node interface. -func (j *NaturalJoin) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := j.Left.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (j *NaturalJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(children), 2) } - right, err := j.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewNaturalJoin(left, right)) -} - -// TransformExpressionsUp implements the Node interface. -func (j *NaturalJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := j.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - right, err := j.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewNaturalJoin(left, right), nil + return NewNaturalJoin(children[0], children[1]), nil } diff --git a/sql/plan/nothing.go b/sql/plan/nothing.go index 70d12c748..43792405d 100644 --- a/sql/plan/nothing.go +++ b/sql/plan/nothing.go @@ -7,8 +7,6 @@ var Nothing nothing type nothing struct{} -var _ sql.Node = nothing{} - func (nothing) String() string { return "NOTHING" } func (nothing) Resolved() bool { return true } func (nothing) Schema() sql.Schema { return nil } @@ -16,9 +14,12 @@ func (nothing) Children() []sql.Node { return nil } func (nothing) RowIter(*sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(), nil } -func (nothing) TransformUp(sql.TransformNodeFunc) (sql.Node, error) { - return Nothing, nil -} -func (nothing) TransformExpressionsUp(sql.TransformExprFunc) (sql.Node, error) { + +// WithChildren implements the Node interface. +func (n nothing) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0) + } + return Nothing, nil } diff --git a/sql/plan/offset.go b/sql/plan/offset.go index da481b7fa..514041fb6 100644 --- a/sql/plan/offset.go +++ b/sql/plan/offset.go @@ -36,22 +36,13 @@ func (o *Offset) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, &offsetIter{o.n, it}), nil } -// TransformUp implements the Transformable interface. -func (o *Offset) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := o.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (o *Offset) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 1) } - return f(NewOffset(o.n, child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (o *Offset) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := o.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewOffset(o.n, child), nil + return NewOffset(o.n, children[0]), nil } func (o Offset) String() string { diff --git a/sql/plan/process.go b/sql/plan/process.go index e565e425f..32a8452d6 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -21,28 +21,13 @@ func NewQueryProcess(node sql.Node, notify NotifyFunc) *QueryProcess { return &QueryProcess{UnaryNode{Child: node}, notify} } -// TransformUp implements the sql.Node interface. -func (p *QueryProcess) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - n, err := p.Child.TransformUp(f) - if err != nil { - return nil, err - } - - np := *p - np.Child = n - return &np, nil -} - -// TransformExpressionsUp implements the sql.Node interface. -func (p *QueryProcess) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - n, err := p.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *QueryProcess) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - np := *p - np.Child = n - return &np, nil + return NewQueryProcess(children[0], p.Notify), nil } // RowIter implements the sql.Node interface. diff --git a/sql/plan/processlist.go b/sql/plan/processlist.go index 8a362a65a..bc9f4d18c 100644 --- a/sql/plan/processlist.go +++ b/sql/plan/processlist.go @@ -58,13 +58,12 @@ func (p *ShowProcessList) Children() []sql.Node { return nil } // Resolved implements the Node interface. func (p *ShowProcessList) Resolved() bool { return true } -// TransformUp implements the Node interface. -func (p *ShowProcessList) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(p) -} +// WithChildren implements the Node interface. +func (p *ShowProcessList) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } -// TransformExpressionsUp implements the Node interface. -func (p *ShowProcessList) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return p, nil } diff --git a/sql/plan/project.go b/sql/plan/project.go index 99779e86b..8b166d449 100644 --- a/sql/plan/project.go +++ b/sql/plan/project.go @@ -69,30 +69,6 @@ func (p *Project) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, &iter{p, i, ctx}), nil } -// TransformUp implements the Transformable interface. -func (p *Project) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := p.Child.TransformUp(f) - if err != nil { - return nil, err - } - return f(NewProject(p.Projections, child)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (p *Project) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - exprs, err := transformExpressionsUp(f, p.Projections) - if err != nil { - return nil, err - } - - child, err := p.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewProject(exprs, child), nil -} - func (p *Project) String() string { pr := sql.NewTreePrinter() var exprs = make([]string, len(p.Projections)) @@ -109,14 +85,22 @@ func (p *Project) Expressions() []sql.Expression { return p.Projections } -// TransformExpressions implements the Expressioner interface. -func (p *Project) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - projects, err := transformExpressionsUp(f, p.Projections) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *Project) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + + return NewProject(p.Projections, children[0]), nil +} + +// WithExpressions implements the Expressioner interface. +func (p *Project) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(p.Projections) { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), len(p.Projections)) } - return NewProject(projects, p.Child), nil + return NewProject(exprs, p.Child), nil } type iter struct { diff --git a/sql/plan/resolved_table.go b/sql/plan/resolved_table.go index 2706739d2..dbbd689a1 100644 --- a/sql/plan/resolved_table.go +++ b/sql/plan/resolved_table.go @@ -44,13 +44,12 @@ func (t *ResolvedTable) RowIter(ctx *sql.Context) (sql.RowIter, error) { }), nil } -// TransformUp implements the Transformable interface. -func (t *ResolvedTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewResolvedTable(t.Table)) -} +// WithChildren implements the Node interface. +func (t *ResolvedTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (t *ResolvedTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return t, nil } diff --git a/sql/plan/set.go b/sql/plan/set.go index 08f994663..9b8101675 100644 --- a/sql/plan/set.go +++ b/sql/plan/set.go @@ -42,41 +42,41 @@ func (s *Set) Resolved() bool { // Children implements the sql.Node interface. func (s *Set) Children() []sql.Node { return nil } -// TransformUp implements the sql.Node interface. -func (s *Set) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(s) -} +// WithChildren implements the Node interface. +func (s *Set) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } -// TransformExpressions implements sql.Expressioner interface. -func (s *Set) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - return s.TransformExpressionsUp(f) + return s, nil } -// Expressions implements the sql.Expressioner interface. -func (s *Set) Expressions() []sql.Expression { - var exprs = make([]sql.Expression, len(s.Variables)) - for i, v := range s.Variables { - exprs[i] = v.Value +// WithExpressions implements the Expressioner interface. +func (s *Set) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(s.Variables) { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(exprs), len(s.Variables)) } - return exprs -} -// TransformExpressionsUp implements the sql.Node interface. -func (s *Set) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { var vars = make([]SetVariable, len(s.Variables)) for i, v := range s.Variables { - val, err := v.Value.TransformUp(f) - if err != nil { - return nil, err + vars[i] = SetVariable{ + Name: v.Name, + Value: exprs[i], } - - vars[i] = v - vars[i].Value = val } return NewSet(vars...), nil } +// Expressions implements the sql.Expressioner interface. +func (s *Set) Expressions() []sql.Expression { + var exprs = make([]sql.Expression, len(s.Variables)) + for i, v := range s.Variables { + exprs[i] = v.Value + } + return exprs +} + // RowIter implements the sql.Node interface. func (s *Set) RowIter(ctx *sql.Context) (sql.RowIter, error) { span, ctx := ctx.Span("plan.Set") diff --git a/sql/plan/show_collation.go b/sql/plan/show_collation.go index 4b01febb3..3c833b246 100644 --- a/sql/plan/show_collation.go +++ b/sql/plan/show_collation.go @@ -42,12 +42,11 @@ func (ShowCollation) RowIter(ctx *sql.Context) (sql.RowIter, error) { // Schema implements the sql.Node interface. func (ShowCollation) Schema() sql.Schema { return collationSchema } -// TransformUp implements the sql.Node interface. -func (ShowCollation) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(ShowCollation{}) -} +// WithChildren implements the Node interface. +func (s ShowCollation) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } -// TransformExpressionsUp implements the sql.Node interface. -func (ShowCollation) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - return ShowCollation{}, nil + return s, nil } diff --git a/sql/plan/show_create_database.go b/sql/plan/show_create_database.go index 57ca9764b..96de06175 100644 --- a/sql/plan/show_create_database.go +++ b/sql/plan/show_create_database.go @@ -82,12 +82,11 @@ func (s *ShowCreateDatabase) Resolved() bool { return !ok } -// TransformExpressionsUp implements the sql.Node interface. -func (s *ShowCreateDatabase) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - return s, nil -} +// WithChildren implements the Node interface. +func (s *ShowCreateDatabase) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } -// TransformUp implements the sql.Node interface. -func (s *ShowCreateDatabase) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(s) + return s, nil } diff --git a/sql/plan/show_create_table.go b/sql/plan/show_create_table.go index 89c27f6f4..7b627b54a 100644 --- a/sql/plan/show_create_table.go +++ b/sql/plan/show_create_table.go @@ -25,14 +25,13 @@ func (n *ShowCreateTable) Schema() sql.Schema { } } -// TransformExpressionsUp implements the Transformable interface. -func (n *ShowCreateTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - return n, nil -} +// WithChildren implements the Node interface. +func (n *ShowCreateTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0) + } -// TransformUp implements the Transformable interface. -func (n *ShowCreateTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewShowCreateTable(n.CurrentDatabase, n.Catalog, n.Table)) + return n, nil } // RowIter implements the Node interface diff --git a/sql/plan/show_indexes.go b/sql/plan/show_indexes.go index b28a4cbc3..d427509ea 100644 --- a/sql/plan/show_indexes.go +++ b/sql/plan/show_indexes.go @@ -39,13 +39,12 @@ func (n *ShowIndexes) Resolved() bool { return !ok } -// TransformUp implements the Transformable interface. -func (n *ShowIndexes) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewShowIndexes(n.db, n.Table, n.Registry)) -} +// WithChildren implements the Node interface. +func (n *ShowIndexes) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (n *ShowIndexes) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return n, nil } diff --git a/sql/plan/show_tables.go b/sql/plan/show_tables.go index 7d44f73ae..930cb20b4 100644 --- a/sql/plan/show_tables.go +++ b/sql/plan/show_tables.go @@ -84,13 +84,12 @@ func (p *ShowTables) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(rows...), nil } -// TransformUp implements the Transformable interface. -func (p *ShowTables) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewShowTables(p.db, p.Full)) -} +// WithChildren implements the Node interface. +func (p *ShowTables) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (p *ShowTables) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return p, nil } diff --git a/sql/plan/showcolumns.go b/sql/plan/showcolumns.go index 88a6a63b6..d2094fe3f 100644 --- a/sql/plan/showcolumns.go +++ b/sql/plan/showcolumns.go @@ -104,24 +104,13 @@ func (s *ShowColumns) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, sql.RowsToRowIter(rows...)), nil } -// TransformUp creates a new ShowColumns node. -func (s *ShowColumns) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := s.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (s *ShowColumns) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) } - return f(NewShowColumns(s.Full, child)) -} - -// TransformExpressionsUp creates a new ShowColumns node. -func (s *ShowColumns) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := s.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewShowColumns(s.Full, child), nil + return NewShowColumns(s.Full, children[0]), nil } func (s *ShowColumns) String() string { diff --git a/sql/plan/showdatabases.go b/sql/plan/showdatabases.go index f370d7eff..41873c7be 100644 --- a/sql/plan/showdatabases.go +++ b/sql/plan/showdatabases.go @@ -53,16 +53,15 @@ func (p *ShowDatabases) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(rows...), nil } -// TransformUp implements the Transformable interface. -func (p *ShowDatabases) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - np := *p - return f(&np) -} +// WithChildren implements the Node interface. +func (p *ShowDatabases) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (p *ShowDatabases) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return p, nil } + func (p ShowDatabases) String() string { return "ShowDatabases" } diff --git a/sql/plan/showtablestatus.go b/sql/plan/showtablestatus.go index b30bf912f..158a1d65c 100644 --- a/sql/plan/showtablestatus.go +++ b/sql/plan/showtablestatus.go @@ -86,13 +86,12 @@ func (s *ShowTableStatus) String() string { return fmt.Sprintf("ShowTableStatus(%s)", strings.Join(s.Databases, ", ")) } -// TransformUp implements the sql.Node interface. -func (s *ShowTableStatus) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(s) -} +// WithChildren implements the Node interface. +func (s *ShowTableStatus) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } -// TransformExpressionsUp implements the sql.Node interface. -func (s *ShowTableStatus) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return s, nil } diff --git a/sql/plan/showvariables.go b/sql/plan/showvariables.go index 7c4648d5e..8cb7fa0c1 100644 --- a/sql/plan/showvariables.go +++ b/sql/plan/showvariables.go @@ -28,13 +28,12 @@ func (sv *ShowVariables) Resolved() bool { return true } -// TransformUp implements the sq.Transformable interface. -func (sv *ShowVariables) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewShowVariables(sv.config, sv.pattern)) -} +// WithChildren implements the Node interface. +func (sv *ShowVariables) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(sv, len(children), 0) + } -// TransformExpressionsUp implements the sql.Transformable interface. -func (sv *ShowVariables) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return sv, nil } diff --git a/sql/plan/showwarnings.go b/sql/plan/showwarnings.go index 5de5fb1f7..c990bfc81 100644 --- a/sql/plan/showwarnings.go +++ b/sql/plan/showwarnings.go @@ -12,13 +12,12 @@ func (ShowWarnings) Resolved() bool { return true } -// TransformUp implements the sq.Transformable interface. -func (sw ShowWarnings) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(sw) -} +// WithChildren implements the Node interface. +func (sw ShowWarnings) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(sw, len(children), 0) + } -// TransformExpressionsUp implements the sql.Transformable interface. -func (sw ShowWarnings) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return sw, nil } diff --git a/sql/plan/sort.go b/sql/plan/sort.go index d6836ba40..8235b2386 100644 --- a/sql/plan/sort.go +++ b/sql/plan/sort.go @@ -6,8 +6,8 @@ import ( "sort" "strings" - "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" ) // ErrUnableSort is thrown when something happens on sorting @@ -91,34 +91,6 @@ func (s *Sort) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, newSortIter(s, i)), nil } -// TransformUp implements the Transformable interface. -func (s *Sort) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := s.Child.TransformUp(f) - if err != nil { - return nil, err - } - return f(NewSort(s.SortFields, child)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (s *Sort) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - var sfs = make([]SortField, len(s.SortFields)) - for i, sf := range s.SortFields { - col, err := sf.Column.TransformUp(f) - if err != nil { - return nil, err - } - sfs[i] = SortField{col, sf.Order, sf.NullOrdering} - } - - child, err := s.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewSort(sfs, child), nil -} - func (s *Sort) String() string { pr := sql.NewTreePrinter() var fields = make([]string, len(s.SortFields)) @@ -139,22 +111,31 @@ func (s *Sort) Expressions() []sql.Expression { return exprs } -// TransformExpressions implements the Expressioner interface. -func (s *Sort) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - var sortFields = make([]SortField, len(s.SortFields)) - for i, field := range s.SortFields { - transformed, err := field.Column.TransformUp(f) - if err != nil { - return nil, err - } - sortFields[i] = SortField{ - Column: transformed, - Order: field.Order, - NullOrdering: field.NullOrdering, +// WithChildren implements the Node interface. +func (s *Sort) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) + } + + return NewSort(s.SortFields, children[0]), nil +} + +// WithExpressions implements the Expressioner interface. +func (s *Sort) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(s.SortFields) { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(exprs), len(s.SortFields)) + } + + var fields = make([]SortField, len(s.SortFields)) + for i, expr := range exprs { + fields[i] = SortField{ + Column: expr, + NullOrdering: s.SortFields[i].NullOrdering, + Order: s.SortFields[i].Order, } } - return NewSort(sortFields, s.Child), nil + return NewSort(fields, s.Child), nil } type sortIter struct { diff --git a/sql/plan/subqueryalias.go b/sql/plan/subqueryalias.go index 70233a075..da1264c88 100644 --- a/sql/plan/subqueryalias.go +++ b/sql/plan/subqueryalias.go @@ -45,16 +45,22 @@ func (n *SubqueryAlias) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, iter), nil } -// TransformUp implements the Node interface. -func (n *SubqueryAlias) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(n) -} +// WithChildren implements the Node interface. +func (n *SubqueryAlias) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 1) + } -// TransformExpressionsUp implements the Node interface. -func (n *SubqueryAlias) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { + nn := *n + nn.Child = children[0] return n, nil } +// Opaque implements the OpaqueNode interface. +func (n *SubqueryAlias) Opaque() bool { + return true +} + func (n SubqueryAlias) String() string { pr := sql.NewTreePrinter() _ = pr.WriteNode("SubqueryAlias(%s)", n.name) diff --git a/sql/plan/tablealias.go b/sql/plan/tablealias.go index 02795b177..be37fb109 100644 --- a/sql/plan/tablealias.go +++ b/sql/plan/tablealias.go @@ -23,22 +23,13 @@ func (t *TableAlias) Name() string { return t.name } -// TransformUp implements the Transformable interface. -func (t *TableAlias) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := t.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (t *TableAlias) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1) } - return f(NewTableAlias(t.name, child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (t *TableAlias) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := t.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewTableAlias(t.name, child), nil + return NewTableAlias(t.name, children[0]), nil } // RowIter implements the Node interface. diff --git a/sql/plan/transaction.go b/sql/plan/transaction.go index 6565cffb2..3d09366ef 100644 --- a/sql/plan/transaction.go +++ b/sql/plan/transaction.go @@ -15,13 +15,12 @@ func (*Rollback) RowIter(*sql.Context) (sql.RowIter, error) { func (*Rollback) String() string { return "ROLLBACK" } -// TransformUp implements the sql.Node interface. -func (r *Rollback) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(r) -} +// WithChildren implements the Node interface. +func (r *Rollback) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0) + } -// TransformExpressionsUp implements the sql.Node interface. -func (r *Rollback) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return r, nil } diff --git a/sql/plan/transform.go b/sql/plan/transform.go new file mode 100644 index 000000000..e437327ed --- /dev/null +++ b/sql/plan/transform.go @@ -0,0 +1,89 @@ +package plan + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// TransformUp applies a transformation function to the given tree from the +// bottom up. +func TransformUp(node sql.Node, f sql.TransformNodeFunc) (sql.Node, error) { + if o, ok := node.(sql.OpaqueNode); ok && o.Opaque() { + return f(node) + } + + children := node.Children() + if len(children) == 0 { + return f(node) + } + + newChildren := make([]sql.Node, len(children)) + for i, c := range children { + c, err := TransformUp(c, f) + if err != nil { + return nil, err + } + newChildren[i] = c + } + + node, err := node.WithChildren(newChildren...) + if err != nil { + return nil, err + } + + return f(node) +} + +// TransformExpressionsUp applies a transformation function to all expressions +// on the given tree from the bottom up. +func TransformExpressionsUp(node sql.Node, f sql.TransformExprFunc) (sql.Node, error) { + if o, ok := node.(sql.OpaqueNode); ok && o.Opaque() { + return TransformExpressions(node, f) + } + + children := node.Children() + if len(children) == 0 { + return TransformExpressions(node, f) + } + + newChildren := make([]sql.Node, len(children)) + for i, c := range children { + c, err := TransformExpressionsUp(c, f) + if err != nil { + return nil, err + } + newChildren[i] = c + } + + node, err := node.WithChildren(newChildren...) + if err != nil { + return nil, err + } + + return TransformExpressions(node, f) +} + +// TransformExpressions applies a transformation function to all expressions +// on the given node. +func TransformExpressions(node sql.Node, f sql.TransformExprFunc) (sql.Node, error) { + e, ok := node.(sql.Expressioner) + if !ok { + return node, nil + } + + exprs := e.Expressions() + if len(exprs) == 0 { + return node, nil + } + + newExprs := make([]sql.Expression, len(exprs)) + for i, e := range exprs { + e, err := expression.TransformUp(e, f) + if err != nil { + return nil, err + } + newExprs[i] = e + } + + return e.WithExpressions(newExprs...) +} diff --git a/sql/plan/transform_test.go b/sql/plan/transform_test.go index 88a6cc89a..730ab128f 100644 --- a/sql/plan/transform_test.go +++ b/sql/plan/transform_test.go @@ -24,7 +24,7 @@ func TestTransformUp(t *testing.T) { } table := mem.NewTable("resolved", schema) - pt, err := p.TransformUp(func(n sql.Node) (sql.Node, error) { + pt, err := TransformUp(p, func(n sql.Node) (sql.Node, error) { switch n.(type) { case *UnresolvedTable: return NewResolvedTable(table), nil diff --git a/sql/plan/unresolved.go b/sql/plan/unresolved.go index 84160804d..9bf70a639 100644 --- a/sql/plan/unresolved.go +++ b/sql/plan/unresolved.go @@ -3,8 +3,8 @@ package plan import ( "fmt" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // ErrUnresolvedTable is thrown when a table cannot be resolved @@ -42,13 +42,12 @@ func (*UnresolvedTable) RowIter(ctx *sql.Context) (sql.RowIter, error) { return nil, ErrUnresolvedTable.New() } -// TransformUp implements the Transformable interface. -func (t *UnresolvedTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewUnresolvedTable(t.name, t.Database)) -} +// WithChildren implements the Node interface. +func (t *UnresolvedTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (t *UnresolvedTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return t, nil } diff --git a/sql/plan/use.go b/sql/plan/use.go index fb721f8a4..b7eb517b7 100644 --- a/sql/plan/use.go +++ b/sql/plan/use.go @@ -50,13 +50,12 @@ func (u *Use) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(), nil } -// TransformUp implements the sql.Node interface. -func (u *Use) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(u) -} +// WithChildren implements the Node interface. +func (u *Use) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) + } -// TransformExpressionsUp implements the sql.Node interface. -func (u *Use) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return u, nil } diff --git a/sql/plan/values.go b/sql/plan/values.go index d07eb5bff..ea1d5be39 100644 --- a/sql/plan/values.go +++ b/sql/plan/values.go @@ -76,25 +76,6 @@ func (p *Values) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(rows...), nil } -// TransformUp implements the Transformable interface. -func (p *Values) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(p) -} - -// TransformExpressionsUp implements the Transformable interface. -func (p *Values) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - ets := make([][]sql.Expression, len(p.ExpressionTuples)) - var err error - for i, et := range p.ExpressionTuples { - ets[i], err = transformExpressionsUp(f, et) - if err != nil { - return nil, err - } - } - - return NewValues(ets), nil -} - func (p *Values) String() string { return fmt.Sprintf("Values(%d tuples)", len(p.ExpressionTuples)) } @@ -108,15 +89,33 @@ func (p *Values) Expressions() []sql.Expression { return exprs } -// TransformExpressions implements the Expressioner interface. -func (p *Values) TransformExpressions(f sql.TransformExprFunc) (sql.Node, error) { - tuples := [][]sql.Expression{} - for _, tuple := range p.ExpressionTuples { - transformed, err := transformExpressionsUp(f, tuple) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *Values) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + + return p, nil +} + +// WithExpressions implements the Expressioner interface. +func (p *Values) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + var expected int + for _, t := range p.ExpressionTuples { + expected += len(t) + } + + if len(exprs) != expected { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), expected) + } + + var offset int + var tuples = make([][]sql.Expression, len(p.ExpressionTuples)) + for i, t := range p.ExpressionTuples { + for range t { + tuples[i] = append(tuples[i], exprs[offset]) + offset++ } - tuples = append(tuples, transformed) } return NewValues(tuples), nil diff --git a/sql/session_test.go b/sql/session_test.go index 3a109779f..315015ec5 100644 --- a/sql/session_test.go +++ b/sql/session_test.go @@ -51,27 +51,22 @@ func TestHasDefaultValue(t *testing.T) { type testNode struct{} -func (t *testNode) Resolved() bool { +func (*testNode) Resolved() bool { panic("not implemented") } - -func (t *testNode) TransformUp(func(Node) Node) Node { - panic("not implemented") -} - -func (t *testNode) TransformExpressionsUp(func(Expression) Expression) Node { +func (*testNode) WithChildren(...Node) (Node, error) { panic("not implemented") } -func (t *testNode) Schema() Schema { +func (*testNode) Schema() Schema { panic("not implemented") } -func (t *testNode) Children() []Node { +func (*testNode) Children() []Node { panic("not implemented") } -func (t *testNode) RowIter(ctx *Context) (RowIter, error) { +func (*testNode) RowIter(ctx *Context) (RowIter, error) { return newTestNodeIterator(ctx), nil }