Skip to content

Commit

Permalink
add more test
Browse files Browse the repository at this point in the history
  • Loading branch information
xiongjiwei committed Nov 24, 2021
1 parent 6149aa8 commit e7dd3f4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
27 changes: 13 additions & 14 deletions expression/builtin_convert_charset.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ func (b *builtinInternalFromBinarySig) evalString(row chunk.Row) (res string, is
return val, isNull, err
}
transferString := b.getTransferFunc()
res, err = transferString(val)
return res, false, err
tBytes, err := transferString([]byte(val))
return string(tBytes), false, err
}

func (b *builtinInternalFromBinarySig) vectorized() bool {
Expand All @@ -190,36 +190,35 @@ func (b *builtinInternalFromBinarySig) vecEvalString(input *chunk.Chunk, result
transferString := b.getTransferFunc()
result.ReserveString(n)
for i := 0; i < n; i++ {
var str string
if buf.IsNull(i) {
result.AppendNull()
continue
}
str = buf.GetString(i)
str, err = transferString(str)
str, err := transferString(buf.GetBytes(i))
if err != nil {
return err
}
result.AppendString(str)
result.AppendBytes(str)
}
return nil
}

func (b *builtinInternalFromBinarySig) getTransferFunc() func(string) (string, error) {
var transferString func(string) (string, error)
func (b *builtinInternalFromBinarySig) getTransferFunc() func([]byte) ([]byte, error) {
var transferString func([]byte) ([]byte, error)
if b.tp.Charset == charset.CharsetUTF8MB4 || b.tp.Charset == charset.CharsetUTF8 {
transferString = func(s string) (string, error) {
if !utf8.ValidString(s) {
return "", errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", s), charset.CharsetBin, b.tp.Charset)
transferString = func(s []byte) ([]byte, error) {
if !utf8.Valid(s) {
return nil, errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", s), charset.CharsetBin, b.tp.Charset)
}
return s, nil
}
} else {
enc := charset.NewEncoding(b.tp.Charset)
transferString = func(s string) (string, error) {
str, err := enc.DecodeString(s)
var buf []byte
transferString = func(s []byte) ([]byte, error) {
str, err := enc.Decode(buf, s)
if err != nil {
return "", errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", s), charset.CharsetBin, b.tp.Charset)
return nil, errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", s), charset.CharsetBin, b.tp.Charset)
}
return str, nil
}
Expand Down
26 changes: 20 additions & 6 deletions expression/constant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,11 @@ func newString(value string, collation string) *Constant {
}

func newFunction(funcName string, args ...Expression) Expression {
return newFunctionWithType(funcName, mysql.TypeLonglong, args...)
return newFunctionWithType(funcName, types.NewFieldType(mysql.TypeLonglong), args...)
}

func newFunctionWithType(funcName string, tp byte, args ...Expression) Expression {
ft := types.NewFieldType(tp)
return NewFunctionInternal(mock.NewContext(), funcName, ft, args...)
func newFunctionWithType(funcName string, tp *types.FieldType, args ...Expression) Expression {
return NewFunctionInternal(mock.NewContext(), funcName, tp, args...)
}

func TestConstantPropagation(t *testing.T) {
Expand Down Expand Up @@ -239,16 +238,31 @@ func TestConstantFoldingCharsetConvert(t *testing.T) {
}{
{
condition: newFunction(ast.Length, newFunctionWithType(
InternalFuncToBinary, mysql.TypeVarchar,
InternalFuncToBinary, types.NewFieldType(mysql.TypeVarchar),
newString("中文", "gbk_bin"))),
result: "4",
},
{
condition: newFunction(ast.Length, newFunctionWithType(
InternalFuncToBinary, mysql.TypeVarchar,
InternalFuncToBinary, types.NewFieldType(mysql.TypeVarchar),
newString("中文", "utf8mb4_bin"))),
result: "6",
},
{
condition: newFunction(ast.Concat, newFunctionWithType(
InternalFuncFromBinary, types.NewFieldType(mysql.TypeVarchar),
newString("中文", "binary"))),
result: "中文",
},
{
condition: newFunction(ast.Concat,
newFunctionWithType(
InternalFuncFromBinary, types.NewFieldTypeWithCollation(mysql.TypeVarchar, "gbk_bin", -1),
newString("\xd2\xbb", "binary")),
newString("中文", "gbk_bin"),
),
result: "一中文",
},
}
for _, tt := range tests {
newConds := FoldConstant(tt.condition)
Expand Down

0 comments on commit e7dd3f4

Please sign in to comment.