Skip to content

Commit

Permalink
server, session: use sqlexec escaping library (pingcap#22877)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgo authored Feb 25, 2021
1 parent f6e8cb4 commit b12aa4a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
7 changes: 6 additions & 1 deletion server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ import (
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/memory"
"github.com/pingcap/tidb/util/sqlexec"
"go.uber.org/zap"
"golang.org/x/net/context"
)
Expand Down Expand Up @@ -715,7 +716,11 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
func (cc *clientConn) useDB(ctx context.Context, db string) (err error) {
// if input is "use `SELECT`", mysql client just send "SELECT"
// so we add `` around db.
_, err = cc.ctx.Execute(ctx, "use `"+db+"`")
sql, err := sqlexec.EscapeSQL("use %n", db)
if err != nil {
return err
}
_, err = cc.ctx.Execute(ctx, sql)
if err != nil {
return errors.Trace(err)
}
Expand Down
9 changes: 3 additions & 6 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,8 +742,7 @@ func (s *session) GetAllSysVars() (map[string]string, error) {
if s.Value(sessionctx.Initing) != nil {
return nil, nil
}
sql := `SELECT VARIABLE_NAME, VARIABLE_VALUE FROM %s.%s;`
sql = fmt.Sprintf(sql, mysql.SystemDB, mysql.GlobalVariablesTable)
sql := sqlexec.MustEscapeSQL("SELECT VARIABLE_NAME, VARIABLE_VALUE FROM %n.%n", mysql.SystemDB, mysql.GlobalVariablesTable)
rows, _, err := s.ExecRestrictedSQL(s, sql)
if err != nil {
return nil, errors.Trace(err)
Expand All @@ -762,8 +761,7 @@ func (s *session) GetGlobalSysVar(name string) (string, error) {
// When running bootstrap or upgrade, we should not access global storage.
return "", nil
}
sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s";`,
mysql.SystemDB, mysql.GlobalVariablesTable, name)
sql := sqlexec.MustEscapeSQL(`SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?`, mysql.SystemDB, mysql.GlobalVariablesTable, name)
sysVar, err := s.getExecRet(s, sql)
if err != nil {
if executor.ErrResultIsEmpty.Equal(err) {
Expand Down Expand Up @@ -792,8 +790,7 @@ func (s *session) SetGlobalSysVar(name, value string) error {
return errors.Trace(err)
}
name = strings.ToLower(name)
sql := fmt.Sprintf(`REPLACE %s.%s VALUES ('%s', '%s');`,
mysql.SystemDB, mysql.GlobalVariablesTable, name, sVal)
sql := sqlexec.MustEscapeSQL(`REPLACE %n.%n VALUES (%?, %?);`, mysql.SystemDB, mysql.GlobalVariablesTable, name, sVal)
_, _, err = s.ExecRestrictedSQL(s, sql)
return errors.Trace(err)
}
Expand Down
4 changes: 2 additions & 2 deletions util/sqlexec/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ func (s *testUtilsSuite) TestEscapeSQL(c *C) {
{
name: "time 3",
input: "select %?",
params: []interface{}{time.Unix(0, 888888888)},
output: "select '1970-01-01 08:00:00.888888'",
params: []interface{}{time.Unix(0, 888888888).UTC()},
output: "select '1970-01-01 00:00:00.888888'",
err: "",
},
{
Expand Down

0 comments on commit b12aa4a

Please sign in to comment.