diff --git a/br/pkg/lightning/checkpoints/checkpoints.go b/br/pkg/lightning/checkpoints/checkpoints.go index 44f2349b672b2..13817e28eb668 100644 --- a/br/pkg/lightning/checkpoints/checkpoints.go +++ b/br/pkg/lightning/checkpoints/checkpoints.go @@ -517,7 +517,15 @@ func OpenCheckpointsDB(ctx context.Context, cfg *config.Config) (DB, error) { switch cfg.Checkpoint.Driver { case config.CheckpointDriverMySQL: - db, err := common.ConnectMySQL(cfg.Checkpoint.DSN) + var ( + db *sql.DB + err error + ) + if cfg.Checkpoint.MySQLParam != nil { + db, err = cfg.Checkpoint.MySQLParam.Connect() + } else { + db, err = sql.Open("mysql", cfg.Checkpoint.DSN) + } if err != nil { return nil, errors.Trace(err) } @@ -546,7 +554,15 @@ func IsCheckpointsDBExists(ctx context.Context, cfg *config.Config) (bool, error } switch cfg.Checkpoint.Driver { case config.CheckpointDriverMySQL: - db, err := sql.Open("mysql", cfg.Checkpoint.DSN) + var ( + db *sql.DB + err error + ) + if cfg.Checkpoint.MySQLParam != nil { + db, err = cfg.Checkpoint.MySQLParam.Connect() + } else { + db, err = sql.Open("mysql", cfg.Checkpoint.DSN) + } if err != nil { return false, errors.Trace(err) } diff --git a/br/pkg/lightning/common/util.go b/br/pkg/lightning/common/util.go index cc03f0ec68dca..679ba6cc5d48b 100644 --- a/br/pkg/lightning/common/util.go +++ b/br/pkg/lightning/common/util.go @@ -23,7 +23,6 @@ import ( "io" "net" "net/http" - "net/url" "os" "strconv" "strings" @@ -58,28 +57,38 @@ type MySQLConnectParam struct { Vars map[string]string } -func (param *MySQLConnectParam) ToDSN() string { - hostPort := net.JoinHostPort(param.Host, strconv.Itoa(param.Port)) - dsn := fmt.Sprintf("%s:%s@tcp(%s)/?charset=utf8mb4&sql_mode='%s'&maxAllowedPacket=%d&tls=%s", - param.User, param.Password, hostPort, - param.SQLMode, param.MaxAllowedPacket, param.TLS) +func (param *MySQLConnectParam) ToDriverConfig() *mysql.Config { + cfg := mysql.NewConfig() + cfg.Params = make(map[string]string) + + cfg.User = param.User + cfg.Passwd = param.Password + cfg.Net = "tcp" + cfg.Addr = net.JoinHostPort(param.Host, strconv.Itoa(param.Port)) + cfg.Params["charset"] = "utf8mb4" + cfg.Params["sql_mode"] = fmt.Sprintf("'%s'", param.SQLMode) + cfg.MaxAllowedPacket = int(param.MaxAllowedPacket) + cfg.TLSConfig = param.TLS for k, v := range param.Vars { - dsn += fmt.Sprintf("&%s='%s'", k, url.QueryEscape(v)) + cfg.Params[k] = fmt.Sprintf("'%s'", v) } - - return dsn + return cfg } -func tryConnectMySQL(dsn string) (*sql.DB, error) { - driverName := "mysql" - failpoint.Inject("MockMySQLDriver", func(val failpoint.Value) { - driverName = val.(string) +func tryConnectMySQL(cfg *mysql.Config) (*sql.DB, error) { + failpoint.Inject("MustMySQLPassword", func(val failpoint.Value) { + pwd := val.(string) + if cfg.Passwd != pwd { + failpoint.Return(nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}) + } + failpoint.Return(nil, nil) }) - db, err := sql.Open(driverName, dsn) + c, err := mysql.NewConnector(cfg) if err != nil { return nil, errors.Trace(err) } + db := sql.OpenDB(c) if err = db.Ping(); err != nil { _ = db.Close() return nil, errors.Trace(err) @@ -89,13 +98,9 @@ func tryConnectMySQL(dsn string) (*sql.DB, error) { // ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding, // we will try to connect MySQL with the base64 decoding of the password. -func ConnectMySQL(dsn string) (*sql.DB, error) { - cfg, err := mysql.ParseDSN(dsn) - if err != nil { - return nil, errors.Trace(err) - } +func ConnectMySQL(cfg *mysql.Config) (*sql.DB, error) { // Try plain password first. - db, firstErr := tryConnectMySQL(dsn) + db, firstErr := tryConnectMySQL(cfg) if firstErr == nil { return db, nil } @@ -104,9 +109,9 @@ func ConnectMySQL(dsn string) (*sql.DB, error) { // If password is encoded by base64, try the decoded string as well. if password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd); decodeErr == nil && string(password) != cfg.Passwd { cfg.Passwd = string(password) - db, err = tryConnectMySQL(cfg.FormatDSN()) + db2, err := tryConnectMySQL(cfg) if err == nil { - return db, nil + return db2, nil } } } @@ -115,7 +120,7 @@ func ConnectMySQL(dsn string) (*sql.DB, error) { } func (param *MySQLConnectParam) Connect() (*sql.DB, error) { - db, err := ConnectMySQL(param.ToDSN()) + db, err := ConnectMySQL(param.ToDriverConfig()) if err != nil { return nil, errors.Trace(err) } diff --git a/br/pkg/lightning/common/util_test.go b/br/pkg/lightning/common/util_test.go index c7c95b44f69bf..a192ecea11906 100644 --- a/br/pkg/lightning/common/util_test.go +++ b/br/pkg/lightning/common/util_test.go @@ -16,16 +16,12 @@ package common_test import ( "context" - "database/sql" - "database/sql/driver" "encoding/base64" "encoding/json" "fmt" "io" - "math/rand" "net/http" "net/http/httptest" - "strconv" "testing" "time" @@ -35,7 +31,6 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/log" - tmysql "github.com/pingcap/tidb/errno" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -85,66 +80,14 @@ func TestGetJSON(t *testing.T) { require.Regexp(t, ".*http status code != 200.*", err.Error()) } -func TestToDSN(t *testing.T) { - param := common.MySQLConnectParam{ - Host: "127.0.0.1", - Port: 4000, - User: "root", - Password: "123456", - SQLMode: "strict", - MaxAllowedPacket: 1234, - TLS: "cluster", - Vars: map[string]string{ - "tidb_distsql_scan_concurrency": "1", - }, - } - require.Equal(t, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN()) - - param.Host = "::1" - require.Equal(t, "root:123456@tcp([::1]:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN()) -} - -type mockDriver struct { - driver.Driver - plainPsw string -} - -func (m *mockDriver) Open(dsn string) (driver.Conn, error) { - cfg, err := mysql.ParseDSN(dsn) - if err != nil { - return nil, err - } - accessDenied := cfg.Passwd != m.plainPsw - return &mockConn{accessDenied: accessDenied}, nil -} - -type mockConn struct { - driver.Conn - driver.Pinger - accessDenied bool -} - -func (c *mockConn) Ping(ctx context.Context) error { - if c.accessDenied { - return &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"} - } - return nil -} - -func (c *mockConn) Close() error { - return nil -} - func TestConnect(t *testing.T) { plainPsw := "dQAUoDiyb1ucWZk7" - driverName := "mysql-mock-" + strconv.Itoa(rand.Int()) - sql.Register(driverName, &mockDriver{plainPsw: plainPsw}) require.NoError(t, failpoint.Enable( - "github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver", - fmt.Sprintf("return(\"%s\")", driverName))) + "github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword", + fmt.Sprintf("return(\"%s\")", plainPsw))) defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword")) }() param := common.MySQLConnectParam{ @@ -155,13 +98,11 @@ func TestConnect(t *testing.T) { SQLMode: "strict", MaxAllowedPacket: 1234, } - db, err := param.Connect() + _, err := param.Connect() require.NoError(t, err) - require.NoError(t, db.Close()) param.Password = base64.StdEncoding.EncodeToString([]byte(plainPsw)) - db, err = param.Connect() + _, err = param.Connect() require.NoError(t, err) - require.NoError(t, db.Close()) } func TestIsContextCanceledError(t *testing.T) { diff --git a/br/pkg/lightning/config/config.go b/br/pkg/lightning/config/config.go index 4c1af0d2baff3..638784ff3ed1e 100644 --- a/br/pkg/lightning/config/config.go +++ b/br/pkg/lightning/config/config.go @@ -553,11 +553,12 @@ type TikvImporter struct { } type Checkpoint struct { - Schema string `toml:"schema" json:"schema"` - DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON. - Driver string `toml:"driver" json:"driver"` - Enable bool `toml:"enable" json:"enable"` - KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"` + Schema string `toml:"schema" json:"schema"` + DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON. + MySQLParam *common.MySQLConnectParam `toml:"-" json:"-"` // For some security reason, we use MySQLParam instead of DSN. + Driver string `toml:"driver" json:"driver"` + Enable bool `toml:"enable" json:"enable"` + KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"` } type Cron struct { @@ -1142,7 +1143,7 @@ func (cfg *Config) AdjustCheckPoint() { MaxAllowedPacket: defaultMaxAllowedPacket, TLS: cfg.TiDB.TLS, } - cfg.Checkpoint.DSN = param.ToDSN() + cfg.Checkpoint.MySQLParam = ¶m case CheckpointDriverFile: cfg.Checkpoint.DSN = "/tmp/" + cfg.Checkpoint.Schema + ".pb" } diff --git a/br/pkg/lightning/config/config_test.go b/br/pkg/lightning/config/config_test.go index 2a4dcbe7cdad9..e74094a6b9066 100644 --- a/br/pkg/lightning/config/config_test.go +++ b/br/pkg/lightning/config/config_test.go @@ -32,7 +32,6 @@ import ( "github.com/BurntSushi/toml" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/config" - "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" ) @@ -626,7 +625,9 @@ func TestLoadConfig(t *testing.T) { taskCfg.TiDB.DistSQLScanConcurrency = 1 err = taskCfg.Adjust(context.Background()) require.NoError(t, err) - require.Equal(t, "guest:12345@tcp(172.16.30.11:4001)/?charset=utf8mb4&sql_mode='"+mysql.DefaultSQLMode+"'&maxAllowedPacket=67108864&tls=false", taskCfg.Checkpoint.DSN) + equivalentDSN := taskCfg.Checkpoint.MySQLParam.ToDriverConfig().FormatDSN() + expectedDSN := "guest:12345@tcp(172.16.30.11:4001)/?tls=false&maxAllowedPacket=67108864&charset=utf8mb4&sql_mode=%27ONLY_FULL_GROUP_BY%2CSTRICT_TRANS_TABLES%2CNO_ZERO_IN_DATE%2CNO_ZERO_DATE%2CERROR_FOR_DIVISION_BY_ZERO%2CNO_AUTO_CREATE_USER%2CNO_ENGINE_SUBSTITUTION%27" + require.Equal(t, expectedDSN, equivalentDSN) result := taskCfg.String() require.Regexp(t, `.*"pd-addr":"172.16.30.11:2379,172.16.30.12:2379".*`, result) diff --git a/cmd/importer/db.go b/cmd/importer/db.go index 8b0d7353b9adf..b8ecf83abfc4b 100644 --- a/cmd/importer/db.go +++ b/cmd/importer/db.go @@ -22,7 +22,7 @@ import ( "strconv" "strings" - _ "github.com/go-sql-driver/mysql" + mysql2 "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" "github.com/pingcap/log" "github.com/pingcap/tidb/parser/mysql" @@ -318,13 +318,18 @@ func execSQL(db *sql.DB, sql string) error { } func createDB(cfg DBConfig) (*sql.DB, error) { - dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name) - db, err := sql.Open("mysql", dbDSN) + driverCfg := mysql2.NewConfig() + driverCfg.User = cfg.User + driverCfg.Passwd = cfg.Password + driverCfg.Net = "tcp" + driverCfg.Addr = cfg.Host + ":" + strconv.Itoa(cfg.Port) + driverCfg.DBName = cfg.Name + + c, err := mysql2.NewConnector(driverCfg) if err != nil { return nil, errors.Trace(err) } - - return db, nil + return sql.OpenDB(c), nil } func closeDB(db *sql.DB) error { diff --git a/dumpling/export/config.go b/dumpling/export/config.go index 980de0d8807f5..b92d2922d2572 100644 --- a/dumpling/export/config.go +++ b/dumpling/export/config.go @@ -218,6 +218,31 @@ func (conf *Config) GetDSN(db string) string { return dsn } +// GetDriverConfig returns the MySQL driver config from Config. +func (conf *Config) GetDriverConfig(db string) *mysql.Config { + driverCfg := mysql.NewConfig() + // maxAllowedPacket=0 can be used to automatically fetch the max_allowed_packet variable from server on every connection. + // https://github.com/go-sql-driver/mysql#maxallowedpacket + hostPort := net.JoinHostPort(conf.Host, strconv.Itoa(conf.Port)) + driverCfg.User = conf.User + driverCfg.Passwd = conf.Password + driverCfg.Net = "tcp" + driverCfg.Addr = hostPort + driverCfg.DBName = db + driverCfg.Collation = "utf8mb4_general_ci" + driverCfg.ReadTimeout = conf.ReadTimeout + driverCfg.WriteTimeout = 30 * time.Second + driverCfg.InterpolateParams = true + driverCfg.MaxAllowedPacket = 0 + if conf.Security.DriveTLSName != "" { + driverCfg.TLSConfig = conf.Security.DriveTLSName + } + if conf.AllowCleartextPasswords { + driverCfg.AllowCleartextPasswords = true + } + return driverCfg +} + func timestampDirName() string { return fmt.Sprintf("./export-%s", time.Now().Format(time.RFC3339)) } diff --git a/dumpling/export/dump.go b/dumpling/export/dump.go index cdc91e6e4a389..857ef5d7470fb 100644 --- a/dumpling/export/dump.go +++ b/dumpling/export/dump.go @@ -37,7 +37,7 @@ import ( "golang.org/x/sync/errgroup" ) -var openDBFunc = sql.Open +var openDBFunc = openDB var errEmptyHandleVals = errors.New("empty handleVals for TiDB table") @@ -1309,11 +1309,11 @@ func startHTTPService(d *Dumper) error { // openSQLDB is an initialization step of Dumper. func openSQLDB(d *Dumper) error { conf := d.conf - pool, err := sql.Open("mysql", conf.GetDSN("")) + c, err := mysql.NewConnector(conf.GetDriverConfig("")) if err != nil { return errors.Trace(err) } - d.dbHandle = pool + d.dbHandle = sql.OpenDB(c) return nil } @@ -1510,12 +1510,20 @@ func setSessionParam(d *Dumper) error { } } } - if d.dbHandle, err = resetDBWithSessionParams(d.tctx, pool, conf.GetDSN(""), conf.SessionParams); err != nil { + if d.dbHandle, err = resetDBWithSessionParams(d.tctx, pool, conf.GetDriverConfig(""), conf.SessionParams); err != nil { return errors.Trace(err) } return nil } +func openDB(cfg *mysql.Config) (*sql.DB, error) { + c, err := mysql.NewConnector(cfg) + if err != nil { + return nil, errors.Trace(err) + } + return sql.OpenDB(c), nil +} + func (d *Dumper) renewSelectTableRegionFuncForLowerTiDB(tctx *tcontext.Context) error { conf := d.conf if !(conf.ServerInfo.ServerType == version.ServerTypeTiDB && conf.ServerInfo.ServerVersion != nil && conf.ServerInfo.HasTiKV && @@ -1532,7 +1540,7 @@ func (d *Dumper) renewSelectTableRegionFuncForLowerTiDB(tctx *tcontext.Context) d.selectTiDBTableRegionFunc = func(_ *tcontext.Context, _ *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) { return nil, nil, errors.Annotatef(errEmptyHandleVals, "table: `%s`.`%s`", escapeString(meta.DatabaseName()), escapeString(meta.TableName())) } - dbHandle, err := openDBFunc("mysql", conf.GetDSN("")) + dbHandle, err := openDBFunc(conf.GetDriverConfig("")) if err != nil { return errors.Trace(err) } diff --git a/dumpling/export/sql.go b/dumpling/export/sql.go index 83655df99e330..837bec568b9a7 100644 --- a/dumpling/export/sql.go +++ b/dumpling/export/sql.go @@ -10,7 +10,6 @@ import ( "fmt" "io" "math" - "net/url" "strconv" "strings" @@ -834,7 +833,7 @@ func isUnknownSystemVariableErr(err error) bool { // resetDBWithSessionParams will return a new sql.DB as a replacement for input `db` with new session parameters. // If returned error is nil, the input `db` will be closed. -func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, dsn string, params map[string]interface{}) (*sql.DB, error) { +func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, cfg *mysql.Config, params map[string]interface{}) (*sql.DB, error) { support := make(map[string]interface{}) for k, v := range params { var pv interface{} @@ -862,6 +861,10 @@ func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, dsn string, pa support[k] = pv } + if cfg.Params == nil { + cfg.Params = make(map[string]string) + } + for k, v := range support { var s string // Wrap string with quote to handle string with space. For example, '2020-10-20 13:41:40' @@ -871,19 +874,21 @@ func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, dsn string, pa } else { s = fmt.Sprintf("%v", v) } - dsn += fmt.Sprintf("&%s=%s", k, url.QueryEscape(s)) + cfg.Params[k] = s } db.Close() - newDB, err := sql.Open("mysql", dsn) - if err == nil { - // ping to make sure all session parameters are set correctly - err = newDB.PingContext(tctx) - if err != nil { - newDB.Close() - } + c, err := mysql.NewConnector(cfg) + if err != nil { + return nil, errors.Trace(err) + } + newDB := sql.OpenDB(c) + // ping to make sure all session parameters are set correctly + err = newDB.PingContext(tctx) + if err != nil { + newDB.Close() } - return newDB, errors.Trace(err) + return newDB, nil } func createConnWithConsistency(ctx context.Context, db *sql.DB, repeatableRead bool) (*sql.Conn, error) { diff --git a/dumpling/export/sql_test.go b/dumpling/export/sql_test.go index d98a8a3c76a64..04615637be8f1 100644 --- a/dumpling/export/sql_test.go +++ b/dumpling/export/sql_test.go @@ -1345,7 +1345,7 @@ func TestBuildVersion3RegionQueries(t *testing.T) { defer func() { openDBFunc = oldOpenFunc }() - openDBFunc = func(_, _ string) (*sql.DB, error) { + openDBFunc = func(*mysql.Config) (*sql.DB, error) { return db, nil } diff --git a/dumpling/tests/s3/import.go b/dumpling/tests/s3/import.go index 0489be3fa7a80..30dc95fae84b1 100644 --- a/dumpling/tests/s3/import.go +++ b/dumpling/tests/s3/import.go @@ -6,7 +6,9 @@ import ( "context" "database/sql" "fmt" + "net" "os" + "strconv" _ "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" @@ -48,7 +50,7 @@ func main() { return errors.Trace(err) } - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4", "root", "", "127.0.0.1", port, database) + dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4", "root", "", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)), database) db, err := sql.Open("mysql", dsn) if err != nil { return errors.Trace(err) diff --git a/util/dbutil/common.go b/util/dbutil/common.go index 37b6da5fd1f49..df54e18bd6909 100644 --- a/util/dbutil/common.go +++ b/util/dbutil/common.go @@ -19,7 +19,7 @@ import ( "database/sql" "encoding/json" "fmt" - "net/url" + "net" "os" "strconv" "strings" @@ -107,26 +107,31 @@ func GetDBConfigFromEnv(schema string) DBConfig { // OpenDB opens a mysql connection FD func OpenDB(cfg DBConfig, vars map[string]string) (*sql.DB, error) { - var dbDSN string + driverCfg := mysql.NewConfig() + driverCfg.Params = make(map[string]string) + driverCfg.User = cfg.User + driverCfg.Passwd = cfg.Password + driverCfg.Net = "tcp" + driverCfg.Addr = net.JoinHostPort(cfg.Host, strconv.Itoa(cfg.Port)) + driverCfg.Params["charset"] = "utf8mb4" + if len(cfg.Snapshot) != 0 { log.Info("create connection with snapshot", zap.String("snapshot", cfg.Snapshot)) - dbDSN = fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4&tidb_snapshot=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Snapshot) - } else { - dbDSN = fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4", cfg.User, cfg.Password, cfg.Host, cfg.Port) + driverCfg.Params["tidb_snapshot"] = cfg.Snapshot } for key, val := range vars { // key='val'. add single quote for better compatibility. - dbDSN += fmt.Sprintf("&%s=%%27%s%%27", key, url.QueryEscape(val)) + driverCfg.Params[key] = fmt.Sprintf("'%s'", val) } - dbConn, err := sql.Open("mysql", dbDSN) + c, err := mysql.NewConnector(driverCfg) if err != nil { return nil, errors.Trace(err) } - - err = dbConn.Ping() - return dbConn, errors.Trace(err) + db := sql.OpenDB(c) + err = db.Ping() + return db, errors.Trace(err) } // CloseDB closes the mysql fd