diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 5fb2b5a3183c8..b8882bbcf4080 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -273,17 +273,26 @@ func (c *concatFunctionClass) getFunction(ctx sessionctx.Context, args []Express if bf.tp.Flen >= mysql.MaxBlobWidth { bf.tp.Flen = mysql.MaxBlobWidth } - sig := &builtinConcatSig{bf} + + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, err + } + + sig := &builtinConcatSig{bf, maxAllowedPacket} return sig, nil } type builtinConcatSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinConcatSig) Clone() builtinFunc { newSig := &builtinConcatSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -296,6 +305,10 @@ func (b *builtinConcatSig) evalString(row chunk.Row) (d string, isNull bool, err if isNull || err != nil { return d, isNull, err } + if uint64(len(s)+len(d)) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat", b.maxAllowedPacket)) + return "", true, nil + } s = append(s, []byte(d)...) } return string(s), false, nil @@ -338,17 +351,25 @@ func (c *concatWSFunctionClass) getFunction(ctx sessionctx.Context, args []Expre bf.tp.Flen = mysql.MaxBlobWidth } - sig := &builtinConcatWSSig{bf} + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, err + } + + sig := &builtinConcatWSSig{bf, maxAllowedPacket} return sig, nil } type builtinConcatWSSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinConcatWSSig) Clone() builtinFunc { newSig := &builtinConcatWSSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -358,25 +379,35 @@ func (b *builtinConcatWSSig) evalString(row chunk.Row) (string, bool, error) { args := b.getArgs() strs := make([]string, 0, len(args)) var sep string - for i, arg := range args { - val, isNull, err := arg.EvalString(b.ctx, row) + var targetLength int + + N := len(args) + if N > 0 { + val, isNull, err := args[0].EvalString(b.ctx, row) + if err != nil || isNull { + // If the separator is NULL, the result is NULL. + return val, isNull, err + } + sep = val + } + for i := 1; i < N; i++ { + val, isNull, err := args[i].EvalString(b.ctx, row) if err != nil { return val, isNull, err } - if isNull { - // If the separator is NULL, the result is NULL. - if i == 0 { - return val, isNull, nil - } // CONCAT_WS() does not skip empty strings. However, // it does skip any NULL values after the separator argument. continue } - if i == 0 { - sep = val - continue + targetLength += len(val) + if i > 1 { + targetLength += len(sep) + } + if uint64(targetLength) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat_ws", b.maxAllowedPacket)) + return "", true, nil } strs = append(strs, val) } diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 5c9b6debb62e4..b07ebd7d11f13 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -171,6 +171,50 @@ func (s *testEvaluatorSuite) TestConcat(c *C) { } } +func (s *testEvaluatorSuite) TestConcatSig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeVarchar}, + } + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000} + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + &Column{Index: 1, RetType: colTypes[1]}, + } + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + concat := &builtinConcatSig{base, 5} + + cases := []struct { + args []interface{} + warnings int + res string + }{ + {[]interface{}{"a", "b"}, 0, "ab"}, + {[]interface{}{"aaa", "bbb"}, 1, ""}, + {[]interface{}{"中", "a"}, 0, "中a"}, + {[]interface{}{"中文", "a"}, 2, ""}, + } + + for _, t := range cases { + input := chunk.NewChunkWithCapacity(colTypes, 10) + input.AppendString(0, t.args[0].(string)) + input.AppendString(1, t.args[1].(string)) + + res, isNull, err := concat.evalString(input.GetRow(0)) + c.Assert(res, Equals, t.res) + c.Assert(err, IsNil) + if t.warnings == 0 { + c.Assert(isNull, IsFalse) + } else { + c.Assert(isNull, IsTrue) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(warnings, HasLen, t.warnings) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) + } + } +} + func (s *testEvaluatorSuite) TestConcatWS(c *C) { defer testleak.AfterTest(c)() cases := []struct { @@ -246,6 +290,53 @@ func (s *testEvaluatorSuite) TestConcatWS(c *C) { c.Assert(err, IsNil) } +func (s *testEvaluatorSuite) TestConcatWSSig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeVarchar}, + } + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000} + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + &Column{Index: 1, RetType: colTypes[1]}, + &Column{Index: 2, RetType: colTypes[2]}, + } + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + concat := &builtinConcatWSSig{base, 6} + + cases := []struct { + args []interface{} + warnings int + res string + }{ + {[]interface{}{",", "a", "b"}, 0, "a,b"}, + {[]interface{}{",", "aaa", "bbb"}, 1, ""}, + {[]interface{}{",", "中", "a"}, 0, "中,a"}, + {[]interface{}{",", "中文", "a"}, 2, ""}, + } + + for _, t := range cases { + input := chunk.NewChunkWithCapacity(colTypes, 10) + input.AppendString(0, t.args[0].(string)) + input.AppendString(1, t.args[1].(string)) + input.AppendString(2, t.args[2].(string)) + + res, isNull, err := concat.evalString(input.GetRow(0)) + c.Assert(res, Equals, t.res) + c.Assert(err, IsNil) + if t.warnings == 0 { + c.Assert(isNull, IsFalse) + } else { + c.Assert(isNull, IsTrue) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(warnings, HasLen, t.warnings) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) + } + } +} + func (s *testEvaluatorSuite) TestLeft(c *C) { defer testleak.AfterTest(c)() stmtCtx := s.ctx.GetSessionVars().StmtCtx diff --git a/util/mock/context.go b/util/mock/context.go index c3419792ac857..a7b76b282baf3 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -228,6 +228,9 @@ func NewContext() *Context { sctx.sessionVars.MaxChunkSize = 32 sctx.sessionVars.StmtCtx.TimeZone = time.UTC sctx.sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor() + if err := sctx.GetSessionVars().SetSystemVar(variable.MaxAllowedPacket, "67108864"); err != nil { + panic(err) + } return sctx }