Skip to content

Commit

Permalink
expression: handle empty input and improve compatibility for format (
Browse files Browse the repository at this point in the history
  • Loading branch information
eurekaka authored and zz-jason committed Jan 31, 2019
1 parent d97b8ed commit 2b6dc60
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 54 deletions.
76 changes: 54 additions & 22 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -2917,7 +2917,13 @@ func (c *formatFunctionClass) getFunction(ctx sessionctx.Context, args []Express
return nil, errors.Trace(err)
}
argTps := make([]types.EvalType, 2, 3)
argTps[0], argTps[1] = types.ETString, types.ETString
argTps[1] = types.ETInt
argTp := args[0].GetType().EvalType()
if argTp == types.ETDecimal || argTp == types.ETInt {
argTps[0] = types.ETDecimal
} else {
argTps[0] = types.ETReal
}
if len(args) == 3 {
argTps = append(argTps, types.ETString)
}
Expand All @@ -2932,6 +2938,41 @@ func (c *formatFunctionClass) getFunction(ctx sessionctx.Context, args []Express
return sig, nil
}

// formatMaxDecimals limits the maximum number of decimal digits for result of
// function `format`, this value is same as `FORMAT_MAX_DECIMALS` in MySQL source code.
const formatMaxDecimals int64 = 30

// evalNumDecArgsForFormat evaluates first 2 arguments, i.e, x and d, for function `format`.
func evalNumDecArgsForFormat(f builtinFunc, row chunk.Row) (string, string, bool, error) {
var xStr string
arg0, arg1 := f.getArgs()[0], f.getArgs()[1]
ctx := f.getCtx()
if arg0.GetType().EvalType() == types.ETDecimal {
x, isNull, err := arg0.EvalDecimal(ctx, row)
if isNull || err != nil {
return "", "", isNull, err
}
xStr = x.String()
} else {
x, isNull, err := arg0.EvalReal(ctx, row)
if isNull || err != nil {
return "", "", isNull, err
}
xStr = strconv.FormatFloat(x, 'f', -1, 64)
}
d, isNull, err := arg1.EvalInt(ctx, row)
if isNull || err != nil {
return "", "", isNull, err
}
if d < 0 {
d = 0
} else if d > formatMaxDecimals {
d = formatMaxDecimals
}
dStr := strconv.FormatInt(d, 10)
return xStr, dStr, false, nil
}

type builtinFormatWithLocaleSig struct {
baseBuiltinFunc
}
Expand All @@ -2945,23 +2986,20 @@ func (b *builtinFormatWithLocaleSig) Clone() builtinFunc {
// evalString evals FORMAT(X,D,locale).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_format
func (b *builtinFormatWithLocaleSig) evalString(row chunk.Row) (string, bool, error) {
x, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}

d, isNull, err := b.args[1].EvalString(b.ctx, row)
x, d, isNull, err := evalNumDecArgsForFormat(b, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
return "", isNull, err
}

locale, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
if err != nil {
return "", false, err
}
if isNull {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errUnknownLocale.GenWithStackByArgs("NULL"))
locale = "en_US"
}

formatString, err := mysql.GetLocaleFormatFunction(locale)(x, d)
return formatString, err != nil, errors.Trace(err)
return formatString, false, err
}

