Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: add max_allowed_packet check in concat/concat_ws (#11137) #11275

Merged
merged 17 commits into from
Jul 23, 2019
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,17 +273,26 @@ func (c *concatFunctionClass) getFunction(ctx sessionctx.Context, args []Express
if bf.tp.Flen >= mysql.MaxBlobWidth {
bf.tp.Flen = mysql.MaxBlobWidth
}
sig := &builtinConcatSig{bf}

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
}

sig := &builtinConcatSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinConcatSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinConcatSig) Clone() builtinFunc {
newSig := &builtinConcatSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -296,6 +305,10 @@ func (b *builtinConcatSig) evalString(row chunk.Row) (d string, isNull bool, err
if isNull || err != nil {
return d, isNull, err
}
if uint64(len(s)+len(d)) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat", b.maxAllowedPacket))
return "", true, nil
}
s = append(s, []byte(d)...)
}
return string(s), false, nil
Expand Down Expand Up @@ -338,17 +351,25 @@ func (c *concatWSFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
bf.tp.Flen = mysql.MaxBlobWidth
}

sig := &builtinConcatWSSig{bf}
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
}

sig := &builtinConcatWSSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinConcatWSSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinConcatWSSig) Clone() builtinFunc {
newSig := &builtinConcatWSSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -358,25 +379,35 @@ func (b *builtinConcatWSSig) evalString(row chunk.Row) (string, bool, error) {
args := b.getArgs()
strs := make([]string, 0, len(args))
var sep string
for i, arg := range args {
val, isNull, err := arg.EvalString(b.ctx, row)
var targetLength int

N := len(args)
if N > 0 {
val, isNull, err := args[0].EvalString(b.ctx, row)
if err != nil || isNull {
// If the separator is NULL, the result is NULL.
return val, isNull, err
}
sep = val
}
for i := 1; i < N; i++ {
val, isNull, err := args[i].EvalString(b.ctx, row)
if err != nil {
return val, isNull, err
}

if isNull {
// If the separator is NULL, the result is NULL.
if i == 0 {
return val, isNull, nil
}
// CONCAT_WS() does not skip empty strings. However,
// it does skip any NULL values after the separator argument.
continue
}

if i == 0 {
sep = val
continue
targetLength += len(val)
if i > 1 {
targetLength += len(sep)
}
if uint64(targetLength) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat_ws", b.maxAllowedPacket))
return "", true, nil
}
strs = append(strs, val)
}
Expand Down
91 changes: 91 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,50 @@ func (s *testEvaluatorSuite) TestConcat(c *C) {
}
}

func (s *testEvaluatorSuite) TestConcatSig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000}
args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
concat := &builtinConcatSig{base, 5}

cases := []struct {
args []interface{}
warnings int
res string
}{
{[]interface{}{"a", "b"}, 0, "ab"},
{[]interface{}{"aaa", "bbb"}, 1, ""},
{[]interface{}{"中", "a"}, 0, "中a"},
{[]interface{}{"中文", "a"}, 2, ""},
}

for _, t := range cases {
input := chunk.NewChunkWithCapacity(colTypes, 10)
input.AppendString(0, t.args[0].(string))
input.AppendString(1, t.args[1].(string))

res, isNull, err := concat.evalString(input.GetRow(0))
c.Assert(res, Equals, t.res)
c.Assert(err, IsNil)
if t.warnings == 0 {
c.Assert(isNull, IsFalse)
} else {
c.Assert(isNull, IsTrue)
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(warnings, HasLen, t.warnings)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}
}
}

func (s *testEvaluatorSuite) TestConcatWS(c *C) {
defer testleak.AfterTest(c)()
cases := []struct {
Expand Down Expand Up @@ -246,6 +290,53 @@ func (s *testEvaluatorSuite) TestConcatWS(c *C) {
c.Assert(err, IsNil)
}

func (s *testEvaluatorSuite) TestConcatWSSig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000}
args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
&Column{Index: 2, RetType: colTypes[2]},
}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
concat := &builtinConcatWSSig{base, 6}

cases := []struct {
args []interface{}
warnings int
res string
}{
{[]interface{}{",", "a", "b"}, 0, "a,b"},
{[]interface{}{",", "aaa", "bbb"}, 1, ""},
{[]interface{}{",", "中", "a"}, 0, "中,a"},
{[]interface{}{",", "中文", "a"}, 2, ""},
}

for _, t := range cases {
input := chunk.NewChunkWithCapacity(colTypes, 10)
input.AppendString(0, t.args[0].(string))
input.AppendString(1, t.args[1].(string))
input.AppendString(2, t.args[2].(string))

res, isNull, err := concat.evalString(input.GetRow(0))
c.Assert(res, Equals, t.res)
c.Assert(err, IsNil)
if t.warnings == 0 {
c.Assert(isNull, IsFalse)
} else {
c.Assert(isNull, IsTrue)
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(warnings, HasLen, t.warnings)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}
}
}

func (s *testEvaluatorSuite) TestLeft(c *C) {
defer testleak.AfterTest(c)()
stmtCtx := s.ctx.GetSessionVars().StmtCtx
Expand Down
3 changes: 3 additions & 0 deletions util/mock/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ func NewContext() *Context {
sctx.sessionVars.MaxChunkSize = 32
sctx.sessionVars.StmtCtx.TimeZone = time.UTC
sctx.sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor()
if err := sctx.GetSessionVars().SetSystemVar(variable.MaxAllowedPacket, "67108864"); err != nil {
panic(err)
}
return sctx
}

Expand Down