Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

sql/analyzer: wrap all time and date expressions with convert #699

Merged
merged 3 commits into from
May 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,22 @@ var queries = []struct {
{string("first row"), int64(1)},
},
},
{
"SELECT CONVERT('9999-12-31 23:59:59', DATETIME)",
[]sql.Row{{time.Date(9999, time.December, 31, 23, 59, 59, 0, time.UTC)}},
},
{
"SELECT CONVERT('10000-12-31 23:59:59', DATETIME)",
[]sql.Row{{nil}},
},
{
"SELECT '9999-12-31 23:59:59' + INTERVAL 1 DAY",
[]sql.Row{{nil}},
},
{
"SELECT DATE_ADD('9999-12-31 23:59:59', INTERVAL 1 DAY)",
[]sql.Row{{nil}},
},
}

func TestQueries(t *testing.T) {
Expand Down
40 changes: 40 additions & 0 deletions sql/analyzer/convert_dates.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package analyzer

import (
"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression/function"
)

// convertDates wraps all expressions of date and datetime type with converts
// to ensure the date range is validated.
func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
if !n.Resolved() {
return n, nil
}

return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
// No need to wrap expressions that already validate times, such as
// convert, date_add, etc and those expressions whose Type method
// cannot be called because they are placeholders.
switch e.(type) {
case *expression.Convert,
*expression.Arithmetic,
*function.DateAdd,
*function.DateSub,
*expression.Star,
*expression.DefaultColumn,
*expression.Alias:
return e, nil
default:
switch e.Type() {
case sql.Date:
return expression.NewConvert(e, expression.ConvertToDate), nil
case sql.Timestamp:
return expression.NewConvert(e, expression.ConvertToDatetime), nil
default:
return e, nil
}
}
})
}
157 changes: 157 additions & 0 deletions sql/analyzer/convert_dates_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package analyzer

import (
"testing"

"github.com/stretchr/testify/require"
"gopkg.in/src-d/go-mysql-server.v0/mem"
"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression/function"
"gopkg.in/src-d/go-mysql-server.v0/sql/plan"
)

func TestConvertDates(t *testing.T) {
testCases := []struct {
name string
in sql.Expression
out sql.Expression
}{
{
"arithmetic with dates",
expression.NewPlus(expression.NewLiteral("", sql.Timestamp), expression.NewLiteral("", sql.Timestamp)),
expression.NewPlus(
expression.NewConvert(
expression.NewLiteral("", sql.Timestamp),
expression.ConvertToDatetime,
),
expression.NewConvert(
expression.NewLiteral("", sql.Timestamp),
expression.ConvertToDatetime,
),
),
},
{
"star",
expression.NewStar(),
expression.NewStar(),
},
{
"default column",
expression.NewDefaultColumn("foo"),
expression.NewDefaultColumn("foo"),
},
{
"convert to date",
expression.NewConvert(
expression.NewPlus(
expression.NewLiteral("", sql.Timestamp),
expression.NewLiteral("", sql.Timestamp),
),
expression.ConvertToDatetime,
),
expression.NewConvert(
expression.NewPlus(
expression.NewConvert(
expression.NewLiteral("", sql.Timestamp),
expression.ConvertToDatetime,
),
expression.NewConvert(
expression.NewLiteral("", sql.Timestamp),
expression.ConvertToDatetime,
),
),
expression.ConvertToDatetime,
),
},
{
"convert to other type",
expression.NewConvert(
expression.NewLiteral("", sql.Text),
expression.ConvertToBinary,
),
expression.NewConvert(
expression.NewLiteral("", sql.Text),
expression.ConvertToBinary,
),
},
{
"datetime col in alias",
expression.NewAlias(
expression.NewLiteral("", sql.Timestamp),
"foo",
),
expression.NewAlias(
expression.NewConvert(
expression.NewLiteral("", sql.Timestamp),
expression.ConvertToDatetime,
),
"foo",
),
},
{
"date col in alias",
expression.NewAlias(
expression.NewLiteral("", sql.Date),
"foo",
),
expression.NewAlias(
expression.NewConvert(
expression.NewLiteral("", sql.Date),
expression.ConvertToDate,
),
"foo",
),
},
{
"date add",
newDateAdd(
expression.NewLiteral("", sql.Timestamp),
expression.NewInterval(expression.NewLiteral(int64(1), sql.Int64), "DAY"),
),
newDateAdd(
expression.NewConvert(
expression.NewLiteral("", sql.Timestamp),
expression.ConvertToDatetime,
),
expression.NewInterval(expression.NewLiteral(int64(1), sql.Int64), "DAY"),
),
},
{
"date sub",
newDateSub(
expression.NewLiteral("", sql.Timestamp),
expression.NewInterval(expression.NewLiteral(int64(1), sql.Int64), "DAY"),
),
newDateSub(
expression.NewConvert(
expression.NewLiteral("", sql.Timestamp),
expression.ConvertToDatetime,
),
expression.NewInterval(expression.NewLiteral(int64(1), sql.Int64), "DAY"),
),
},
}

table := plan.NewResolvedTable(mem.NewTable("t", nil))

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
input := plan.NewProject([]sql.Expression{tt.in}, table)
expected := plan.NewProject([]sql.Expression{tt.out}, table)
result, err := convertDates(sql.NewEmptyContext(), nil, input)
require.NoError(t, err)
require.Equal(t, expected, result)
})
}
}