type builtinFormatSig struct {
Expand All @@ -2977,18 +3015,12 @@ func (b *builtinFormatSig) Clone() builtinFunc {
// evalString evals FORMAT(X,D).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_format
func (b *builtinFormatSig) evalString(row chunk.Row) (string, bool, error) {
x, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}

d, isNull, err := b.args[1].EvalString(b.ctx, row)
x, d, isNull, err := evalNumDecArgsForFormat(b, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
return "", isNull, err
}

formatString, err := mysql.GetLocaleFormatFunction("en_US")(x, d)
return formatString, err != nil, errors.Trace(err)
return formatString, false, err
}

type fromBase64FunctionClass struct {
Expand Down
93 changes: 61 additions & 32 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1599,38 +1599,42 @@ func (s *testEvaluatorSuite) TestFormat(c *C) {
locale string
ret interface{}
}{
{12332.1234561111111111111111111111111111111111111, 4, "en_US", "12,332.1234"},
{12332.12341111111111111111111111111111111111111, 4, "en_US", "12,332.1234"},
{nil, 22, "en_US", nil},
}
formatTests1 := []struct {
number interface{}
precision interface{}
ret interface{}
warnings int
}{
{12332.123456, 4, "12,332.1234"},
{12332.123456, 0, "12,332"},
{12332.123456, -4, "12,332"},
{-12332.123456, 4, "-12,332.1234"},
{-12332.123456, 0, "-12,332"},
{-12332.123456, -4, "-12,332"},
{"12332.123456", "4", "12,332.1234"},
{"12332.123456A", "4", "12,332.1234"},
{"-12332.123456", "4", "-12,332.1234"},
{"-12332.123456A", "4", "-12,332.1234"},
{"A123345", "4", "0.0000"},
{"-A123345", "4", "0.0000"},
{"-12332.123456", "A", "-12,332"},
{"12332.123456", "A", "12,332"},
{"-12332.123456", "4A", "-12,332.1234"},
{"12332.123456", "4A", "12,332.1234"},
{"-A12332.123456", "A", "0"},
{"A12332.123456", "A", "0"},
{"-A12332.123456", "4A", "0.0000"},
{"A12332.123456", "4A", "0.0000"},
{"-.12332.123456", "4A", "-0.1233"},
{".12332.123456", "4A", "0.1233"},
{"12332.1234567890123456789012345678901", 22, "12,332.1234567890123456789012"},
{nil, 22, nil},
{12332.123444, 4, "12,332.1234", 0},
{12332.123444, 0, "12,332", 0},
{12332.123444, -4, "12,332", 0},
{-12332.123444, 4, "-12,332.1234", 0},
{-12332.123444, 0, "-12,332", 0},
{-12332.123444, -4, "-12,332", 0},
{"12332.123444", "4", "12,332.1234", 0},
{"12332.123444A", "4", "12,332.1234", 1},
{"-12332.123444", "4", "-12,332.1234", 0},
{"-12332.123444A", "4", "-12,332.1234", 1},
{"A123345", "4", "0.0000", 1},
{"-A123345", "4", "0.0000", 1},
{"-12332.123444", "A", "-12,332", 1},
{"12332.123444", "A", "12,332", 1},
{"-12332.123444", "4A", "-12,332.1234", 1},
{"12332.123444", "4A", "12,332.1234", 1},
{"-A12332.123444", "A", "0", 2},
{"A12332.123444", "A", "0", 2},
{"-A12332.123444", "4A", "0.0000", 2},
{"A12332.123444", "4A", "0.0000", 2},
{"-.12332.123444", "4A", "-0.1233", 2},
{".12332.123444", "4A", "0.1233", 2},
{"12332.1234567890123456789012345678901", 22, "12,332.1234567890110000000000", 0},
{nil, 22, nil, 0},
{1, 1024, "1.000000000000000000000000000000", 0},
{"", 1, "0.0", 1},
{1, "", "1", 1},
}
formatTests2 := struct {
number interface{}
Expand All @@ -1644,9 +1648,15 @@ func (s *testEvaluatorSuite) TestFormat(c *C) {
locale string
ret interface{}
}{"-12332.123456", "4", "de_GE", nil}
formatTests4 := struct {
number interface{}
precision interface{}
locale interface{}
ret interface{}
}{1, 4, nil, "1.0000"}

fc := funcs[ast.Format]
for _, tt := range formatTests {
fc := funcs[ast.Format]
f, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(tt.number, tt.precision, tt.locale)))
c.Assert(err, IsNil)
c.Assert(f, NotNil)
Expand All @@ -1655,31 +1665,50 @@ func (s *testEvaluatorSuite) TestFormat(c *C) {
c.Assert(r, testutil.DatumEquals, types.NewDatum(tt.ret))
}

origConfig := s.ctx.GetSessionVars().StmtCtx.TruncateAsWarning
s.ctx.GetSessionVars().StmtCtx.TruncateAsWarning = true
for _, tt := range formatTests1 {
fc := funcs[ast.Format]
f, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(tt.number, tt.precision)))
c.Assert(err, IsNil)
c.Assert(f, NotNil)
r, err := evalBuiltinFunc(f, chunk.Row{})
c.Assert(err, IsNil)
c.Assert(r, testutil.DatumEquals, types.NewDatum(tt.ret))
c.Assert(r, testutil.DatumEquals, types.NewDatum(tt.ret), Commentf("test %v", tt))
if tt.warnings > 0 {
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, tt.warnings, Commentf("test %v", tt))
for i := 0; i < tt.warnings; i++ {
c.Assert(terror.ErrorEqual(types.ErrTruncated, warnings[i].Err), IsTrue, Commentf("test %v", tt))
}
s.ctx.GetSessionVars().StmtCtx.SetWarnings([]stmtctx.SQLWarn{})
}
}
s.ctx.GetSessionVars().StmtCtx.TruncateAsWarning = origConfig

