From 0ef3fa2e1b8797809c2c9fcc2a24efaf719643d5 Mon Sep 17 00:00:00 2001 From: Noble Mittal Date: Thu, 15 Feb 2024 16:49:20 +0530 Subject: [PATCH 1/9] evalEngine: Implement ELT Signed-off-by: Noble Mittal --- go/vt/vtgate/evalengine/cached_size.go | 12 ++ go/vt/vtgate/evalengine/compiler_asm.go | 17 +++ go/vt/vtgate/evalengine/fn_string.go | 114 +++++++++++++++++++ go/vt/vtgate/evalengine/testcases/cases.go | 51 +++++++++ go/vt/vtgate/evalengine/translate_builtin.go | 5 + 5 files changed, 199 insertions(+) diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index 7525bfdaec4..9989b7999c6 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -799,6 +799,18 @@ func (cached *builtinDegrees) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinElt) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinExp) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index e017a949a07..d9ce7ae08ea 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -2345,6 +2345,23 @@ func (asm *assembler) Fn_BIT_LENGTH() { }, "FN BIT_LENGTH VARCHAR(SP-1)") } +func (asm *assembler) Fn_ELT(args int, tt sqltypes.Type, tc collations.TypedCollation) { + asm.adjustStack(-args + 1) + asm.emit(func(env *ExpressionEnv) int { + i := env.vm.stack[env.vm.sp-args].(*evalInt64) + + if i.i < 1 || int(i.i) >= args || env.vm.stack[env.vm.sp-args+int(i.i)] == nil { + env.vm.stack[env.vm.sp-args] = nil + } else { + b := env.vm.stack[env.vm.sp-args+int(i.i)].(*evalBytes) + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalRaw(b.bytes, tt, tc) + } + + env.vm.sp -= args - 1 + return 1 + }, "FN ELT INT64(SP-%d) VARCHAR(SP-%d)...VARCHAR(SP-1)", args, args-1) +} + func (asm *assembler) Fn_INSERT(col collations.TypedCollation) { asm.adjustStack(-3) diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index 23ff1cbdca3..62fe744c5ca 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -30,6 +30,11 @@ import ( ) type ( + builtinElt struct { + CallExpr + collate collations.ID + } + builtinInsert struct { CallExpr collate collations.ID @@ -112,6 +117,7 @@ type ( } ) +var _ IR = (*builtinElt)(nil) var _ IR = (*builtinInsert)(nil) var _ IR = (*builtinChangeCase)(nil) var _ IR = (*builtinCharLength)(nil) @@ -127,6 +133,114 @@ var _ IR = (*builtinLeftRight)(nil) var _ IR = (*builtinPad)(nil) var _ IR = (*builtinTrim)(nil) +func (call *builtinElt) eval(env *ExpressionEnv) (eval, error) { + var ca collationAggregation + tt := sqltypes.VarChar + + args, err := call.args(env) + if err != nil { + return nil, err + } + + if args[0] == nil { + return nil, nil + } + + i := evalToInt64(args[0]).i + if i < 1 || i >= int64(len(args)) || args[i] == nil { + return nil, nil + } + + for _, arg := range args[1:] { + if arg == nil { + continue + } + + tt = concatSQLType(arg.SQLType(), tt) + err = ca.add(evalCollation(arg), env.collationEnv) + if err != nil { + return nil, err + } + } + + tc := ca.result() + // If we only had numbers, we instead fall back to the default + // collation instead of using the numeric collation. + if tc.Coercibility == collations.CoerceNumeric { + tc = typedCoercionCollation(tt, call.collate) + } + + b, err := evalToVarchar(args[i], tc.Collation, true) + if err != nil { + return nil, err + } + + return newEvalRaw(tt, b.bytes, b.col), nil +} + +func (call *builtinElt) compile(c *compiler) (ctype, error) { + args := make([]ctype, len(call.Arguments)) + + var ca collationAggregation + tt := sqltypes.VarChar + + var skip *jump + for i, arg := range call.Arguments { + var err error + args[i], err = arg.compile(c) + if err != nil { + return ctype{}, nil + } + + if i == 0 { + skip = c.compileNullCheck1(args[i]) + continue + } + + tt = concatSQLType(args[i].Type, tt) + err = ca.add(args[i].Col, c.env.CollationEnv()) + if err != nil { + return ctype{}, err + } + } + + tc := ca.result() + // If we only had numbers, we instead fall back to the default + // collation instead of using the numeric collation. + if tc.Coercibility == collations.CoerceNumeric { + tc = typedCoercionCollation(tt, call.collate) + } + + _ = c.compileToInt64(args[0], len(args)) + + for i, arg := range args[1:] { + offset := len(args) - (i + 1) + skip := c.compileNullCheckOffset(arg, offset) + + switch arg.Type { + case sqltypes.VarBinary, sqltypes.Binary, sqltypes.Blob: + if tc.Collation != collations.CollationBinaryID { + c.asm.Convert_xce(offset, arg.Type, tc.Collation) + } + case sqltypes.VarChar, sqltypes.Char, sqltypes.Text: + fromCharset := colldata.Lookup(arg.Col.Collation).Charset() + toCharset := colldata.Lookup(tc.Collation).Charset() + if fromCharset != toCharset && !toCharset.IsSuperset(fromCharset) { + c.asm.Convert_xce(offset, arg.Type, tc.Collation) + } + default: + c.asm.Convert_xce(offset, arg.Type, tc.Collation) + } + + c.asm.jumpDestination(skip) + } + + c.asm.Fn_ELT(len(args), tt, tc) + c.asm.jumpDestination(skip) + + return ctype{Type: tt, Col: tc, Flag: flagNullable}, nil +} + func insert(str, newstr *evalBytes, pos, l int) []byte { pos-- diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 9d9cdfa248e..01dd398ecaa 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -63,6 +63,7 @@ var Cases = []TestCase{ {Run: TupleComparisons}, {Run: Comparisons}, {Run: InStatement}, + {Run: FnElt}, {Run: FnInsert}, {Run: FnLower}, {Run: FnUpper}, @@ -1315,6 +1316,56 @@ var JSONExtract_Schema = []*querypb.Field{ }, } +func FnElt(yield Query) { + for _, s1 := range inputStrings { + for _, n := range inputBitwise { + yield(fmt.Sprintf("ELT(%s, %s)", n, s1), nil) + } + } + + for _, s1 := range inputStrings { + for _, s2 := range inputStrings { + for _, n := range inputBitwise { + yield(fmt.Sprintf("ELT(%s, %s, %s)", n, s1, s2), nil) + } + } + } + + for _, s1 := range inputStrings { + for _, s2 := range inputStrings { + for _, s3 := range inputStrings { + for _, n := range inputBitwise { + yield(fmt.Sprintf("ELT(%s, %s, %s, %s)", n, s1, s2, s3), nil) + } + } + } + } + + validIndex := []string{ + "1", + "2", + "3", + } + for _, s1 := range inputStrings { + for _, s2 := range inputStrings { + for _, s3 := range inputStrings { + for _, n := range validIndex { + yield(fmt.Sprintf("ELT(%s, %s, %s, %s)", n, s1, s2, s3), nil) + } + } + } + } + + mysqlDocSamples := []string{ + "ELT(1, 'Aa', 'Bb', 'Cc', 'Dd')", + "ELT(4, 'Aa', 'Bb', 'Cc', 'Dd')", + } + + for _, q := range mysqlDocSamples { + yield(q, nil) + } +} + func FnInsert(yield Query) { for _, s := range insertStrings { for _, ns := range insertStrings { diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 11618bb1d1a..54debfa89ef 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -265,6 +265,11 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) { return nil, argError(method) } return &builtinPad{CallExpr: call, collate: ast.cfg.Collation, left: method == "lpad"}, nil + case "elt": + if len(args) < 2 { + return nil, argError(method) + } + return &builtinElt{CallExpr: call, collate: ast.cfg.Collation}, nil case "lower", "lcase": if len(args) != 1 { return nil, argError(method) From 8147e88a1fd9ada917657da31329f377abab752e Mon Sep 17 00:00:00 2001 From: Noble Mittal Date: Mon, 19 Feb 2024 13:14:44 +0530 Subject: [PATCH 2/9] evalEngine: Implement FIELD Signed-off-by: Noble Mittal --- go/vt/vtgate/evalengine/compiler_asm.go | 31 +++++++ go/vt/vtgate/evalengine/fn_string.go | 88 ++++++++++++++++++++ go/vt/vtgate/evalengine/testcases/cases.go | 18 ++++ go/vt/vtgate/evalengine/testcases/inputs.go | 18 ++++ go/vt/vtgate/evalengine/translate_builtin.go | 5 ++ 5 files changed, 160 insertions(+) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index d9ce7ae08ea..a4b6d49e154 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -30,6 +30,7 @@ import ( "math" "math/bits" "net/netip" + "reflect" "strconv" "time" @@ -2345,6 +2346,36 @@ func (asm *assembler) Fn_BIT_LENGTH() { }, "FN BIT_LENGTH VARCHAR(SP-1)") } +func (asm *assembler) Fn_FIELD(args int) { + asm.adjustStack(-args + 1) + asm.emit(func(env *ExpressionEnv) int { + if env.vm.stack[env.vm.sp-args] == nil { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + } + + tar := env.vm.stack[env.vm.sp-args].(*evalBytes) + + for i := range args - 1 { + if env.vm.stack[env.vm.sp-args+i+1] == nil { + continue + } + + str := env.vm.stack[env.vm.sp-args+i+1].(*evalBytes) + if reflect.DeepEqual(str, tar) { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) + env.vm.sp -= args - 1 + return 1 + } + } + + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + }, "FN FIELD VARCHAR(SP-%d)...VARCHAR(SP-1)", args) +} + func (asm *assembler) Fn_ELT(args int, tt sqltypes.Type, tc collations.TypedCollation) { asm.adjustStack(-args + 1) asm.emit(func(env *ExpressionEnv) int { diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index 62fe744c5ca..cb6f7b86a1e 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -19,6 +19,7 @@ package evalengine import ( "bytes" "math" + "reflect" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/charset" @@ -30,6 +31,11 @@ import ( ) type ( + builtinField struct { + CallExpr + collate collations.ID + } + builtinElt struct { CallExpr collate collations.ID @@ -117,6 +123,7 @@ type ( } ) +var _ IR = (*builtinField)(nil) var _ IR = (*builtinElt)(nil) var _ IR = (*builtinInsert)(nil) var _ IR = (*builtinChangeCase)(nil) @@ -133,6 +140,87 @@ var _ IR = (*builtinLeftRight)(nil) var _ IR = (*builtinPad)(nil) var _ IR = (*builtinTrim)(nil) +func (call *builtinField) eval(env *ExpressionEnv) (eval, error) { + args, err := call.args(env) + if err != nil { + return nil, err + } + + strs := make([]*evalBytes, len(args)) + + tt := "" + for _, arg := range args { + switch arg.(type) { + case *evalBytes: + if tt == "n" { + tt = "f" + } else { + tt = "s" + } + case evalNumeric: + if tt == "s" { + tt = "f" + } else { + tt = "n" + } + } + } + + // switch tt { + // case "s": + // for i, arg := range args { + // strs[i], err = evalToVarchar(arg, call.collate, false) + // if err != nil { + // return nil, err + // } + // } + // case "n": + // for i, arg := range args { + // strs[i], err = evalToIn(arg, call.collate, false) + // if err != nil { + // return nil, err + // } + // } + // } + + for i, str := range strs[1:] { + if reflect.DeepEqual(str, strs[0]) { + return newEvalInt64(int64(i + 1)), nil + } + } + + return newEvalInt64(0), nil +} + +func (call *builtinField) compile(c *compiler) (ctype, error) { + strs := make([]ctype, len(call.Arguments)) + + for i, arg := range call.Arguments { + var err error + strs[i], err = arg.compile(c) + if err != nil { + return ctype{}, err + } + } + + for i, str := range strs { + offset := len(strs) - i + skip := c.compileNullCheckOffset(str, offset) + + switch { + case str.isTextual(): + default: + c.asm.Convert_xce(offset, sqltypes.VarChar, c.collation) + } + + c.asm.jumpDestination(skip) + } + + c.asm.Fn_FIELD(len(call.Arguments)) + + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} + func (call *builtinElt) eval(env *ExpressionEnv) (eval, error) { var ca collationAggregation tt := sqltypes.VarChar diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 01dd398ecaa..564003b8ab1 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -63,6 +63,7 @@ var Cases = []TestCase{ {Run: TupleComparisons}, {Run: Comparisons}, {Run: InStatement}, + {Run: FnField}, {Run: FnElt}, {Run: FnInsert}, {Run: FnLower}, @@ -1316,6 +1317,23 @@ var JSONExtract_Schema = []*querypb.Field{ }, } +func FnField(yield Query) { + // for _, s1 := range inputFieldStrings { + // for _, s2 := range inputFieldStrings { + // for _, s3 := range inputFieldStrings { + // yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + // } + // } + // } + + mysqlDocSamples := []string{ + "FIELD('πŸ˜ŠπŸ˜‚πŸ€’', 'πŸ˜‚πŸ€’', 'πŸ˜ŠπŸ˜‚', 'πŸ˜ŠπŸ˜‚πŸ€’')", + } + for _, q := range mysqlDocSamples { + yield(q, nil) + } +} + func FnElt(yield Query) { for _, s1 := range inputStrings { for _, n := range inputBitwise { diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index c453f904c96..2346c30e6e6 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -199,6 +199,24 @@ var inputStrings = []string{ // "_ucs2 'AabcΓ…Γ₯'", } +var inputFieldStrings = []string{ + "NULL", + "\"\"", + "\"a\"", + "\"abc\"", + "1", + "-1", + "0123", + "0xAACC", + "3.1415926", + "'123'", + "9223372036854775807", + "-9223372036854775808", + "999999999999999999999999", + "-999999999999999999999999", + "_binary 'MΓΌller' ", +} + var insertStrings = []string{ "NULL", "\"\"", diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 54debfa89ef..e094c8bd57a 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -265,6 +265,11 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) { return nil, argError(method) } return &builtinPad{CallExpr: call, collate: ast.cfg.Collation, left: method == "lpad"}, nil + case "field": + if len(args) < 2 { + return nil, argError(method) + } + return &builtinField{CallExpr: call, collate: ast.cfg.Collation}, nil case "elt": if len(args) < 2 { return nil, argError(method) From 55eb73a327f007834cd2603195b1791082120eef Mon Sep 17 00:00:00 2001 From: Noble Mittal Date: Thu, 29 Feb 2024 22:44:00 +0530 Subject: [PATCH 3/9] Fix FIELD for DOUBLE conversions Signed-off-by: Noble Mittal --- go/vt/vtgate/evalengine/compiler_asm.go | 67 +++++++-- go/vt/vtgate/evalengine/fn_string.go | 154 ++++++++++++++------ go/vt/vtgate/evalengine/testcases/cases.go | 35 +++-- go/vt/vtgate/evalengine/testcases/inputs.go | 18 --- 4 files changed, 183 insertions(+), 91 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index a4b6d49e154..7809d48ca5b 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -30,7 +30,6 @@ import ( "math" "math/bits" "net/netip" - "reflect" "strconv" "time" @@ -2346,7 +2345,7 @@ func (asm *assembler) Fn_BIT_LENGTH() { }, "FN BIT_LENGTH VARCHAR(SP-1)") } -func (asm *assembler) Fn_FIELD(args int) { +func (asm *assembler) Fn_FIELD(args int, containsOnlyString, containsOnlyInt64 bool) { asm.adjustStack(-args + 1) asm.emit(func(env *ExpressionEnv) int { if env.vm.stack[env.vm.sp-args] == nil { @@ -2355,18 +2354,64 @@ func (asm *assembler) Fn_FIELD(args int) { return 1 } - tar := env.vm.stack[env.vm.sp-args].(*evalBytes) + if containsOnlyInt64 { + tar := env.vm.stack[env.vm.sp-args].(*evalInt64) - for i := range args - 1 { - if env.vm.stack[env.vm.sp-args+i+1] == nil { - continue + for i := range args - 1 { + if env.vm.stack[env.vm.sp-args+i+1] == nil { + continue + } + + arg := env.vm.stack[env.vm.sp-args+i+1].(*evalInt64) + + if tar.i == arg.i { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) + env.vm.sp -= args - 1 + return 1 + } } + } else if containsOnlyString { + tar := env.vm.stack[env.vm.sp-args].(*evalBytes) - str := env.vm.stack[env.vm.sp-args+i+1].(*evalBytes) - if reflect.DeepEqual(str, tar) { - env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) - env.vm.sp -= args - 1 - return 1 + for i := range args - 1 { + if env.vm.stack[env.vm.sp-args+i+1] == nil { + continue + } + + str := env.vm.stack[env.vm.sp-args+i+1].(*evalBytes) + + // Compare target and current string + if len(tar.bytes) == len(str.bytes) { + eq := true + for i, b := range tar.bytes { + if str.bytes[i] != b { + eq = false + break + } + } + + if eq { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) + env.vm.sp -= args - 1 + return 1 + } + } + } + } else { + tar := env.vm.stack[env.vm.sp-args].(*evalFloat) + + for i := range args - 1 { + if env.vm.stack[env.vm.sp-args+i+1] == nil { + continue + } + + arg := env.vm.stack[env.vm.sp-args+i+1].(*evalFloat) + + if tar.f == arg.f { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) + env.vm.sp -= args - 1 + return 1 + } } } diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index cb6f7b86a1e..06bcbaaf243 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -19,7 +19,6 @@ package evalengine import ( "bytes" "math" - "reflect" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/charset" @@ -145,47 +144,85 @@ func (call *builtinField) eval(env *ExpressionEnv) (eval, error) { if err != nil { return nil, err } + if args[0] == nil { + return newEvalInt64(0), nil + } - strs := make([]*evalBytes, len(args)) - - tt := "" + // If the arguments contain both integral and string values + // MySQL converts all the arguments to DOUBLE + containsOnlyInt64, containsOnlyString := true, true for _, arg := range args { - switch arg.(type) { - case *evalBytes: - if tt == "n" { - tt = "f" - } else { - tt = "s" + if arg == nil { + continue + } + + containsOnlyInt64 = sqltypes.IsIntegral(arg.SQLType()) && containsOnlyInt64 + containsOnlyString = !sqltypes.IsNumber(arg.SQLType()) && containsOnlyString + } + + if containsOnlyInt64 { + tar := evalToInt64(args[0]) + + for i, arg := range args[1:] { + if arg == nil { + continue } - case evalNumeric: - if tt == "s" { - tt = "f" - } else { - tt = "n" + + e := evalToInt64(arg) + if tar.i == e.i { + return newEvalInt64(int64(i + 1)), nil + } + } + } else if containsOnlyString { + tar, ok := args[0].(*evalBytes) + if !ok { + tar, err = evalToVarchar(args[0], call.collate, true) + if err != nil { + return nil, err } } - } - // switch tt { - // case "s": - // for i, arg := range args { - // strs[i], err = evalToVarchar(arg, call.collate, false) - // if err != nil { - // return nil, err - // } - // } - // case "n": - // for i, arg := range args { - // strs[i], err = evalToIn(arg, call.collate, false) - // if err != nil { - // return nil, err - // } - // } - // } + for i, arg := range args[1:] { + if arg == nil { + continue + } - for i, str := range strs[1:] { - if reflect.DeepEqual(str, strs[0]) { - return newEvalInt64(int64(i + 1)), nil + var ok bool + e, ok := arg.(*evalBytes) + if !ok { + e, err = evalToVarchar(arg, call.collate, true) + if err != nil { + return nil, err + } + } + + // Compare target and current string + if len(tar.bytes) == len(e.bytes) { + eq := true + for i, b := range tar.bytes { + if e.bytes[i] != b { + eq = false + break + } + } + + if eq { + return newEvalInt64(int64(i + 1)), nil + } + } + } + } else { + tar, _ := evalToFloat(args[0]) + + for i, arg := range args[1:] { + if arg == nil { + continue + } + + e, _ := evalToFloat(arg) + if tar.f == e.f { + return newEvalInt64(int64(i + 1)), nil + } } } @@ -203,20 +240,49 @@ func (call *builtinField) compile(c *compiler) (ctype, error) { } } - for i, str := range strs { - offset := len(strs) - i - skip := c.compileNullCheckOffset(str, offset) + // If the arguments contain both integral and string values + // MySQL converts all the arguments to DOUBLE + containsOnlyString, containsOnlyInt64 := true, true + for _, str := range strs { + if sqltypes.IsNull(str.Type) { + continue + } + + containsOnlyString = !sqltypes.IsNumber(str.Type) && containsOnlyString + containsOnlyInt64 = sqltypes.IsIntegral(str.Type) && containsOnlyInt64 + } - switch { - case str.isTextual(): - default: - c.asm.Convert_xce(offset, sqltypes.VarChar, c.collation) + if containsOnlyInt64 { + for i, str := range strs { + offset := len(strs) - i + skip := c.compileNullCheckOffset(str, offset) + + _ = c.compileToInt64(str, offset) + c.asm.jumpDestination(skip) + } + } else if containsOnlyString { + for i, str := range strs { + offset := len(strs) - i + skip := c.compileNullCheckOffset(str, offset) + + switch { + case str.isTextual(): + default: + c.asm.Convert_xce(offset, sqltypes.VarChar, call.collate) + } + c.asm.jumpDestination(skip) } + } else { + for i, str := range strs { + offset := len(strs) - i + skip := c.compileNullCheckOffset(str, offset) - c.asm.jumpDestination(skip) + c.asm.Convert_xf(offset) + c.asm.jumpDestination(skip) + } } - c.asm.Fn_FIELD(len(call.Arguments)) + c.asm.Fn_FIELD(len(call.Arguments), containsOnlyString, containsOnlyInt64) return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil } diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 564003b8ab1..7a09447b57f 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -1318,16 +1318,25 @@ var JSONExtract_Schema = []*querypb.Field{ } func FnField(yield Query) { - // for _, s1 := range inputFieldStrings { - // for _, s2 := range inputFieldStrings { - // for _, s3 := range inputFieldStrings { - // yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) - // } - // } - // } + for _, s1 := range inputStrings { + for _, s2 := range inputStrings { + for _, s3 := range inputStrings { + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + } + } + } + + for _, s1 := range radianInputs { + for _, s2 := range radianInputs { + for _, s3 := range radianInputs { + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + } + } + } mysqlDocSamples := []string{ - "FIELD('πŸ˜ŠπŸ˜‚πŸ€’', 'πŸ˜‚πŸ€’', 'πŸ˜ŠπŸ˜‚', 'πŸ˜ŠπŸ˜‚πŸ€’')", + "FIELD('Bb', 'Aa', 'Bb', 'Cc', 'Dd', 'Ff')", + "FIELD('Gg', 'Aa', 'Bb', 'Cc', 'Dd', 'Ff')", } for _, q := range mysqlDocSamples { yield(q, nil) @@ -1349,16 +1358,6 @@ func FnElt(yield Query) { } } - for _, s1 := range inputStrings { - for _, s2 := range inputStrings { - for _, s3 := range inputStrings { - for _, n := range inputBitwise { - yield(fmt.Sprintf("ELT(%s, %s, %s, %s)", n, s1, s2, s3), nil) - } - } - } - } - validIndex := []string{ "1", "2", diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index 2346c30e6e6..c453f904c96 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -199,24 +199,6 @@ var inputStrings = []string{ // "_ucs2 'AabcΓ…Γ₯'", } -var inputFieldStrings = []string{ - "NULL", - "\"\"", - "\"a\"", - "\"abc\"", - "1", - "-1", - "0123", - "0xAACC", - "3.1415926", - "'123'", - "9223372036854775807", - "-9223372036854775808", - "999999999999999999999999", - "-999999999999999999999999", - "_binary 'MΓΌller' ", -} - var insertStrings = []string{ "NULL", "\"\"", From d09872d384195b50d4efe0db8bfeaad5812a4b90 Mon Sep 17 00:00:00 2001 From: Noble Mittal Date: Thu, 29 Feb 2024 22:46:12 +0530 Subject: [PATCH 4/9] Add cached_size changes for FIELD Signed-off-by: Noble Mittal --- go/vt/vtgate/evalengine/cached_size.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index 9989b7999c6..d9420d8264f 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -823,6 +823,18 @@ func (cached *builtinExp) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinField) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinFloor) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) From f15f6774b4c735adc943df7d05b0bd5c1074356c Mon Sep 17 00:00:00 2001 From: Noble Mittal Date: Mon, 18 Mar 2024 13:45:08 +0530 Subject: [PATCH 5/9] Split Fn_FIELD into separate functions Signed-off-by: Noble Mittal --- go/vt/vtgate/evalengine/compiler_asm.go | 114 +++++++++++++++--------- go/vt/vtgate/evalengine/fn_string.go | 8 +- 2 files changed, 77 insertions(+), 45 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 7809d48ca5b..f64793259fb 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -2345,7 +2345,7 @@ func (asm *assembler) Fn_BIT_LENGTH() { }, "FN BIT_LENGTH VARCHAR(SP-1)") } -func (asm *assembler) Fn_FIELD(args int, containsOnlyString, containsOnlyInt64 bool) { +func (asm *assembler) Fn_FIELD_i(args int) { asm.adjustStack(-args + 1) asm.emit(func(env *ExpressionEnv) int { if env.vm.stack[env.vm.sp-args] == nil { @@ -2354,60 +2354,57 @@ func (asm *assembler) Fn_FIELD(args int, containsOnlyString, containsOnlyInt64 b return 1 } - if containsOnlyInt64 { - tar := env.vm.stack[env.vm.sp-args].(*evalInt64) + tar := env.vm.stack[env.vm.sp-args].(*evalInt64) - for i := range args - 1 { - if env.vm.stack[env.vm.sp-args+i+1] == nil { - continue - } + for i := range args - 1 { + if env.vm.stack[env.vm.sp-args+i+1] == nil { + continue + } - arg := env.vm.stack[env.vm.sp-args+i+1].(*evalInt64) + arg := env.vm.stack[env.vm.sp-args+i+1].(*evalInt64) - if tar.i == arg.i { - env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) - env.vm.sp -= args - 1 - return 1 - } + if tar.i == arg.i { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) + env.vm.sp -= args - 1 + return 1 } - } else if containsOnlyString { - tar := env.vm.stack[env.vm.sp-args].(*evalBytes) + } - for i := range args - 1 { - if env.vm.stack[env.vm.sp-args+i+1] == nil { - continue - } + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + }, "FN FIELD INT64(SP-%d)...INT64(SP-1)", args) +} - str := env.vm.stack[env.vm.sp-args+i+1].(*evalBytes) +func (asm *assembler) Fn_FIELD_b(args int) { + asm.adjustStack(-args + 1) + asm.emit(func(env *ExpressionEnv) int { + if env.vm.stack[env.vm.sp-args] == nil { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + } - // Compare target and current string - if len(tar.bytes) == len(str.bytes) { - eq := true - for i, b := range tar.bytes { - if str.bytes[i] != b { - eq = false - break - } - } + tar := env.vm.stack[env.vm.sp-args].(*evalBytes) - if eq { - env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) - env.vm.sp -= args - 1 - return 1 - } - } + for i := range args - 1 { + if env.vm.stack[env.vm.sp-args+i+1] == nil { + continue } - } else { - tar := env.vm.stack[env.vm.sp-args].(*evalFloat) - for i := range args - 1 { - if env.vm.stack[env.vm.sp-args+i+1] == nil { - continue - } + str := env.vm.stack[env.vm.sp-args+i+1].(*evalBytes) - arg := env.vm.stack[env.vm.sp-args+i+1].(*evalFloat) + // Compare target and current string + if len(tar.bytes) == len(str.bytes) { + eq := true + for i, b := range tar.bytes { + if str.bytes[i] != b { + eq = false + break + } + } - if tar.f == arg.f { + if eq { env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) env.vm.sp -= args - 1 return 1 @@ -2421,6 +2418,37 @@ func (asm *assembler) Fn_FIELD(args int, containsOnlyString, containsOnlyInt64 b }, "FN FIELD VARCHAR(SP-%d)...VARCHAR(SP-1)", args) } +func (asm *assembler) Fn_FIELD_f(args int) { + asm.adjustStack(-args + 1) + asm.emit(func(env *ExpressionEnv) int { + if env.vm.stack[env.vm.sp-args] == nil { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + } + + tar := env.vm.stack[env.vm.sp-args].(*evalFloat) + + for i := range args - 1 { + if env.vm.stack[env.vm.sp-args+i+1] == nil { + continue + } + + arg := env.vm.stack[env.vm.sp-args+i+1].(*evalFloat) + + if tar.f == arg.f { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) + env.vm.sp -= args - 1 + return 1 + } + } + + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + }, "FN FIELD VARCHAR(SP-%d)...VARCHAR(SP-1)", args) +} + func (asm *assembler) Fn_ELT(args int, tt sqltypes.Type, tc collations.TypedCollation) { asm.adjustStack(-args + 1) asm.emit(func(env *ExpressionEnv) int { diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index 06bcbaaf243..9ab8c63c623 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -260,6 +260,8 @@ func (call *builtinField) compile(c *compiler) (ctype, error) { _ = c.compileToInt64(str, offset) c.asm.jumpDestination(skip) } + + c.asm.Fn_FIELD_i(len(call.Arguments)) } else if containsOnlyString { for i, str := range strs { offset := len(strs) - i @@ -272,6 +274,8 @@ func (call *builtinField) compile(c *compiler) (ctype, error) { } c.asm.jumpDestination(skip) } + + c.asm.Fn_FIELD_b(len(call.Arguments)) } else { for i, str := range strs { offset := len(strs) - i @@ -280,9 +284,9 @@ func (call *builtinField) compile(c *compiler) (ctype, error) { c.asm.Convert_xf(offset) c.asm.jumpDestination(skip) } - } - c.asm.Fn_FIELD(len(call.Arguments), containsOnlyString, containsOnlyInt64) + c.asm.Fn_FIELD_f(len(call.Arguments)) + } return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil } From 14944a2eb664344b1a1a9d36a1608a8fbd236aff Mon Sep 17 00:00:00 2001 From: Noble Mittal Date: Wed, 20 Mar 2024 01:46:22 +0530 Subject: [PATCH 6/9] Make use of bytes.Equal Signed-off-by: Noble Mittal --- go/vt/vtgate/evalengine/compiler_asm.go | 20 +++++--------------- go/vt/vtgate/evalengine/fn_string.go | 21 ++++++++------------- 2 files changed, 13 insertions(+), 28 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index f64793259fb..8981a33a136 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -2395,20 +2395,10 @@ func (asm *assembler) Fn_FIELD_b(args int) { str := env.vm.stack[env.vm.sp-args+i+1].(*evalBytes) // Compare target and current string - if len(tar.bytes) == len(str.bytes) { - eq := true - for i, b := range tar.bytes { - if str.bytes[i] != b { - eq = false - break - } - } - - if eq { - env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) - env.vm.sp -= args - 1 - return 1 - } + if bytes.Equal(tar.bytes, str.bytes) { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) + env.vm.sp -= args - 1 + return 1 } } @@ -2446,7 +2436,7 @@ func (asm *assembler) Fn_FIELD_f(args int) { env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) env.vm.sp -= args - 1 return 1 - }, "FN FIELD VARCHAR(SP-%d)...VARCHAR(SP-1)", args) + }, "FN FIELD FLOAT64(SP-%d)...FLOAT64(SP-1)", args) } func (asm *assembler) Fn_ELT(args int, tt sqltypes.Type, tc collations.TypedCollation) { diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index 9ab8c63c623..a46d3f0eeb7 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -197,18 +197,8 @@ func (call *builtinField) eval(env *ExpressionEnv) (eval, error) { } // Compare target and current string - if len(tar.bytes) == len(e.bytes) { - eq := true - for i, b := range tar.bytes { - if e.bytes[i] != b { - eq = false - break - } - } - - if eq { - return newEvalInt64(int64(i + 1)), nil - } + if bytes.Equal(tar.bytes, e.bytes) { + return newEvalInt64(int64(i + 1)), nil } } } else { @@ -281,7 +271,12 @@ func (call *builtinField) compile(c *compiler) (ctype, error) { offset := len(strs) - i skip := c.compileNullCheckOffset(str, offset) - c.asm.Convert_xf(offset) + switch str.Type { + case sqltypes.Float64: + default: + c.asm.Convert_xf(offset) + } + c.asm.jumpDestination(skip) } From 660d5c5b0f0b55d34ca5b3199798c72675d016ee Mon Sep 17 00:00:00 2001 From: Noble Mittal Date: Sat, 23 Mar 2024 02:19:22 +0530 Subject: [PATCH 7/9] Fix decimal cases for FIELD Signed-off-by: Noble Mittal --- go/vt/vtgate/evalengine/compiler_asm.go | 31 ++++++++++ go/vt/vtgate/evalengine/fn_string.go | 69 ++++++++++++++++++---- go/vt/vtgate/evalengine/testcases/cases.go | 18 ++++++ 3 files changed, 106 insertions(+), 12 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 8981a33a136..51644769282 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -2408,6 +2408,37 @@ func (asm *assembler) Fn_FIELD_b(args int) { }, "FN FIELD VARCHAR(SP-%d)...VARCHAR(SP-1)", args) } +func (asm *assembler) Fn_FIELD_d(args int) { + asm.adjustStack(-args + 1) + asm.emit(func(env *ExpressionEnv) int { + if env.vm.stack[env.vm.sp-args] == nil { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + } + + tar := env.vm.stack[env.vm.sp-args].(*evalDecimal) + + for i := range args - 1 { + if env.vm.stack[env.vm.sp-args+i+1] == nil { + continue + } + + arg := env.vm.stack[env.vm.sp-args+i+1].(*evalDecimal) + + if tar.dec.Equal(arg.dec) { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) + env.vm.sp -= args - 1 + return 1 + } + } + + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + }, "FN FIELD DECIMAL(SP-%d)...DECIMAL(SP-1)", args) +} + func (asm *assembler) Fn_FIELD_f(args int) { asm.adjustStack(-args + 1) asm.emit(func(env *ExpressionEnv) int { diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index a46d3f0eeb7..b67d64b390e 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -139,6 +139,20 @@ var _ IR = (*builtinLeftRight)(nil) var _ IR = (*builtinPad)(nil) var _ IR = (*builtinTrim)(nil) +func fieldSQLType(arg sqltypes.Type, tt sqltypes.Type) sqltypes.Type { + if typeIsTextual(arg) && typeIsTextual(tt) { + return sqltypes.VarChar + } else if sqltypes.IsIntegral(arg) && sqltypes.IsIntegral(tt) { + return sqltypes.Int64 + } + + if (sqltypes.IsIntegral(arg) || sqltypes.IsDecimal(arg)) && (sqltypes.IsIntegral(tt) || sqltypes.IsDecimal(tt)) { + return sqltypes.Decimal + } + + return sqltypes.Float64 +} + func (call *builtinField) eval(env *ExpressionEnv) (eval, error) { args, err := call.args(env) if err != nil { @@ -150,17 +164,17 @@ func (call *builtinField) eval(env *ExpressionEnv) (eval, error) { // If the arguments contain both integral and string values // MySQL converts all the arguments to DOUBLE - containsOnlyInt64, containsOnlyString := true, true - for _, arg := range args { + tt := args[0].SQLType() + + for _, arg := range args[1:] { if arg == nil { continue } - containsOnlyInt64 = sqltypes.IsIntegral(arg.SQLType()) && containsOnlyInt64 - containsOnlyString = !sqltypes.IsNumber(arg.SQLType()) && containsOnlyString + tt = fieldSQLType(arg.SQLType(), tt) } - if containsOnlyInt64 { + if tt == sqltypes.Int64 { tar := evalToInt64(args[0]) for i, arg := range args[1:] { @@ -173,7 +187,7 @@ func (call *builtinField) eval(env *ExpressionEnv) (eval, error) { return newEvalInt64(int64(i + 1)), nil } } - } else if containsOnlyString { + } else if tt == sqltypes.VarChar { tar, ok := args[0].(*evalBytes) if !ok { tar, err = evalToVarchar(args[0], call.collate, true) @@ -201,6 +215,19 @@ func (call *builtinField) eval(env *ExpressionEnv) (eval, error) { return newEvalInt64(int64(i + 1)), nil } } + } else if tt == sqltypes.Decimal { + tar := evalToDecimal(args[0], 0, 0) + + for i, arg := range args[1:] { + if arg == nil { + continue + } + + e := evalToDecimal(arg, 0, 0) + if tar.dec.Equal(e.dec) { + return newEvalInt64(int64(i + 1)), nil + } + } } else { tar, _ := evalToFloat(args[0]) @@ -232,27 +259,31 @@ func (call *builtinField) compile(c *compiler) (ctype, error) { // If the arguments contain both integral and string values // MySQL converts all the arguments to DOUBLE - containsOnlyString, containsOnlyInt64 := true, true + tt := strs[0].Type + for _, str := range strs { if sqltypes.IsNull(str.Type) { continue } - containsOnlyString = !sqltypes.IsNumber(str.Type) && containsOnlyString - containsOnlyInt64 = sqltypes.IsIntegral(str.Type) && containsOnlyInt64 + tt = fieldSQLType(str.Type, tt) } - if containsOnlyInt64 { + if tt == sqltypes.Int64 { for i, str := range strs { offset := len(strs) - i skip := c.compileNullCheckOffset(str, offset) - _ = c.compileToInt64(str, offset) + switch str.Type { + case sqltypes.Int64: + default: + c.asm.Convert_xi(offset) + } c.asm.jumpDestination(skip) } c.asm.Fn_FIELD_i(len(call.Arguments)) - } else if containsOnlyString { + } else if tt == sqltypes.VarChar { for i, str := range strs { offset := len(strs) - i skip := c.compileNullCheckOffset(str, offset) @@ -266,6 +297,20 @@ func (call *builtinField) compile(c *compiler) (ctype, error) { } c.asm.Fn_FIELD_b(len(call.Arguments)) + } else if tt == sqltypes.Decimal { + for i, str := range strs { + offset := len(strs) - i + skip := c.compileNullCheckOffset(str, offset) + + switch str.Type { + case sqltypes.Decimal: + default: + c.asm.Convert_xd(offset, 0, 0) + } + c.asm.jumpDestination(skip) + } + + c.asm.Fn_FIELD_d(len(call.Arguments)) } else { for i, str := range strs { offset := len(strs) - i diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 7a09447b57f..56fe576b6be 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -1334,6 +1334,24 @@ func FnField(yield Query) { } } + // Contains failing testcases + for _, s1 := range inputStrings { + for _, s2 := range radianInputs { + for _, s3 := range inputStrings { + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + } + } + } + + // Contains failing testcases + for _, s1 := range inputBitwise { + for _, s2 := range inputBitwise { + for _, s3 := range inputBitwise { + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + } + } + } + mysqlDocSamples := []string{ "FIELD('Bb', 'Aa', 'Bb', 'Cc', 'Dd', 'Ff')", "FIELD('Gg', 'Aa', 'Bb', 'Cc', 'Dd', 'Ff')", From 67d898bbaa7854f31d3d7563154b49483d8581fc Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Thu, 28 Mar 2024 21:17:46 +0100 Subject: [PATCH 8/9] Fix literal NULL in list with numerical types If there's a literal NULL with only numerical types, we also have to convert to DOUBLE instead of ignoring the NULL. Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/fn_string.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index b67d64b390e..d96ffa0f9ce 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -140,6 +140,15 @@ var _ IR = (*builtinPad)(nil) var _ IR = (*builtinTrim)(nil) func fieldSQLType(arg sqltypes.Type, tt sqltypes.Type) sqltypes.Type { + if sqltypes.IsNull(arg) { + // If we have a NULL combined with only so far numerical types, + // we have to convert it all to DOUBLE. + if sqltypes.IsIntegral(tt) || sqltypes.IsDecimal(tt) { + return sqltypes.Float64 + } + return tt + } + if typeIsTextual(arg) && typeIsTextual(tt) { return sqltypes.VarChar } else if sqltypes.IsIntegral(arg) && sqltypes.IsIntegral(tt) { @@ -167,11 +176,14 @@ func (call *builtinField) eval(env *ExpressionEnv) (eval, error) { tt := args[0].SQLType() for _, arg := range args[1:] { + var at sqltypes.Type if arg == nil { - continue + at = sqltypes.Null + } else { + at = arg.SQLType() } - tt = fieldSQLType(arg.SQLType(), tt) + tt = fieldSQLType(at, tt) } if tt == sqltypes.Int64 { @@ -262,10 +274,6 @@ func (call *builtinField) compile(c *compiler) (ctype, error) { tt := strs[0].Type for _, str := range strs { - if sqltypes.IsNull(str.Type) { - continue - } - tt = fieldSQLType(str.Type, tt) } From baf6e3cb6a35e0adf7f86cb9313390b9d9b8cd08 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Thu, 28 Mar 2024 22:05:28 +0100 Subject: [PATCH 9/9] Fix collation handling for FIELD Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/compiler_asm.go | 23 +++++++++++++-- go/vt/vtgate/evalengine/fn_string.go | 38 +++++++------------------ 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 51644769282..8329f70e560 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -2376,7 +2376,7 @@ func (asm *assembler) Fn_FIELD_i(args int) { }, "FN FIELD INT64(SP-%d)...INT64(SP-1)", args) } -func (asm *assembler) Fn_FIELD_b(args int) { +func (asm *assembler) Fn_FIELD_b(args int, col colldata.Collation) { asm.adjustStack(-args + 1) asm.emit(func(env *ExpressionEnv) int { if env.vm.stack[env.vm.sp-args] == nil { @@ -2394,8 +2394,27 @@ func (asm *assembler) Fn_FIELD_b(args int) { str := env.vm.stack[env.vm.sp-args+i+1].(*evalBytes) + // We cannot do these comparison earlier in the compilation, + // because if we convert everything first, we error on cases + // where there is a match. MySQL will do an element for element + // comparison where if there's a match already, it doesn't matter + // if there was an invalid conversion later on. + // + // This means we also must convert here in this compiler function + // and can't eagerly do the conversion. + toCharset := col.Charset() + fromCharset := colldata.Lookup(str.col.Collation).Charset() + if fromCharset != toCharset && !toCharset.IsSuperset(fromCharset) { + str, env.vm.err = evalToVarchar(str, col.ID(), true) + if env.vm.err != nil { + env.vm.stack[env.vm.sp-args] = nil + env.vm.sp -= args - 1 + return 1 + } + } + // Compare target and current string - if bytes.Equal(tar.bytes, str.bytes) { + if col.Collate(tar.bytes, str.bytes, false) == 0 { env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) env.vm.sp -= args - 1 return 1 diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index d96ffa0f9ce..9543c8befa8 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -200,30 +200,22 @@ func (call *builtinField) eval(env *ExpressionEnv) (eval, error) { } } } else if tt == sqltypes.VarChar { - tar, ok := args[0].(*evalBytes) - if !ok { - tar, err = evalToVarchar(args[0], call.collate, true) - if err != nil { - return nil, err - } - } + col := evalCollation(args[0]) + collation := colldata.Lookup(col.Collation) + tar := args[0].(*evalBytes) for i, arg := range args[1:] { if arg == nil { continue } - var ok bool - e, ok := arg.(*evalBytes) - if !ok { - e, err = evalToVarchar(arg, call.collate, true) - if err != nil { - return nil, err - } + e, err := evalToVarchar(arg, col.Collation, true) + if err != nil { + return nil, err } // Compare target and current string - if bytes.Equal(tar.bytes, e.bytes) { + if collation.Collate(tar.bytes, e.bytes, false) == 0 { return newEvalInt64(int64(i + 1)), nil } } @@ -272,6 +264,7 @@ func (call *builtinField) compile(c *compiler) (ctype, error) { // If the arguments contain both integral and string values // MySQL converts all the arguments to DOUBLE tt := strs[0].Type + col := strs[0].Col for _, str := range strs { tt = fieldSQLType(str.Type, tt) @@ -292,19 +285,8 @@ func (call *builtinField) compile(c *compiler) (ctype, error) { c.asm.Fn_FIELD_i(len(call.Arguments)) } else if tt == sqltypes.VarChar { - for i, str := range strs { - offset := len(strs) - i - skip := c.compileNullCheckOffset(str, offset) - - switch { - case str.isTextual(): - default: - c.asm.Convert_xce(offset, sqltypes.VarChar, call.collate) - } - c.asm.jumpDestination(skip) - } - - c.asm.Fn_FIELD_b(len(call.Arguments)) + collation := colldata.Lookup(col.Collation) + c.asm.Fn_FIELD_b(len(call.Arguments), collation) } else if tt == sqltypes.Decimal { for i, str := range strs { offset := len(strs) - i