From b12aa4a9e00c26cf764d8834e2511946706707f2 Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Thu, 25 Feb 2021 09:22:04 -0700 Subject: [PATCH] server, session: use sqlexec escaping library (#22877) --- server/conn.go | 7 ++++++- session/session.go | 9 +++------ util/sqlexec/utils_test.go | 4 ++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/server/conn.go b/server/conn.go index ae6f769d01e02..18ca216fae3d8 100644 --- a/server/conn.go +++ b/server/conn.go @@ -64,6 +64,7 @@ import ( "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tidb/util/sqlexec" "go.uber.org/zap" "golang.org/x/net/context" ) @@ -715,7 +716,11 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { func (cc *clientConn) useDB(ctx context.Context, db string) (err error) { // if input is "use `SELECT`", mysql client just send "SELECT" // so we add `` around db. - _, err = cc.ctx.Execute(ctx, "use `"+db+"`") + sql, err := sqlexec.EscapeSQL("use %n", db) + if err != nil { + return err + } + _, err = cc.ctx.Execute(ctx, sql) if err != nil { return errors.Trace(err) } diff --git a/session/session.go b/session/session.go index aeb627fa3c303..08b938f608c2f 100644 --- a/session/session.go +++ b/session/session.go @@ -742,8 +742,7 @@ func (s *session) GetAllSysVars() (map[string]string, error) { if s.Value(sessionctx.Initing) != nil { return nil, nil } - sql := `SELECT VARIABLE_NAME, VARIABLE_VALUE FROM %s.%s;` - sql = fmt.Sprintf(sql, mysql.SystemDB, mysql.GlobalVariablesTable) + sql := sqlexec.MustEscapeSQL("SELECT VARIABLE_NAME, VARIABLE_VALUE FROM %n.%n", mysql.SystemDB, mysql.GlobalVariablesTable) rows, _, err := s.ExecRestrictedSQL(s, sql) if err != nil { return nil, errors.Trace(err) @@ -762,8 +761,7 @@ func (s *session) GetGlobalSysVar(name string) (string, error) { // When running bootstrap or upgrade, we should not access global storage. return "", nil } - sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s";`, - mysql.SystemDB, mysql.GlobalVariablesTable, name) + sql := sqlexec.MustEscapeSQL(`SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?`, mysql.SystemDB, mysql.GlobalVariablesTable, name) sysVar, err := s.getExecRet(s, sql) if err != nil { if executor.ErrResultIsEmpty.Equal(err) { @@ -792,8 +790,7 @@ func (s *session) SetGlobalSysVar(name, value string) error { return errors.Trace(err) } name = strings.ToLower(name) - sql := fmt.Sprintf(`REPLACE %s.%s VALUES ('%s', '%s');`, - mysql.SystemDB, mysql.GlobalVariablesTable, name, sVal) + sql := sqlexec.MustEscapeSQL(`REPLACE %n.%n VALUES (%?, %?);`, mysql.SystemDB, mysql.GlobalVariablesTable, name, sVal) _, _, err = s.ExecRestrictedSQL(s, sql) return errors.Trace(err) } diff --git a/util/sqlexec/utils_test.go b/util/sqlexec/utils_test.go index d5dd0542b6bd5..a8a912a33978f 100644 --- a/util/sqlexec/utils_test.go +++ b/util/sqlexec/utils_test.go @@ -276,8 +276,8 @@ func (s *testUtilsSuite) TestEscapeSQL(c *C) { { name: "time 3", input: "select %?", - params: []interface{}{time.Unix(0, 888888888)}, - output: "select '1970-01-01 08:00:00.888888'", + params: []interface{}{time.Unix(0, 888888888).UTC()}, + output: "select '1970-01-01 00:00:00.888888'", err: "", }, {