Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix the behavior when adding date with big interval | tidb-test=pr/2260 #49228

Merged
merged 7 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/config",
"//pkg/errctx",
"//pkg/errno",
"//pkg/extension",
"//pkg/kv",
Expand Down
117 changes: 82 additions & 35 deletions pkg/expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/config"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
Expand Down Expand Up @@ -2754,7 +2755,7 @@ type baseDateArithmetical struct {

func newDateArithmeticalUtil() baseDateArithmetical {
return baseDateArithmetical{
intervalRegexp: regexp.MustCompile(`-?[\d]+`),
intervalRegexp: regexp.MustCompile(`^[+-]?[\d]+`),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XXX1 should be parsed as "", not "1"

}
}

Expand Down Expand Up @@ -2864,17 +2865,58 @@ func (du *baseDateArithmetical) getIntervalFromString(ctx sessionctx.Context, ar
if isNull || err != nil {
return "", true, err
}
// unit "DAY" and "HOUR" has to be specially handled.
if toLower := strings.ToLower(unit); toLower == "day" || toLower == "hour" {
if strings.ToLower(interval) == "true" {
interval = "1"
} else if strings.ToLower(interval) == "false" {

interval, err = du.intervalReformatString(ctx.GetSessionVars().StmtCtx.ErrCtx(), interval, unit)
return interval, false, err
}

func (du *baseDateArithmetical) intervalReformatString(ec errctx.Context, str string, unit string) (interval string, err error) {
switch strings.ToUpper(unit) {
case "MICROSECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR":
str = strings.TrimSpace(str)
// a single unit value has to be specially handled.
lcwangchao marked this conversation as resolved.
Show resolved Hide resolved
interval = du.intervalRegexp.FindString(str)
if interval == "" {
interval = "0"
} else {
interval = du.intervalRegexp.FindString(interval)
}

if interval != str {
err = ec.HandleError(types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", str))
}
case "SECOND":
// The unit SECOND is specially handled, for example:
// date + INTERVAL "1e2" SECOND = date + INTERVAL 100 second
// date + INTERVAL "1.6" SECOND = date + INTERVAL 1.6 second
// But:
// date + INTERVAL "1e2" MINUTE = date + INTERVAL 1 MINUTE
// date + INTERVAL "1.6" MINUTE = date + INTERVAL 1 MINUTE
var dec types.MyDecimal
if err = dec.FromString([]byte(str)); err != nil {
truncatedErr := types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", str)
err = ec.HandleErrorWithAlias(err, truncatedErr, truncatedErr)
}
interval = string(dec.ToString())
default:
interval = str
}
return interval, false, nil
return interval, err
}

func (du *baseDateArithmetical) intervalDecimalToString(ec errctx.Context, dec *types.MyDecimal) (string, error) {
var rounded types.MyDecimal
err := dec.Round(&rounded, 0, types.ModeHalfUp)
if err != nil {
return "", err
}

intVal, err := rounded.ToInt()
if err != nil {
if err = ec.HandleError(types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", dec.String())); err != nil {
return "", err
}
}

return strconv.FormatInt(intVal, 10), nil
}

func (du *baseDateArithmetical) getIntervalFromDecimal(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (string, bool, error) {
Expand Down Expand Up @@ -2921,9 +2963,8 @@ func (du *baseDateArithmetical) getIntervalFromDecimal(ctx sessionctx.Context, a
// interval is already like the %f format.
default:
// YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE, MICROSECOND
castExpr := WrapWithCastAsString(ctx, WrapWithCastAsInt(ctx, args[1]))
interval, isNull, err = castExpr.EvalString(ctx, row)
if isNull || err != nil {
interval, err = du.intervalDecimalToString(ctx.GetSessionVars().StmtCtx.ErrCtx(), decimal)
if err != nil {
return "", true, err
}
}
Expand All @@ -2936,6 +2977,11 @@ func (du *baseDateArithmetical) getIntervalFromInt(ctx sessionctx.Context, args
if isNull || err != nil {
return "", true, err
}

if mysql.HasUnsignedFlag(args[1].GetType().GetFlag()) {
lcwangchao marked this conversation as resolved.
Show resolved Hide resolved
return strconv.FormatUint(uint64(interval), 10), false, nil
}

return strconv.FormatInt(interval, 10), false, nil
}

Expand All @@ -2962,7 +3008,10 @@ func (du *baseDateArithmetical) addDate(ctx sessionctx.Context, date types.Time,
}

goTime = goTime.Add(time.Duration(nano))
goTime = types.AddDate(year, month, day, goTime)
goTime, err = types.AddDate(year, month, day, goTime)
if err != nil {
return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime"))
}

// Adjust fsp as required by outer - always respect type inference.
date.SetFsp(resultFsp)
Expand All @@ -2974,10 +3023,6 @@ func (du *baseDateArithmetical) addDate(ctx sessionctx.Context, date types.Time,
return date, false, nil
}

if goTime.Year() < 0 || goTime.Year() > 9999 {
return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime"))
}

date.SetCoreTime(types.FromGoTime(goTime))
overflow, err := types.DateTimeIsOverflow(ctx.GetSessionVars().StmtCtx.TypeCtx(), date)
if err := handleInvalidTimeError(ctx, err); err != nil {
Expand Down Expand Up @@ -3236,28 +3281,19 @@ func (du *baseDateArithmetical) vecGetIntervalFromString(b *baseBuiltinFunc, ctx
return err
}

amendInterval := func(val string) string {
return val
}
if unitLower := strings.ToLower(unit); unitLower == "day" || unitLower == "hour" {
amendInterval = func(val string) string {
if intervalLower := strings.ToLower(val); intervalLower == "true" {
return "1"
} else if intervalLower == "false" {
return "0"
}
return du.intervalRegexp.FindString(val)
}
}

ec := ctx.GetSessionVars().StmtCtx.ErrCtx()
result.ReserveString(n)
for i := 0; i < n; i++ {
if buf.IsNull(i) {
result.AppendNull()
continue
}

result.AppendString(amendInterval(buf.GetString(i)))
interval, err := du.intervalReformatString(ec, buf.GetString(i), unit)
if err != nil {
return err
}
result.AppendString(interval)
}
return nil
}
Expand Down Expand Up @@ -3325,10 +3361,18 @@ func (du *baseDateArithmetical) vecGetIntervalFromDecimal(b *baseBuiltinFunc, ct
/* keep interval as original decimal */
default:
// YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE, MICROSECOND
castExpr := WrapWithCastAsString(ctx, WrapWithCastAsInt(ctx, b.args[1]))
amendInterval = func(_ string, row *chunk.Row) (string, bool, error) {
interval, isNull, err := castExpr.EvalString(ctx, *row)
return interval, isNull || err != nil, err
dec, isNull, err := b.args[1].EvalDecimal(ctx, *row)
if isNull || err != nil {
return "", true, err
}

str, err := du.intervalDecimalToString(ctx.GetSessionVars().StmtCtx.ErrCtx(), dec)
if err != nil {
return "", true, err
}

return str, false, nil
}
}

Expand Down Expand Up @@ -3376,9 +3420,12 @@ func (du *baseDateArithmetical) vecGetIntervalFromInt(b *baseBuiltinFunc, ctx se

result.ReserveString(n)
i64s := buf.Int64s()
unsigned := mysql.HasUnsignedFlag(b.args[1].GetType().GetFlag())
for i := 0; i < n; i++ {
if buf.IsNull(i) {
result.AppendNull()
} else if unsigned {
result.AppendString(strconv.FormatUint(uint64(i64s[i]), 10))
} else {
result.AppendString(strconv.FormatInt(i64s[i], 10))
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/expression/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2247,11 +2247,11 @@ func TestTimeBuiltin(t *testing.T) {
{"\"2011-11-11 10:10:10\"", "\"20\"", "DAY", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"\"2011-11-11 10:10:10\"", "19.88", "DAY", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"19.88\"", "DAY", "2011-11-30 10:10:10", "2011-10-23 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"prefix19suffix\"", "DAY", "2011-11-30 10:10:10", "2011-10-23 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"prefix19suffix\"", "DAY", "2011-11-11 10:10:10", "2011-11-11 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"20-11\"", "DAY", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"20,11\"", "daY", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"1000\"", "dAy", "2014-08-07 10:10:10", "2009-02-14 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"true\"", "Day", "2011-11-12 10:10:10", "2011-11-10 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"true\"", "Day", "2011-11-11 10:10:10", "2011-11-11 10:10:10"},
{"\"2011-11-11 10:10:10\"", "true", "Day", "2011-11-12 10:10:10", "2011-11-10 10:10:10"},
{"\"2011-11-11\"", "1", "DAY", "2011-11-12", "2011-11-10"},
{"\"2011-11-11\"", "10", "HOUR", "2011-11-11 10:00:00", "2011-11-10 14:00:00"},
Expand Down Expand Up @@ -2329,8 +2329,8 @@ func TestTimeBuiltin(t *testing.T) {
{"\"2009-01-01\"", "6/0", "HOUR_MINUTE", "<nil>", "<nil>"},
{"\"1970-01-01 12:00:00\"", "CAST(6/4 AS DECIMAL(3,1))", "HOUR_MINUTE", "1970-01-01 13:05:00", "1970-01-01 10:55:00"},
// for issue #8077
{"\"2012-01-02\"", "\"prefix8\"", "HOUR", "2012-01-02 08:00:00", "2012-01-01 16:00:00"},
{"\"2012-01-02\"", "\"prefix8prefix\"", "HOUR", "2012-01-02 08:00:00", "2012-01-01 16:00:00"},
{"\"2012-01-02\"", "\"prefix8\"", "HOUR", "2012-01-02 00:00:00", "2012-01-02 00:00:00"},
{"\"2012-01-02\"", "\"prefix8prefix\"", "HOUR", "2012-01-02 00:00:00", "2012-01-02 00:00:00"},
{"\"2012-01-02\"", "\"8:00\"", "HOUR", "2012-01-02 08:00:00", "2012-01-01 16:00:00"},
{"\"2012-01-02\"", "\"8:00:00\"", "HOUR", "2012-01-02 08:00:00", "2012-01-01 16:00:00"},
}
Expand Down
20 changes: 18 additions & 2 deletions pkg/types/core_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,30 @@ func compareTime(a, b CoreTime) int {
// Dig it and we found it's caused by golang api time.Date(year int, month Month, day, hour, min, sec, nsec int, loc *Location) Time ,
// it says October 32 converts to November 1 ,it conflicts with mysql.
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
func AddDate(year, month, day int64, ot gotime.Time) (nt gotime.Time) {
func AddDate(year, month, day int64, ot gotime.Time) (nt gotime.Time, _ error) {
// We must limit the range of year, month and day to avoid overflow.
// The datetime range is from '1000-01-01 00:00:00.000000' to '9999-12-31 23:59:59.499999',
// so it is safe to limit the added value from -10000*365 to 10000*365.
const maxAdd = 10000 * 365
const minAdd = -maxAdd
if year > maxAdd || year < minAdd ||
month > maxAdd || month < minAdd ||
day > maxAdd || day < minAdd {
return nt, ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")
}

df := getFixDays(int(year), int(month), int(day), ot)
if df != 0 {
nt = ot.AddDate(int(year), int(month), df)
} else {
nt = ot.AddDate(int(year), int(month), int(day))
}
return nt

if nt.Year() < 0 || nt.Year() > 9999 {
return nt, ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")
}

return nt, nil
}

func calcTimeFromSec(to *CoreTime, seconds, microseconds int) {
Expand Down
29 changes: 23 additions & 6 deletions pkg/types/core_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,33 @@ func TestAddDate(t *testing.T) {
month int
day int
ot time.Time
err bool
}{
{01, 1, 0, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC)},
{02, 1, 12, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC)},
{03, 1, 12, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC)},
{04, 2, 24, time.Date(2000, 2, 10, 0, 0, 0, 0, time.UTC)},
{01, 04, 05, time.Date(2019, 04, 01, 1, 2, 3, 4, time.UTC)},
{01, 1, 0, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{02, 1, 12, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{03, 1, 12, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{04, 2, 24, time.Date(2000, 2, 10, 0, 0, 0, 0, time.UTC), false},
{01, 04, 05, time.Date(2019, 04, 01, 1, 2, 3, 4, time.UTC), false},
{7999, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{-2000, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{8000, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{10001 * 365, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{01, 10001 * 36, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{01, 1, 10001 * 365, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{-2001, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{-10001 * 365, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{01, -10001 * 36, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{01, 1, -10001 * 365, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
}

for _, tt := range tests {
res := AddDate(int64(tt.year), int64(tt.month), int64(tt.day), tt.ot)
res, err := AddDate(int64(tt.year), int64(tt.month), int64(tt.day), tt.ot)
if tt.err {
require.EqualError(t, err, ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime").Error())
require.True(t, ErrDatetimeFunctionOverflow.Equal(err))
continue
}
require.NoError(t, err)
require.Equal(t, tt.year+tt.ot.Year(), res.Year())
}
}
Expand Down
Loading