Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

sql/analyzer: alias projected columns wrapped in converts #701

Merged
merged 2 commits into from
May 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,14 @@ var queries = []struct {
"SELECT DATE_ADD('9999-12-31 23:59:59', INTERVAL 1 DAY)",
[]sql.Row{{nil}},
},
{
`SELECT t.date_col FROM (SELECT CONVERT('2019-06-06 00:00:00', DATETIME) as date_col) t WHERE t.date_col > '0000-01-01 00:00:00'`,
[]sql.Row{{time.Date(2019, time.June, 6, 0, 0, 0, 0, time.UTC)}},
},
{
`SELECT t.date_col FROM (SELECT CONVERT('2019-06-06 00:00:00', DATETIME) as date_col) t GROUP BY t.date_col`,
[]sql.Row{{time.Date(2019, time.June, 6, 0, 0, 0, 0, time.UTC)}},
},
}

func TestQueries(t *testing.T) {
Expand Down
158 changes: 137 additions & 21 deletions sql/analyzer/convert_dates.go
Original file line number Diff line number Diff line change
@@ -1,40 +1,156 @@
package analyzer

import (
"fmt"

"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression/function"
"gopkg.in/src-d/go-mysql-server.v0/sql/plan"
)

type tableCol struct {
table string
col string
}

// convertDates wraps all expressions of date and datetime type with converts
// to ensure the date range is validated.
func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
if !n.Resolved() {
return n, nil
}

return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
// No need to wrap expressions that already validate times, such as
// convert, date_add, etc and those expressions whose Type method
// cannot be called because they are placeholders.
switch e.(type) {
case *expression.Convert,
*expression.Arithmetic,
*function.DateAdd,
*function.DateSub,
*expression.Star,
*expression.DefaultColumn,
*expression.Alias:
return e, nil
default:
switch e.Type() {
case sql.Date:
return expression.NewConvert(e, expression.ConvertToDate), nil
case sql.Timestamp:
return expression.NewConvert(e, expression.ConvertToDatetime), nil
default:
return e, nil
// Replacements contains a mapping from columns to the alias they will be
// replaced by.
var replacements = make(map[tableCol]string)

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
exp, ok := n.(sql.Expressioner)
if !ok {
return n, nil
}

// nodeReplacements are all the replacements found in the current node.
// These replacements are not applied to the current node, only to
// parent nodes.
var nodeReplacements = make(map[tableCol]string)

var expressions = make(map[string]bool)
switch exp := exp.(type) {
case *plan.Project:
for _, e := range exp.Projections {
expressions[e.String()] = true
}
case *plan.GroupBy:
for _, e := range exp.Aggregate {
expressions[e.String()] = true
}
}

var result sql.Node
var err error
switch exp := exp.(type) {
case *plan.GroupBy:
var aggregate = make([]sql.Expression, len(exp.Aggregate))
for i, a := range exp.Aggregate {
agg, err := a.TransformUp(func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true)
})
if err != nil {
return nil, err
}
aggregate[i] = agg
}

var grouping = make([]sql.Expression, len(exp.Grouping))
for i, g := range exp.Grouping {
gr, err := g.TransformUp(func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, false)
})
if err != nil {
return nil, err
}
grouping[i] = gr
}

result = plan.NewGroupBy(aggregate, grouping, exp.Child)
default:
result, err = exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
return addDateConvert(e, n, replacements, nodeReplacements, expressions, true)
})
}

if err != nil {
return nil, err
}

// We're done with this node, so copy all the replacements found in
// this node to the global replacements in order to make the necesssary
// changes in parent nodes.
for tc, n := range nodeReplacements {
replacements[tc] = n
}

return result, err
})
}

