Skip to content

Commit

Permalink
util/sqlexec: EscapeSQL support resolve some underlying type (#42465)
Browse files Browse the repository at this point in the history
  • Loading branch information
lance6716 committed Mar 31, 2023
1 parent 6273e22 commit 2d4df7f
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 9 deletions.
45 changes: 36 additions & 9 deletions util/sqlexec/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package sqlexec
import (
"encoding/json"
"io"
"reflect"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -161,11 +162,7 @@ func escapeSQL(sql string, args ...interface{}) ([]byte, error) {
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
buf = appendSQLArgBool(buf, v)
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
Expand All @@ -187,9 +184,7 @@ func escapeSQL(sql string, args ...interface{}) ([]byte, error) {
buf = append(buf, '\'')
}
case string:
buf = append(buf, '\'')
buf = escapeStringBackslash(buf, v)
buf = append(buf, '\'')
buf = appendSQLArgString(buf, v)
case []string:
for i, k := range v {
if i > 0 {
Expand All @@ -214,7 +209,25 @@ func escapeSQL(sql string, args ...interface{}) ([]byte, error) {
buf = strconv.AppendFloat(buf, k, 'g', -1, 64)
}
default:
return nil, errors.Errorf("unsupported %d-th argument: %v", argPos, arg)
// slow path based on reflection
reflectTp := reflect.TypeOf(arg)
kind := reflectTp.Kind()
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
buf = strconv.AppendInt(buf, reflect.ValueOf(arg).Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
buf = strconv.AppendUint(buf, reflect.ValueOf(arg).Uint(), 10)
case reflect.Float32:
buf = strconv.AppendFloat(buf, reflect.ValueOf(arg).Float(), 'g', -1, 32)
case reflect.Float64:
buf = strconv.AppendFloat(buf, reflect.ValueOf(arg).Float(), 'g', -1, 64)
case reflect.Bool:
buf = appendSQLArgBool(buf, reflect.ValueOf(arg).Bool())
case reflect.String:
buf = appendSQLArgString(buf, reflect.ValueOf(arg).String())
default:
return nil, errors.Errorf("unsupported %d-th argument: %v", argPos, arg)
}
}
}
i++ // skip specifier
Expand All @@ -228,6 +241,20 @@ func escapeSQL(sql string, args ...interface{}) ([]byte, error) {
return buf, nil
}

func appendSQLArgBool(buf []byte, v bool) []byte {
if v {
return append(buf, '1')
}
return append(buf, '0')
}

func appendSQLArgString(buf []byte, s string) []byte {
buf = append(buf, '\'')
buf = escapeStringBackslash(buf, s)
buf = append(buf, '\'')
return buf
}

// EscapeSQL will escape input arguments into the sql string, doing necessary processing.
// It works like printf() in c, there are following format specifiers:
// 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..)
Expand Down
39 changes: 39 additions & 0 deletions util/sqlexec/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ func TestEscapeBackslash(t *testing.T) {
}
}

type myInt int
type myStr string

func TestEscapeSQL(t *testing.T) {
type TestCase struct {
name string
Expand Down Expand Up @@ -385,6 +388,18 @@ func TestEscapeSQL(t *testing.T) {
params: []interface{}{[]float64{55.2, 0.66}},
output: "select 55.2,0.66",
},
{
name: "myInt",
input: "select %?",
params: []interface{}{myInt(3)},
output: "select 3",
},
{
name: "myStr",
input: "select %?",
params: []interface{}{myStr("3")},
output: "select '3'",
},
}
for _, v := range tests {
// copy iterator variable into a new variable, see issue #27779
Expand Down Expand Up @@ -451,3 +466,27 @@ func TestEscapeString(t *testing.T) {
require.Equal(t, v.output, EscapeString(v.input))
}
}

func BenchmarkEscapeString(b *testing.B) {
for i := 0; i < b.N; i++ {
escapeSQL("select %?", "3")
}
}

func BenchmarkUnderlyingString(b *testing.B) {
for i := 0; i < b.N; i++ {
escapeSQL("select %?", myStr("3"))
}
}

func BenchmarkEscapeInt(b *testing.B) {
for i := 0; i < b.N; i++ {
escapeSQL("select %?", 3)
}
}

func BenchmarkUnderlyingInt(b *testing.B) {
for i := 0; i < b.N; i++ {
escapeSQL("select %?", myInt(3))
}
}

0 comments on commit 2d4df7f

Please sign in to comment.