From adae3b0e3452f5fd14c00101483f3997bc3f6ee4 Mon Sep 17 00:00:00 2001 From: haplone Date: Wed, 13 Feb 2019 10:23:24 +0800 Subject: [PATCH] expression: fix date_add interval month,year diffs from mysql (#8988) (#9284) --- expression/builtin_time.go | 14 ++------- expression/builtin_time_test.go | 49 +++++++++++++++++++++++++++++++ types/mytime.go | 51 ++++++++++++++++++++++++++++++++- types/mytime_test.go | 18 ++++++++++++ 4 files changed, 120 insertions(+), 12 deletions(-) diff --git a/expression/builtin_time.go b/expression/builtin_time.go index de832cc33c488..bbd0411f54267 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -2632,7 +2632,7 @@ func (du *baseDateArithmitical) add(ctx sessionctx.Context, date types.Time, int duration := time.Duration(dur) goTime = goTime.Add(duration) - goTime = goTime.AddDate(int(year), int(month), int(day)) + goTime = types.AddDate(year, month, day, goTime) if goTime.Nanosecond() == 0 { date.Fsp = 0 @@ -2658,7 +2658,7 @@ func (du *baseDateArithmitical) sub(ctx sessionctx.Context, date types.Time, int duration := time.Duration(dur) goTime = goTime.Add(duration) - goTime = goTime.AddDate(int(year), int(month), int(day)) + goTime = types.AddDate(year, month, day, goTime) if goTime.Nanosecond() == 0 { date.Fsp = 0 @@ -5587,15 +5587,7 @@ func (b *builtinLastDaySig) evalTime(row chunk.Row) (types.Time, bool, error) { if year == 0 && month == 0 && tm.Day() == 0 { return types.Time{}, true, errors.Trace(handleInvalidTimeError(b.ctx, types.ErrIncorrectDatetimeValue.GenWithStackByArgs(arg.String()))) } - if month == 1 || month == 3 || month == 5 || - month == 7 || month == 8 || month == 10 || month == 12 { - day = 31 - } else if month == 2 { - day = 28 - if tm.IsLeapYear() { - day = 29 - } - } + day = types.GetLastDay(year, month) ret := types.Time{ Time: types.FromDate(year, month, day, 0, 0, 0, 0), Type: mysql.TypeDate, diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index b19a2a7ba5040..5f51be8c24dc1 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -1656,6 +1656,55 @@ func (s *testEvaluatorSuite) TestDateArithFuncs(c *C) { v, err = evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(v.IsNull(), IsTrue) + + testMonths := []struct { + input string + months int + expected string + }{ + {"1900-01-31", 1, "1900-02-28"}, + {"2000-01-31", 1, "2000-02-29"}, + {"2016-01-31", 1, "2016-02-29"}, + {"2018-07-31", 1, "2018-08-31"}, + {"2018-08-31", 1, "2018-09-30"}, + {"2018-07-31", 2, "2018-09-30"}, + {"2016-01-31", 27, "2018-04-30"}, + {"2000-02-29", 12, "2001-02-28"}, + {"2000-11-30", 1, "2000-12-30"}, + } + + for _, test := range testMonths { + args = types.MakeDatums(test.input, test.months, "MONTH") + f, err = fcAdd.getFunction(s.ctx, s.datumsToConstants(args)) + c.Assert(err, IsNil) + c.Assert(f, NotNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v.GetMysqlTime().String(), Equals, test.expected) + } + + testYears := []struct { + input string + year int + expected string + }{ + {"1899-02-28", 1, "1900-02-28"}, + {"1901-02-28", -1, "1900-02-28"}, + {"2000-02-29", 1, "2001-02-28"}, + {"2001-02-28", -1, "2000-02-28"}, + {"2004-02-29", 1, "2005-02-28"}, + {"2005-02-28", -1, "2004-02-28"}, + } + + for _, test := range testYears { + args = types.MakeDatums(test.input, test.year, "YEAR") + f, err = fcAdd.getFunction(s.ctx, s.datumsToConstants(args)) + c.Assert(err, IsNil) + c.Assert(f, NotNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v.GetMysqlTime().String(), Equals, test.expected) + } } func (s *testEvaluatorSuite) TestTimestamp(c *C) { diff --git a/types/mytime.go b/types/mytime.go index 6c9564dcbac49..bc90f08c13125 100644 --- a/types/mytime.go +++ b/types/mytime.go @@ -119,7 +119,56 @@ func (t MysqlTime) GoTime(loc *gotime.Location) (gotime.Time, error) { // IsLeapYear returns if it's leap year. func (t MysqlTime) IsLeapYear() bool { - return (t.year%4 == 0 && t.year%100 != 0) || t.year%400 == 0 + return isLeapYear(t.year) +} + +func isLeapYear(year uint16) bool { + return (year%4 == 0 && year%100 != 0) || year%400 == 0 +} + +var daysByMonth = [12]int{31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31} + +// GetLastDay returns the last day of the month +func GetLastDay(year, month int) int { + var day = 0 + if month > 0 && month <= 12 { + day = daysByMonth[month-1] + } + if month == 2 && isLeapYear(uint16(year)) { + day = 29 + } + return day +} + +func getFixDays(year, month, day int, ot gotime.Time) int { + if (year != 0 || month != 0) && day == 0 { + od := ot.Day() + t := ot.AddDate(year, month, day) + td := t.Day() + if od != td { + tm := int(t.Month()) - 1 + tMax := GetLastDay(t.Year(), tm) + dd := tMax - od + return dd + } + } + return 0 +} + +// AddDate fix gap between mysql and golang api +// When we execute select date_add('2018-01-31',interval 1 month) in mysql we got 2018-02-28 +// but in tidb we got 2018-03-03. +// Dig it and we found it's caused by golang api time.Date(year int, month Month, day, hour, min, sec, nsec int, loc *Location) Time , +// it says October 32 converts to November 1 ,it conflits with mysql. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add +func AddDate(year, month, day int64, ot gotime.Time) (nt gotime.Time) { + df := getFixDays(int(year), int(month), int(day), ot) + if df != 0 { + nt = ot.AddDate(int(year), int(month), df) + } else { + nt = ot.AddDate(int(year), int(month), int(day)) + } + return nt } func calcTimeFromSec(to *MysqlTime, seconds, microseconds int) { diff --git a/types/mytime_test.go b/types/mytime_test.go index b4b0a8606ab5b..8d9d8d016a68f 100644 --- a/types/mytime_test.go +++ b/types/mytime_test.go @@ -210,3 +210,21 @@ func (s *testMyTimeSuite) TestIsLeapYear(c *C) { c.Assert(tt.T.IsLeapYear(), Equals, tt.Expect) } } +func (s *testMyTimeSuite) TestGetLastDay(c *C) { + tests := []struct { + year int + month int + expectedDay int + }{ + {2000, 1, 31}, + {2000, 2, 29}, + {2000, 4, 30}, + {1900, 2, 28}, + {1996, 2, 29}, + } + + for _, t := range tests { + day := GetLastDay(t.year, t.month) + c.Assert(day, Equals, t.expectedDay) + } +}