Skip to content

Commit

Permalink
expression: handle max_allowed_packet warnings for pad functions (#7171)
Browse files Browse the repository at this point in the history
  • Loading branch information
zz-jason authored Jul 31, 2018
1 parent b8a2b1b commit c38f567
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 7 deletions.
50 changes: 46 additions & 4 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -1760,24 +1761,33 @@ func (c *lpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1])
SetBinFlagOrBinStr(args[0].GetType(), bf.tp)
SetBinFlagOrBinStr(args[2].GetType(), bf.tp)

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) {
sig := &builtinLpadBinarySig{bf}
sig := &builtinLpadBinarySig{bf, maxAllowedPacket}
return sig, nil
}
if bf.tp.Flen *= 4; bf.tp.Flen > mysql.MaxBlobWidth {
bf.tp.Flen = mysql.MaxBlobWidth
}
sig := &builtinLpadSig{bf}
sig := &builtinLpadSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinLpadBinarySig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinLpadBinarySig) Clone() builtinFunc {
newSig := &builtinLpadBinarySig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -1796,6 +1806,11 @@ func (b *builtinLpadBinarySig) evalString(row chunk.Row) (string, bool, error) {
}
targetLength := int(length)

if uint64(targetLength) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("lpad", b.maxAllowedPacket))
return "", true, nil
}

padStr, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
Expand All @@ -1815,11 +1830,13 @@ func (b *builtinLpadBinarySig) evalString(row chunk.Row) (string, bool, error) {

type builtinLpadSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinLpadSig) Clone() builtinFunc {
newSig := &builtinLpadSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -1838,6 +1855,11 @@ func (b *builtinLpadSig) evalString(row chunk.Row) (string, bool, error) {
}
targetLength := int(length)

if uint64(targetLength)*uint64(mysql.MaxBytesOfCharacter) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("lpad", b.maxAllowedPacket))
return "", true, nil
}

padStr, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
Expand Down Expand Up @@ -1867,24 +1889,33 @@ func (c *rpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1])
SetBinFlagOrBinStr(args[0].GetType(), bf.tp)
SetBinFlagOrBinStr(args[2].GetType(), bf.tp)

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) {
sig := &builtinRpadBinarySig{bf}
sig := &builtinRpadBinarySig{bf, maxAllowedPacket}
return sig, nil
}
if bf.tp.Flen *= 4; bf.tp.Flen > mysql.MaxBlobWidth {
bf.tp.Flen = mysql.MaxBlobWidth
}
sig := &builtinRpadSig{bf}
sig := &builtinRpadSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinRpadBinarySig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinRpadBinarySig) Clone() builtinFunc {
newSig := &builtinRpadBinarySig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -1902,6 +1933,10 @@ func (b *builtinRpadBinarySig) evalString(row chunk.Row) (string, bool, error) {
return "", true, errors.Trace(err)
}
targetLength := int(length)
if uint64(targetLength) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("rpad", b.maxAllowedPacket))
return "", true, nil
}

padStr, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
Expand All @@ -1922,11 +1957,13 @@ func (b *builtinRpadBinarySig) evalString(row chunk.Row) (string, bool, error) {

type builtinRpadSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinRpadSig) Clone() builtinFunc {
newSig := &builtinRpadSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -1945,6 +1982,11 @@ func (b *builtinRpadSig) evalString(row chunk.Row) (string, bool, error) {
}
targetLength := int(length)

if uint64(targetLength)*uint64(mysql.MaxBytesOfCharacter) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("rpad", b.maxAllowedPacket))
return "", true, nil
}

padStr, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
Expand Down
42 changes: 42 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -1266,6 +1267,47 @@ func (s *testEvaluatorSuite) TestRpad(c *C) {
}
}

func (s *testEvaluatorSuite) TestRpadSig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeLonglong},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000}

args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
&Column{Index: 2, RetType: colTypes[2]},
}

base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
rpad := &builtinRpadSig{base, 1000}

input := chunk.NewChunkWithCapacity(colTypes, 10)
input.AppendString(0, "abc")
input.AppendString(0, "abc")
input.AppendInt64(1, 6)
input.AppendInt64(1, 10000)
input.AppendString(2, "123")
input.AppendString(2, "123")

res, isNull, err := rpad.evalString(input.GetRow(0))
c.Assert(res, Equals, "abc123")
c.Assert(isNull, IsFalse)
c.Assert(err, IsNil)

res, isNull, err = rpad.evalString(input.GetRow(1))
c.Assert(res, Equals, "")
c.Assert(isNull, IsTrue)
c.Assert(err, IsNil)

warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, 1)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}

