diff --git a/util/sqlexec/utils.go b/util/sqlexec/utils.go index 05d87ce5bbd37..4854ec2d560b7 100644 --- a/util/sqlexec/utils.go +++ b/util/sqlexec/utils.go @@ -17,6 +17,7 @@ package sqlexec import ( "encoding/json" "io" + "reflect" "strconv" "strings" "time" @@ -161,11 +162,7 @@ func escapeSQL(sql string, args ...interface{}) ([]byte, error) { case float64: buf = strconv.AppendFloat(buf, v, 'g', -1, 64) case bool: - if v { - buf = append(buf, '1') - } else { - buf = append(buf, '0') - } + buf = appendSQLArgBool(buf, v) case time.Time: if v.IsZero() { buf = append(buf, "'0000-00-00'"...) @@ -187,9 +184,7 @@ func escapeSQL(sql string, args ...interface{}) ([]byte, error) { buf = append(buf, '\'') } case string: - buf = append(buf, '\'') - buf = escapeStringBackslash(buf, v) - buf = append(buf, '\'') + buf = appendSQLArgString(buf, v) case []string: for i, k := range v { if i > 0 { @@ -214,7 +209,25 @@ func escapeSQL(sql string, args ...interface{}) ([]byte, error) { buf = strconv.AppendFloat(buf, k, 'g', -1, 64) } default: - return nil, errors.Errorf("unsupported %d-th argument: %v", argPos, arg) + // slow path based on reflection + reflectTp := reflect.TypeOf(arg) + kind := reflectTp.Kind() + switch kind { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + buf = strconv.AppendInt(buf, reflect.ValueOf(arg).Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + buf = strconv.AppendUint(buf, reflect.ValueOf(arg).Uint(), 10) + case reflect.Float32: + buf = strconv.AppendFloat(buf, reflect.ValueOf(arg).Float(), 'g', -1, 32) + case reflect.Float64: + buf = strconv.AppendFloat(buf, reflect.ValueOf(arg).Float(), 'g', -1, 64) + case reflect.Bool: + buf = appendSQLArgBool(buf, reflect.ValueOf(arg).Bool()) + case reflect.String: + buf = appendSQLArgString(buf, reflect.ValueOf(arg).String()) + default: + return nil, errors.Errorf("unsupported %d-th argument: %v", argPos, arg) + } } } i++ // skip specifier @@ -228,6 +241,20 @@ func escapeSQL(sql string, args ...interface{}) ([]byte, error) { return buf, nil } +func appendSQLArgBool(buf []byte, v bool) []byte { + if v { + return append(buf, '1') + } + return append(buf, '0') +} + +func appendSQLArgString(buf []byte, s string) []byte { + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, s) + buf = append(buf, '\'') + return buf +} + // EscapeSQL will escape input arguments into the sql string, doing necessary processing. // It works like printf() in c, there are following format specifiers: // 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..) diff --git a/util/sqlexec/utils_test.go b/util/sqlexec/utils_test.go index b9d7c21921224..80a996385d3bd 100644 --- a/util/sqlexec/utils_test.go +++ b/util/sqlexec/utils_test.go @@ -105,6 +105,9 @@ func TestEscapeBackslash(t *testing.T) { } } +type myInt int +type myStr string + func TestEscapeSQL(t *testing.T) { type TestCase struct { name string @@ -385,6 +388,18 @@ func TestEscapeSQL(t *testing.T) { params: []interface{}{[]float64{55.2, 0.66}}, output: "select 55.2,0.66", }, + { + name: "myInt", + input: "select %?", + params: []interface{}{myInt(3)}, + output: "select 3", + }, + { + name: "myStr", + input: "select %?", + params: []interface{}{myStr("3")}, + output: "select '3'", + }, } for _, v := range tests { // copy iterator variable into a new variable, see issue #27779 @@ -451,3 +466,27 @@ func TestEscapeString(t *testing.T) { require.Equal(t, v.output, EscapeString(v.input)) } } + +func BenchmarkEscapeString(b *testing.B) { + for i := 0; i < b.N; i++ { + escapeSQL("select %?", "3") + } +} + +func BenchmarkUnderlyingString(b *testing.B) { + for i := 0; i < b.N; i++ { + escapeSQL("select %?", myStr("3")) + } +} + +func BenchmarkEscapeInt(b *testing.B) { + for i := 0; i < b.N; i++ { + escapeSQL("select %?", 3) + } +} + +func BenchmarkUnderlyingInt(b *testing.B) { + for i := 0; i < b.N; i++ { + escapeSQL("select %?", myInt(3)) + } +}