diff --git a/engine_test.go b/engine_test.go index e7a0f64d3..4fb6ed35e 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1052,6 +1052,14 @@ var queries = []struct { "SELECT DATE_ADD('9999-12-31 23:59:59', INTERVAL 1 DAY)", []sql.Row{{nil}}, }, + { + `SELECT t.date_col FROM (SELECT CONVERT('2019-06-06 00:00:00', DATETIME) as date_col) t WHERE t.date_col > '0000-01-01 00:00:00'`, + []sql.Row{{time.Date(2019, time.June, 6, 0, 0, 0, 0, time.UTC)}}, + }, + { + `SELECT t.date_col FROM (SELECT CONVERT('2019-06-06 00:00:00', DATETIME) as date_col) t GROUP BY t.date_col`, + []sql.Row{{time.Date(2019, time.June, 6, 0, 0, 0, 0, time.UTC)}}, + }, } func TestQueries(t *testing.T) { diff --git a/sql/analyzer/convert_dates.go b/sql/analyzer/convert_dates.go index 72b59f721..10bddeb34 100644 --- a/sql/analyzer/convert_dates.go +++ b/sql/analyzer/convert_dates.go @@ -1,11 +1,19 @@ package analyzer import ( + "fmt" + "gopkg.in/src-d/go-mysql-server.v0/sql" "gopkg.in/src-d/go-mysql-server.v0/sql/expression" "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function" + "gopkg.in/src-d/go-mysql-server.v0/sql/plan" ) +type tableCol struct { + table string + col string +} + // convertDates wraps all expressions of date and datetime type with converts // to ensure the date range is validated. func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { @@ -13,28 +21,136 @@ func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { return n, nil } - return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) { - // No need to wrap expressions that already validate times, such as - // convert, date_add, etc and those expressions whose Type method - // cannot be called because they are placeholders. - switch e.(type) { - case *expression.Convert, - *expression.Arithmetic, - *function.DateAdd, - *function.DateSub, - *expression.Star, - *expression.DefaultColumn, - *expression.Alias: - return e, nil - default: - switch e.Type() { - case sql.Date: - return expression.NewConvert(e, expression.ConvertToDate), nil - case sql.Timestamp: - return expression.NewConvert(e, expression.ConvertToDatetime), nil - default: - return e, nil + // Replacements contains a mapping from columns to the alias they will be + // replaced by. + var replacements = make(map[tableCol]string) + + return n.TransformUp(func(n sql.Node) (sql.Node, error) { + exp, ok := n.(sql.Expressioner) + if !ok { + return n, nil + } + + // nodeReplacements are all the replacements found in the current node. + // These replacements are not applied to the current node, only to + // parent nodes. + var nodeReplacements = make(map[tableCol]string) + + var expressions = make(map[string]bool) + switch exp := exp.(type) { + case *plan.Project: + for _, e := range exp.Projections { + expressions[e.String()] = true + } + case *plan.GroupBy: + for _, e := range exp.Aggregate { + expressions[e.String()] = true } } + + var result sql.Node + var err error + switch exp := exp.(type) { + case *plan.GroupBy: + var aggregate = make([]sql.Expression, len(exp.Aggregate)) + for i, a := range exp.Aggregate { + agg, err := a.TransformUp(func(e sql.Expression) (sql.Expression, error) { + return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true) + }) + if err != nil { + return nil, err + } + aggregate[i] = agg + } + + var grouping = make([]sql.Expression, len(exp.Grouping)) + for i, g := range exp.Grouping { + gr, err := g.TransformUp(func(e sql.Expression) (sql.Expression, error) { + return addDateConvert(e, exp, replacements, nodeReplacements, expressions, false) + }) + if err != nil { + return nil, err + } + grouping[i] = gr + } + + result = plan.NewGroupBy(aggregate, grouping, exp.Child) + default: + result, err = exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + return addDateConvert(e, n, replacements, nodeReplacements, expressions, true) + }) + } + + if err != nil { + return nil, err + } + + // We're done with this node, so copy all the replacements found in + // this node to the global replacements in order to make the necesssary + // changes in parent nodes. + for tc, n := range nodeReplacements { + replacements[tc] = n + } + + return result, err }) } + +func addDateConvert( + e sql.Expression, + node sql.Node, + replacements, nodeReplacements map[tableCol]string, + expressions map[string]bool, + aliasRootProjections bool, +) (sql.Expression, error) { + var result sql.Expression + + // No need to wrap expressions that already validate times, such as + // convert, date_add, etc and those expressions whose Type method + // cannot be called because they are placeholders. + switch e.(type) { + case *expression.Convert, + *expression.Arithmetic, + *function.DateAdd, + *function.DateSub, + *expression.Star, + *expression.DefaultColumn, + *expression.Alias: + return e, nil + default: + // If it's a replacement, just replace it with the correct GetField + // because we know that it's already converted to a correct date + // and there is no point to do so again. + if gf, ok := e.(*expression.GetField); ok { + if name, ok := replacements[tableCol{gf.Table(), gf.Name()}]; ok { + return expression.NewGetField(gf.Index(), gf.Type(), name, gf.IsNullable()), nil + } + } + + switch e.Type() { + case sql.Date: + result = expression.NewConvert(e, expression.ConvertToDate) + case sql.Timestamp: + result = expression.NewConvert(e, expression.ConvertToDatetime) + default: + result = e + } + } + + // Only do this if it's a root expression in a project or group by. + switch node.(type) { + case *plan.Project, *plan.GroupBy: + // If it was originally a GetField, and it's not anymore it's + // because we wrapped it in a convert. We need to make it an alias + // and propagate the changes up the chain. + if gf, ok := e.(*expression.GetField); ok && expressions[e.String()] && aliasRootProjections { + if _, ok := result.(*expression.GetField); !ok { + name := fmt.Sprintf("%s__%s", gf.Table(), gf.Name()) + result = expression.NewAlias(result, name) + nodeReplacements[tableCol{gf.Table(), gf.Name()}] = name + } + } + } + + return result, nil +} diff --git a/sql/analyzer/convert_dates_test.go b/sql/analyzer/convert_dates_test.go index 3989b5246..bc83c0f02 100644 --- a/sql/analyzer/convert_dates_test.go +++ b/sql/analyzer/convert_dates_test.go @@ -146,6 +146,84 @@ func TestConvertDates(t *testing.T) { } } +func TestConvertDatesProject(t *testing.T) { + table := plan.NewResolvedTable(mem.NewTable("t", nil)) + input := plan.NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewLiteral("2019-06-06 00:00:00", sql.Text), + ), + plan.NewProject([]sql.Expression{ + expression.NewGetField(0, sql.Timestamp, "foo", false), + }, table), + ) + expected := plan.NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "__foo", false), + expression.NewLiteral("2019-06-06 00:00:00", sql.Text), + ), + plan.NewProject([]sql.Expression{ + expression.NewAlias( + expression.NewConvert( + expression.NewGetField(0, sql.Timestamp, "foo", false), + expression.ConvertToDatetime, + ), + "__foo", + ), + }, table), + ) + + result, err := convertDates(sql.NewEmptyContext(), nil, input) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func TestConvertDatesGroupBy(t *testing.T) { + table := plan.NewResolvedTable(mem.NewTable("t", nil)) + input := plan.NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewLiteral("2019-06-06 00:00:00", sql.Text), + ), + plan.NewGroupBy( + []sql.Expression{ + expression.NewGetField(0, sql.Timestamp, "foo", false), + }, + []sql.Expression{ + expression.NewGetField(0, sql.Timestamp, "foo", false), + }, table, + ), + ) + expected := plan.NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "__foo", false), + expression.NewLiteral("2019-06-06 00:00:00", sql.Text), + ), + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewConvert( + expression.NewGetField(0, sql.Timestamp, "foo", false), + expression.ConvertToDatetime, + ), + "__foo", + ), + }, + []sql.Expression{ + expression.NewConvert( + expression.NewGetField(0, sql.Timestamp, "foo", false), + expression.ConvertToDatetime, + ), + }, + table, + ), + ) + + result, err := convertDates(sql.NewEmptyContext(), nil, input) + require.NoError(t, err) + require.Equal(t, expected, result) +} + func newDateAdd(l, r sql.Expression) sql.Expression { e, _ := function.NewDateAdd(l, r) return e diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index 0a9112384..fcd0b2723 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -20,7 +20,6 @@ var DefaultRules = []Rule{ {"reorder_projection", reorderProjection}, {"move_join_conds_to_filter", moveJoinConditionsToFilter}, {"eval_filter", evalFilter}, - {"convert_dates", convertDates}, {"optimize_distinct", optimizeDistinct}, } @@ -36,6 +35,7 @@ var OnceBeforeDefault = []Rule{ // DefaultRules. var OnceAfterDefault = []Rule{ {"remove_unnecessary_converts", removeUnnecessaryConverts}, + {"convert_dates", convertDates}, {"assign_catalog", assignCatalog}, {"prune_columns", pruneColumns}, {"pushdown", pushdown},