From 5bb5cf73b3fbad68d1581874632c803466a6e6fb Mon Sep 17 00:00:00 2001 From: mengnan Date: Thu, 2 Aug 2018 01:02:45 +0800 Subject: [PATCH 1/4] expression: handle max_allowed_packet warnings for to_base64 functions --- expression/builtin_string.go | 25 ++++++++-- expression/builtin_string_test.go | 78 +++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/expression/builtin_string.go b/expression/builtin_string.go index c21ecae092aeb..a65bbd4d6c729 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -232,7 +232,7 @@ func (b *builtinASCIISig) Clone() builtinFunc { return newSig } -// eval evals a builtinASCIISig. +// evalInt evals a builtinASCIISig. // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_ascii func (b *builtinASCIISig) evalInt(row chunk.Row) (int64, bool, error) { val, isNull, err := b.args[0].EvalString(b.ctx, row) @@ -285,6 +285,7 @@ func (b *builtinConcatSig) Clone() builtinFunc { return newSig } +// evalString evals a builtinConcatSig // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_concat func (b *builtinConcatSig) evalString(row chunk.Row) (d string, isNull bool, err error) { var s []byte @@ -568,7 +569,7 @@ func (b *builtinRepeatSig) Clone() builtinFunc { return newSig } -// eval evals a builtinRepeatSig. +// evalString evals a builtinRepeatSig. // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_repeat func (b *builtinRepeatSig) evalString(row chunk.Row) (d string, isNull bool, err error) { str, isNull, err := b.args[0].EvalString(b.ctx, row) @@ -1504,6 +1505,7 @@ type trimFunctionClass struct { baseFunctionClass } +// getFunction sets trim built-in function signature. // The syntax of trim in mysql is 'TRIM([{BOTH | LEADING | TRAILING} [remstr] FROM] str), TRIM([remstr FROM] str)', // but we wil convert it into trim(str), trim(str, remstr) and trim(str, remstr, direction) in AST. func (c *trimFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { @@ -2471,8 +2473,8 @@ func (b *builtinOctStringSig) Clone() builtinFunc { return newSig } -// // evalString evals OCT(N). -// // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_oct +// evalString evals OCT(N). +// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_oct func (b *builtinOctStringSig) evalString(row chunk.Row) (string, bool, error) { val, isNull, err := b.args[0].EvalString(b.ctx, row) if isNull || err != nil { @@ -2988,17 +2990,26 @@ func (c *toBase64FunctionClass) getFunction(ctx sessionctx.Context, args []Expre } bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString) bf.tp.Flen = base64NeededEncodedLength(bf.args[0].GetType().Flen) - sig := &builtinToBase64Sig{bf} + + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + + sig := &builtinToBase64Sig{bf, maxAllowedPacket} return sig, nil } type builtinToBase64Sig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinToBase64Sig) Clone() builtinFunc { newSig := &builtinToBase64Sig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -3033,6 +3044,10 @@ func (b *builtinToBase64Sig) evalString(row chunk.Row) (d string, isNull bool, e return "", isNull, errors.Trace(err) } + if b.tp.Flen*mysql.MaxBytesOfCharacter > int(b.maxAllowedPacket) { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("to_base64", b.maxAllowedPacket)) + return "", true, nil + } if b.tp.Flen == -1 || b.tp.Flen > mysql.MaxBlobWidth { return "", true, nil } diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 46dde9237fa0d..d29da28e87124 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -1835,6 +1835,84 @@ func (s *testEvaluatorSuite) TestToBase64(c *C) { c.Assert(err, IsNil) } +func (s *testEvaluatorSuite) TestToBase64Sig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + } + + tests := []struct { + args string + expect string + isNil bool + getErr bool + maxAllowPacket uint64 + }{ + {"abc", "YWJj", false, false, 16}, + {"abc", "", true, false, 15}, + { + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", + "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrLw==", + false, + false, + 356, + }, + { + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", + "", + true, + false, + 355, + }, + { + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", + "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrL0FCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaYWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4\neXowMTIzNDU2Nzg5Ky9BQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWmFiY2RlZmdoaWprbG1ub3Bx\ncnN0dXZ3eHl6MDEyMzQ1Njc4OSsv", + false, + false, + 1036, + }, + { + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", + "", + true, + false, + 1035, + }, + } + + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + } + + warningCount := 0 + + for _, test := range tests { + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: base64NeededEncodedLength(len(test.args))} + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + toBase64 := &builtinToBase64Sig{base, test.maxAllowPacket} + + input := chunk.NewChunkWithCapacity(colTypes, 1) + input.AppendString(0, test.args) + res, isNull, err := toBase64.evalString(input.GetRow(0)) + if test.getErr { + c.Assert(err, NotNil) + } else { + c.Assert(err, IsNil) + } + if test.isNil { + c.Assert(isNull, IsTrue) + warningCount += 1 + } else { + c.Assert(isNull, IsFalse) + } + c.Assert(res, Equals, test.expect) + } + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, warningCount) + for _, warn := range warnings { + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, warn.Err), IsTrue) + } +} + func (s *testEvaluatorSuite) TestStringRight(c *C) { defer testleak.AfterTest(c)() fc := funcs[ast.Right] From 9700a53e653a92a8ea63d0f788158bca26a67926 Mon Sep 17 00:00:00 2001 From: mengnan Date: Fri, 10 Aug 2018 22:23:46 +0800 Subject: [PATCH 2/4] use `need encode length` to decide whether the result of `ToBase64` exceeds max_allowed_packet --- expression/builtin_string.go | 7 +++++-- expression/builtin_string_test.go | 12 ++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 40964e0c32911..100e78a2ac6d7 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -3054,8 +3054,11 @@ func (b *builtinToBase64Sig) evalString(row chunk.Row) (d string, isNull bool, e if isNull || err != nil { return "", isNull, errors.Trace(err) } - - if b.tp.Flen*mysql.MaxBytesOfCharacter > int(b.maxAllowedPacket) { + needEncodeLen := base64NeededEncodedLength(len(str)) + if needEncodeLen == -1 { + return "", true, nil + } + if needEncodeLen > int(b.maxAllowedPacket) { b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("to_base64", b.maxAllowedPacket)) return "", true, nil } diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 39fed7dc28991..5ebdc071b3ee6 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -1891,35 +1891,35 @@ func (s *testEvaluatorSuite) TestToBase64Sig(c *C) { getErr bool maxAllowPacket uint64 }{ - {"abc", "YWJj", false, false, 16}, - {"abc", "", true, false, 15}, + {"abc", "YWJj", false, false, 4}, + {"abc", "", true, false, 3}, { "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrLw==", false, false, - 356, + 89, }, { "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", "", true, false, - 355, + 88, }, { "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrL0FCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaYWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4\neXowMTIzNDU2Nzg5Ky9BQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWmFiY2RlZmdoaWprbG1ub3Bx\ncnN0dXZ3eHl6MDEyMzQ1Njc4OSsv", false, false, - 1036, + 259, }, { "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", "", true, false, - 1035, + 258, }, } From 0467a1f3a06584229fd14281792292079f9e7220 Mon Sep 17 00:00:00 2001 From: mengnan Date: Fri, 10 Aug 2018 23:04:17 +0800 Subject: [PATCH 3/4] remove get err check --- expression/builtin_string_test.go | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 5ebdc071b3ee6..f1079660d598e 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -1888,37 +1888,32 @@ func (s *testEvaluatorSuite) TestToBase64Sig(c *C) { args string expect string isNil bool - getErr bool maxAllowPacket uint64 }{ - {"abc", "YWJj", false, false, 4}, - {"abc", "", true, false, 3}, + {"abc", "YWJj", false, 4}, + {"abc", "", true, 3}, { "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrLw==", false, - false, 89, }, { "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", "", true, - false, 88, }, { "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrL0FCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaYWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4\neXowMTIzNDU2Nzg5Ky9BQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWmFiY2RlZmdoaWprbG1ub3Bx\ncnN0dXZ3eHl6MDEyMzQ1Njc4OSsv", false, - false, 259, }, { "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", "", true, - false, 258, }, } @@ -1937,11 +1932,7 @@ func (s *testEvaluatorSuite) TestToBase64Sig(c *C) { input := chunk.NewChunkWithCapacity(colTypes, 1) input.AppendString(0, test.args) res, isNull, err := toBase64.evalString(input.GetRow(0)) - if test.getErr { - c.Assert(err, NotNil) - } else { - c.Assert(err, IsNil) - } + c.Assert(err, IsNil) if test.isNil { c.Assert(isNull, IsTrue) warningCount += 1 From 7f158acc06a32a54d1f480d6239bf3d44b3706d2 Mon Sep 17 00:00:00 2001 From: mengnan Date: Fri, 10 Aug 2018 23:40:56 +0800 Subject: [PATCH 4/4] check the exactly warning count and warning content in each test case --- expression/builtin_string_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index f1079660d598e..7341f7f1055dc 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -23,6 +23,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/charset" @@ -1922,8 +1923,6 @@ func (s *testEvaluatorSuite) TestToBase64Sig(c *C) { &Column{Index: 0, RetType: colTypes[0]}, } - warningCount := 0 - for _, test := range tests { resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: base64NeededEncodedLength(len(test.args))} base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} @@ -1935,17 +1934,18 @@ func (s *testEvaluatorSuite) TestToBase64Sig(c *C) { c.Assert(err, IsNil) if test.isNil { c.Assert(isNull, IsTrue) - warningCount += 1 + + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, 1) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) + s.ctx.GetSessionVars().StmtCtx.SetWarnings([]stmtctx.SQLWarn{}) + } else { c.Assert(isNull, IsFalse) } c.Assert(res, Equals, test.expect) } - warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() - c.Assert(len(warnings), Equals, warningCount) - for _, warn := range warnings { - c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, warn.Err), IsTrue) - } } func (s *testEvaluatorSuite) TestStringRight(c *C) {