fc2 := funcs[ast.Format]
f2, err := fc2.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(formatTests2.number, formatTests2.precision, formatTests2.locale)))
f2, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(formatTests2.number, formatTests2.precision, formatTests2.locale)))
c.Assert(err, IsNil)
c.Assert(f2, NotNil)
r2, err := evalBuiltinFunc(f2, chunk.Row{})
c.Assert(types.NewDatum(err), testutil.DatumEquals, types.NewDatum(errors.New("not implemented")))
c.Assert(r2, testutil.DatumEquals, types.NewDatum(formatTests2.ret))

fc3 := funcs[ast.Format]
f3, err := fc3.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(formatTests3.number, formatTests3.precision, formatTests3.locale)))
f3, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(formatTests3.number, formatTests3.precision, formatTests3.locale)))
c.Assert(err, IsNil)
c.Assert(f3, NotNil)
r3, err := evalBuiltinFunc(f3, chunk.Row{})
c.Assert(types.NewDatum(err), testutil.DatumEquals, types.NewDatum(errors.New("not support for the specific locale")))
c.Assert(r3, testutil.DatumEquals, types.NewDatum(formatTests3.ret))

f4, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(formatTests4.number, formatTests4.precision, formatTests4.locale)))
c.Assert(err, IsNil)
c.Assert(f4, NotNil)
r4, err := evalBuiltinFunc(f4, chunk.Row{})
c.Assert(err, IsNil)
c.Assert(r4, testutil.DatumEquals, types.NewDatum(formatTests4.ret))
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, 1)
c.Assert(terror.ErrorEqual(errUnknownLocale, warnings[0].Err), IsTrue)
s.ctx.GetSessionVars().StmtCtx.SetWarnings([]stmtctx.SQLWarn{})
}

func (s *testEvaluatorSuite) TestFromBase64(c *C) {
Expand Down
2 changes: 2 additions & 0 deletions expression/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ var (
errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed])
errWarnOptionIgnored = terror.ClassExpression.New(mysql.WarnOptionIgnored, mysql.MySQLErrName[mysql.WarnOptionIgnored])
errTruncatedWrongValue = terror.ClassExpression.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue])
errUnknownLocale = terror.ClassExpression.New(mysql.ErrUnknownLocale, mysql.MySQLErrName[mysql.ErrUnknownLocale])
)

func init() {
Expand All @@ -60,6 +61,7 @@ func init() {
mysql.ErrWarnAllowedPacketOverflowed: mysql.ErrWarnAllowedPacketOverflowed,
mysql.WarnOptionIgnored: mysql.WarnOptionIgnored,
mysql.ErrTruncatedWrongValue: mysql.ErrTruncatedWrongValue,
mysql.ErrUnknownLocale: mysql.ErrUnknownLocale,
}
terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes
}
Expand Down

0 comments on commit 2b6dc60

Please sign in to comment.