func (s *testEvaluatorSuite) TestInstr(c *C) {
defer testleak.AfterTest(c)()
tbl := []struct {
Expand Down
8 changes: 6 additions & 2 deletions expression/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@ import (

// Error instances.
var (
// All the exported errors are defined here:
ErrIncorrectParameterCount = terror.ClassExpression.New(mysql.ErrWrongParamcountToNativeFct, mysql.MySQLErrName[mysql.ErrWrongParamcountToNativeFct])
ErrDivisionByZero = terror.ClassExpression.New(mysql.ErrDivisionByZero, mysql.MySQLErrName[mysql.ErrDivisionByZero])
ErrRegexp = terror.ClassExpression.New(mysql.ErrRegexp, mysql.MySQLErrName[mysql.ErrRegexp])
ErrOperandColumns = terror.ClassExpression.New(mysql.ErrOperandColumns, mysql.MySQLErrName[mysql.ErrOperandColumns])
ErrCutValueGroupConcat = terror.ClassExpression.New(mysql.ErrCutValueGroupConcat, mysql.MySQLErrName[mysql.ErrCutValueGroupConcat])

// All the un-exported errors are defined here:
errFunctionNotExists = terror.ClassExpression.New(mysql.ErrSpDoesNotExist, mysql.MySQLErrName[mysql.ErrSpDoesNotExist])
errZlibZData = terror.ClassTypes.New(mysql.ErrZlibZData, mysql.MySQLErrName[mysql.ErrZlibZData])
errIncorrectArgs = terror.ClassExpression.New(mysql.ErrWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments])
errUnknownCharacterSet = terror.ClassExpression.New(mysql.ErrUnknownCharacterSet, mysql.MySQLErrName[mysql.ErrUnknownCharacterSet])
errDefaultValue = terror.ClassExpression.New(mysql.ErrInvalidDefault, "invalid default value")
errDeprecatedSyntaxNoReplacement = terror.ClassExpression.New(mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.MySQLErrName[mysql.ErrWarnDeprecatedSyntaxNoReplacement])
errBadField = terror.ClassExpression.New(mysql.ErrBadField, mysql.MySQLErrName[mysql.ErrBadField])
ErrOperandColumns = terror.ClassExpression.New(mysql.ErrOperandColumns, mysql.MySQLErrName[mysql.ErrOperandColumns])
ErrCutValueGroupConcat = terror.ClassExpression.New(mysql.ErrCutValueGroupConcat, mysql.MySQLErrName[mysql.ErrCutValueGroupConcat])
errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed])
)

func init() {
Expand All @@ -49,6 +52,7 @@ func init() {
mysql.ErrWarnDeprecatedSyntaxNoReplacement: mysql.ErrWarnDeprecatedSyntaxNoReplacement,
mysql.ErrOperandColumns: mysql.ErrOperandColumns,
mysql.ErrRegexp: mysql.ErrRegexp,
mysql.ErrWarnAllowedPacketOverflowed: mysql.ErrWarnAllowedPacketOverflowed,
}
terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes
}
Expand Down
2 changes: 2 additions & 0 deletions expression/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func (s *testEvaluatorSuite) SetUpSuite(c *C) {
s.Parser = parser.New()
s.ctx = mock.NewContext()
s.ctx.GetSessionVars().StmtCtx.TimeZone = time.Local
s.ctx.GetSessionVars().SetSystemVar("max_allowed_packet", "67108864")
}

func (s *testEvaluatorSuite) TearDownSuite(c *C) {
Expand All @@ -58,6 +59,7 @@ func (s *testEvaluatorSuite) SetUpTest(c *C) {
}

func (s *testEvaluatorSuite) TearDownTest(c *C) {
s.ctx.GetSessionVars().StmtCtx.SetWarnings(nil)
testleak.AfterTest(c)()
}

Expand Down
2 changes: 1 addition & 1 deletion mysql/errname.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ var MySQLErrName = map[uint16]string{
ErrUnknownTimeZone: "Unknown or incorrect time zone: '%-.64s'",
ErrWarnInvalidTimestamp: "Invalid TIMESTAMP value in column '%s' at row %d",
ErrInvalidCharacterString: "Invalid %s character string: '%.64s'",
ErrWarnAllowedPacketOverflowed: "Result of %s() was larger than maxAllowedPacket (%d) - truncated",
ErrWarnAllowedPacketOverflowed: "Result of %s() was larger than max_allowed_packet (%d) - truncated",
ErrConflictingDeclarations: "Conflicting declarations: '%s%s' and '%s%s'",
ErrSpNoRecursiveCreate: "Can't create a %s from within another stored routine",
ErrSpAlreadyExists: "%s %s already exists",
Expand Down

0 comments on commit c38f567

Please sign in to comment.