func addDateConvert(
e sql.Expression,
node sql.Node,
replacements, nodeReplacements map[tableCol]string,
expressions map[string]bool,
aliasRootProjections bool,
) (sql.Expression, error) {
var result sql.Expression

// No need to wrap expressions that already validate times, such as
// convert, date_add, etc and those expressions whose Type method
// cannot be called because they are placeholders.
switch e.(type) {
case *expression.Convert,
*expression.Arithmetic,
*function.DateAdd,
*function.DateSub,
*expression.Star,
*expression.DefaultColumn,
*expression.Alias:
return e, nil
default:
// If it's a replacement, just replace it with the correct GetField
// because we know that it's already converted to a correct date
// and there is no point to do so again.
if gf, ok := e.(*expression.GetField); ok {
if name, ok := replacements[tableCol{gf.Table(), gf.Name()}]; ok {
return expression.NewGetField(gf.Index(), gf.Type(), name, gf.IsNullable()), nil
}
}

switch e.Type() {
case sql.Date:
result = expression.NewConvert(e, expression.ConvertToDate)
case sql.Timestamp:
result = expression.NewConvert(e, expression.ConvertToDatetime)
default:
result = e
}
}

// Only do this if it's a root expression in a project or group by.
switch node.(type) {
case *plan.Project, *plan.GroupBy:
// If it was originally a GetField, and it's not anymore it's
// because we wrapped it in a convert. We need to make it an alias
// and propagate the changes up the chain.
if gf, ok := e.(*expression.GetField); ok && expressions[e.String()] && aliasRootProjections {
if _, ok := result.(*expression.GetField); !ok {
name := fmt.Sprintf("%s__%s", gf.Table(), gf.Name())
result = expression.NewAlias(result, name)
nodeReplacements[tableCol{gf.Table(), gf.Name()}] = name
}
}
}

return result, nil
}
78 changes: 78 additions & 0 deletions sql/analyzer/convert_dates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,84 @@ func TestConvertDates(t *testing.T) {
}
}

func TestConvertDatesProject(t *testing.T) {
table := plan.NewResolvedTable(mem.NewTable("t", nil))
input := plan.NewFilter(
expression.NewEquals(
expression.NewGetField(0, sql.Int64, "foo", false),
expression.NewLiteral("2019-06-06 00:00:00", sql.Text),
),
plan.NewProject([]sql.Expression{
expression.NewGetField(0, sql.Timestamp, "foo", false),
}, table),
)
expected := plan.NewFilter(
expression.NewEquals(
expression.NewGetField(0, sql.Int64, "__foo", false),
expression.NewLiteral("2019-06-06 00:00:00", sql.Text),
),
plan.NewProject([]sql.Expression{
expression.NewAlias(
expression.NewConvert(
expression.NewGetField(0, sql.Timestamp, "foo", false),
expression.ConvertToDatetime,
),
"__foo",
),
}, table),
)

result, err := convertDates(sql.NewEmptyContext(), nil, input)
require.NoError(t, err)
require.Equal(t, expected, result)
}

func TestConvertDatesGroupBy(t *testing.T) {
table := plan.NewResolvedTable(mem.NewTable("t", nil))
input := plan.NewFilter(
expression.NewEquals(
expression.NewGetField(0, sql.Int64, "foo", false),
expression.NewLiteral("2019-06-06 00:00:00", sql.Text),
),
plan.NewGroupBy(
[]sql.Expression{
expression.NewGetField(0, sql.Timestamp, "foo", false),
},
[]sql.Expression{
expression.NewGetField(0, sql.Timestamp, "foo", false),
}, table,
),
)
expected := plan.NewFilter(
expression.NewEquals(
expression.NewGetField(0, sql.Int64, "__foo", false),
expression.NewLiteral("2019-06-06 00:00:00", sql.Text),
),
plan.NewGroupBy(
[]sql.Expression{
expression.NewAlias(
expression.NewConvert(
expression.NewGetField(0, sql.Timestamp, "foo", false),
expression.ConvertToDatetime,
),
"__foo",
),
},
[]sql.Expression{
expression.NewConvert(
expression.NewGetField(0, sql.Timestamp, "foo", false),
expression.ConvertToDatetime,
),
},
table,
),
)

result, err := convertDates(sql.NewEmptyContext(), nil, input)
require.NoError(t, err)
require.Equal(t, expected, result)
}

func newDateAdd(l, r sql.Expression) sql.Expression {
e, _ := function.NewDateAdd(l, r)
return e
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ var DefaultRules = []Rule{
{"reorder_projection", reorderProjection},
{"move_join_conds_to_filter", moveJoinConditionsToFilter},
{"eval_filter", evalFilter},
{"convert_dates", convertDates},
{"optimize_distinct", optimizeDistinct},
}

Expand All @@ -36,6 +35,7 @@ var OnceBeforeDefault = []Rule{
// DefaultRules.
var OnceAfterDefault = []Rule{
{"remove_unnecessary_converts", removeUnnecessaryConverts},
{"convert_dates", convertDates},
{"assign_catalog", assignCatalog},
{"prune_columns", pruneColumns},
{"pushdown", pushdown},
Expand Down