Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix issue that date_add and date_sub is incompatible with MySQL #9702

Merged
merged 8 commits into from
Mar 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -2692,6 +2692,13 @@ func (du *baseDateArithmitical) add(ctx sessionctx.Context, date types.Time, int
}

date.Time = types.FromGoTime(goTime)
overflow, err := types.DateTimeIsOverflow(ctx.GetSessionVars().StmtCtx, date)
if err != nil {
return types.Time{}, true, err
}
if overflow {
return types.Time{}, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime"))
}
return date, false, nil
}

Expand All @@ -2718,6 +2725,13 @@ func (du *baseDateArithmitical) sub(ctx sessionctx.Context, date types.Time, int
}

date.Time = types.FromGoTime(goTime)
overflow, err := types.DateTimeIsOverflow(ctx.GetSessionVars().StmtCtx, date)
if err != nil {
return types.Time{}, true, err
}
if overflow {
return types.Time{}, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime"))
}
return date, false, nil
}

Expand Down
3 changes: 2 additions & 1 deletion expression/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ func init() {
// handleInvalidTimeError reports error or warning depend on the context.
func handleInvalidTimeError(ctx sessionctx.Context, err error) error {
if err == nil || !(terror.ErrorEqual(err, types.ErrInvalidTimeFormat) || types.ErrIncorrectDatetimeValue.Equal(err) ||
types.ErrTruncatedWrongValue.Equal(err) || types.ErrInvalidWeekModeFormat.Equal(err)) {
types.ErrTruncatedWrongValue.Equal(err) || types.ErrInvalidWeekModeFormat.Equal(err) ||
types.ErrDatetimeFunctionOverflow.Equal(err)) {
return err
}
sc := ctx.GetSessionVars().StmtCtx
Expand Down
34 changes: 34 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1957,6 +1957,40 @@ func (s *testIntegrationSuite) TestOpBuiltin(c *C) {
result.Check(testkit.Rows("1 0 -9 -0.001 0.999 <nil> aaa"))
}

func (s *testIntegrationSuite) TestDatetimeOverflow(c *C) {
defer s.cleanEnv(c)
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

tk.MustExec("create table t1 (d date)")
tk.MustExec("set sql_mode='traditional'")
overflowSQLs := []string{
"insert into t1 (d) select date_add('2000-01-01',interval 8000 year)",
"insert into t1 (d) select date_sub('2000-01-01', INTERVAL 2001 YEAR)",
"insert into t1 (d) select date_add('9999-12-31',interval 1 year)",
"insert into t1 (d) select date_sub('1000-01-01', INTERVAL 1 YEAR)",
"insert into t1 (d) select date_add('9999-12-31',interval 1 day)",
"insert into t1 (d) select date_sub('1000-01-01', INTERVAL 1 day)",
"insert into t1 (d) select date_sub('1000-01-01', INTERVAL 1 second)",
}

for _, sql := range overflowSQLs {
_, err := tk.Exec(sql)
c.Assert(err.Error(), Equals, "[types:1441]Datetime function: datetime field overflow")
}

tk.MustExec("set sql_mode=''")
for _, sql := range overflowSQLs {
tk.MustExec(sql)
}

rows := make([]string, 0, len(overflowSQLs))
for range overflowSQLs {
rows = append(rows, "<nil>")
}
tk.MustQuery("select * from t1").Check(testkit.Rows(rows...))
}

func (s *testIntegrationSuite) TestBuiltin(c *C) {
defer s.cleanEnv(c)
tk := testkit.NewTestKit(c, s.store)
Expand Down
61 changes: 54 additions & 7 deletions types/time.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ import (

// Portable analogs of some common call errors.
var (
ErrInvalidTimeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid time format: '%v'")
ErrInvalidWeekModeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid week mode format: '%v'")
ErrInvalidYearFormat = errors.New("invalid year format")
ErrInvalidYear = errors.New("invalid year")
ErrZeroDate = errors.New("datetime zero in date")
ErrIncorrectDatetimeValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "Incorrect datetime value: '%s'")
ErrTruncatedWrongValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue])
ErrInvalidTimeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid time format: '%v'")
ErrInvalidWeekModeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid week mode format: '%v'")
ErrInvalidYearFormat = errors.New("invalid year format")
ErrInvalidYear = errors.New("invalid year")
ErrZeroDate = errors.New("datetime zero in date")
ErrIncorrectDatetimeValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "Incorrect datetime value: '%s'")
ErrDatetimeFunctionOverflow = terror.ClassTypes.New(mysql.ErrDatetimeFunctionOverflow, mysql.MySQLErrName[mysql.ErrDatetimeFunctionOverflow])
ErrTruncatedWrongValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue])
)

// Time format without fractional seconds precision.
Expand Down Expand Up @@ -2598,3 +2599,49 @@ func DateFSP(date string) (fsp int) {
}
return
}

// DateTimeIsOverflow return if this date is overflow.
// See: https://dev.mysql.com/doc/refman/8.0/en/datetime.html
func DateTimeIsOverflow(sc *stmtctx.StatementContext, date Time) (bool, error) {
tz := sc.TimeZone
if tz == nil {
tz = gotime.Local
}

var err error
var b, e, t gotime.Time
switch date.Type {
case mysql.TypeDate, mysql.TypeDatetime:
if b, err = MinDatetime.GoTime(tz); err != nil {
return false, err
}
if e, err = MaxDatetime.GoTime(tz); err != nil {
return false, err
}
case mysql.TypeTimestamp:
minTS, maxTS := MinTimestamp, MaxTimestamp
if tz != gotime.UTC {
if err = minTS.ConvertTimeZone(gotime.UTC, tz); err != nil {
return false, err
}
if err = maxTS.ConvertTimeZone(gotime.UTC, tz); err != nil {
return false, err
}
}
if b, err = minTS.Time.GoTime(tz); err != nil {
return false, err
}
if e, err = maxTS.Time.GoTime(tz); err != nil {
return false, err
}
default:
return false, nil
}

if t, err = date.Time.GoTime(tz); err != nil {
return false, err
}

inRange := (t.After(b) || t.Equal(b)) && (t.Before(e) || t.Equal(e))
return !inRange, nil
}