func newDateAdd(l, r sql.Expression) sql.Expression {
e, _ := function.NewDateAdd(l, r)
return e
}

func newDateSub(l, r sql.Expression) sql.Expression {
e, _ := function.NewDateSub(l, r)
return e
}
1 change: 1 addition & 0 deletions sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var DefaultRules = []Rule{
{"reorder_projection", reorderProjection},
{"move_join_conds_to_filter", moveJoinConditionsToFilter},
{"eval_filter", evalFilter},
{"convert_dates", convertDates},
{"optimize_distinct", optimizeDistinct},
}

Expand Down
15 changes: 12 additions & 3 deletions sql/expression/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ func (a *Arithmetic) String() string {
return fmt.Sprintf("%s %s %s", a.Left, a.Op, a.Right)
}

// IsNullable implements the sql.Expression interface.
func (a *Arithmetic) IsNullable() bool {
if a.Type() == sql.Timestamp {
return true
}

return a.BinaryExpression.IsNullable()
}

// Type returns the greatest type for given operation.
func (a *Arithmetic) Type() sql.Type {
switch a.Op {
Expand Down Expand Up @@ -254,12 +263,12 @@ func plus(lval, rval interface{}) (interface{}, error) {
case time.Time:
switch r := rval.(type) {
case *TimeDelta:
return r.Add(l), nil
return sql.ValidateTime(r.Add(l)), nil
}
case *TimeDelta:
switch r := rval.(type) {
case time.Time:
return l.Add(r), nil
return sql.ValidateTime(l.Add(r)), nil
}
}

Expand Down Expand Up @@ -288,7 +297,7 @@ func minus(lval, rval interface{}) (interface{}, error) {
case time.Time:
switch r := rval.(type) {
case *TimeDelta:
return r.Sub(l), nil
return sql.ValidateTime(r.Sub(l)), nil
}
}

Expand Down
16 changes: 14 additions & 2 deletions sql/expression/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,27 @@ func NewConvert(expr sql.Expression, castToType string) *Convert {
}
}

// IsNullable implements the Expression interface.
func (c *Convert) IsNullable() bool {
switch c.castToType {
case ConvertToDate, ConvertToDatetime:
return true
default:
return c.Child.IsNullable()
}
}

// Type implements the Expression interface.
func (c *Convert) Type() sql.Type {
switch c.castToType {
case ConvertToBinary:
return sql.Blob
case ConvertToChar, ConvertToNChar:
return sql.Text
case ConvertToDate, ConvertToDatetime:
case ConvertToDate:
return sql.Date
case ConvertToDatetime:
return sql.Timestamp
case ConvertToDecimal:
return sql.Float64
case ConvertToJSON:
Expand Down Expand Up @@ -143,7 +155,7 @@ func convertValue(val interface{}, castTo string) (interface{}, error) {
}
}

return d, nil
return sql.ValidateTime(d.(time.Time)), nil
case ConvertToDecimal:
d, err := cast.ToFloat64E(val)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions sql/expression/function/date.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (d *DateAdd) Resolved() bool {

// IsNullable implements the sql.Expression interface.
func (d *DateAdd) IsNullable() bool {
return d.Date.IsNullable() || d.Interval.IsNullable()
return true
}

// Type implements the sql.Expression interface.
Expand Down Expand Up @@ -85,7 +85,7 @@ func (d *DateAdd) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return nil, nil
}

return delta.Add(date.(time.Time)), nil
return sql.ValidateTime(delta.Add(date.(time.Time))), nil
}

func (d *DateAdd) String() string {
Expand Down Expand Up @@ -124,7 +124,7 @@ func (d *DateSub) Resolved() bool {

// IsNullable implements the sql.Expression interface.
func (d *DateSub) IsNullable() bool {
return d.Date.IsNullable() || d.Interval.IsNullable()
return true
}

// Type implements the sql.Expression interface.
Expand Down Expand Up @@ -169,7 +169,7 @@ func (d *DateSub) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return nil, nil
}

return delta.Sub(date.(time.Time)), nil
return sql.ValidateTime(delta.Sub(date.(time.Time))), nil
}

func (d *DateSub) String() string {
Expand Down
11 changes: 11 additions & 0 deletions sql/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ type Type interface {
fmt.Stringer
}

var maxTime = time.Date(9999, time.December, 31, 23, 59, 59, 0, time.UTC)

// ValidateTime receives a time and returns either that time or nil if it's
// not a valid time.
func ValidateTime(t time.Time) interface{} {
if t.After(maxTime) {
return nil
}
return t
}

var (
// Null represents the null type.
Null nullT
Expand Down