From b1afbfdfdafa09e77e4577ec05d02916022cf715 Mon Sep 17 00:00:00 2001 From: lance6716 Date: Wed, 19 Oct 2022 10:27:52 +0800 Subject: [PATCH 1/2] This is an automated cherry-pick of #38342 Signed-off-by: ti-chi-bot --- br/pkg/lightning/checkpoints/checkpoints.go | 20 +++++- br/pkg/lightning/common/util.go | 51 ++++++++------- br/pkg/lightning/common/util_test.go | 69 ++------------------- br/pkg/lightning/config/config.go | 13 ++-- br/pkg/lightning/config/config_test.go | 5 +- cmd/importer/db.go | 15 +++-- dumpling/export/config.go | 25 ++++++++ dumpling/export/dump.go | 18 ++++-- dumpling/export/sql.go | 24 +++++-- dumpling/export/sql_test.go | 2 +- dumpling/tests/s3/import.go | 4 +- util/dbutil/common.go | 25 +++++--- 12 files changed, 148 insertions(+), 123 deletions(-) diff --git a/br/pkg/lightning/checkpoints/checkpoints.go b/br/pkg/lightning/checkpoints/checkpoints.go index 36cefed180ae3..ebda0d2f87fac 100644 --- a/br/pkg/lightning/checkpoints/checkpoints.go +++ b/br/pkg/lightning/checkpoints/checkpoints.go @@ -516,7 +516,15 @@ func OpenCheckpointsDB(ctx context.Context, cfg *config.Config) (DB, error) { switch cfg.Checkpoint.Driver { case config.CheckpointDriverMySQL: - db, err := common.ConnectMySQL(cfg.Checkpoint.DSN) + var ( + db *sql.DB + err error + ) + if cfg.Checkpoint.MySQLParam != nil { + db, err = cfg.Checkpoint.MySQLParam.Connect() + } else { + db, err = sql.Open("mysql", cfg.Checkpoint.DSN) + } if err != nil { return nil, errors.Trace(err) } @@ -545,7 +553,15 @@ func IsCheckpointsDBExists(ctx context.Context, cfg *config.Config) (bool, error } switch cfg.Checkpoint.Driver { case config.CheckpointDriverMySQL: - db, err := sql.Open("mysql", cfg.Checkpoint.DSN) + var ( + db *sql.DB + err error + ) + if cfg.Checkpoint.MySQLParam != nil { + db, err = cfg.Checkpoint.MySQLParam.Connect() + } else { + db, err = sql.Open("mysql", cfg.Checkpoint.DSN) + } if err != nil { return false, errors.Trace(err) } diff --git a/br/pkg/lightning/common/util.go b/br/pkg/lightning/common/util.go index 67a26fb3ab411..57afc1fb7eac0 100644 --- a/br/pkg/lightning/common/util.go +++ b/br/pkg/lightning/common/util.go @@ -23,7 +23,6 @@ import ( "io" "net" "net/http" - "net/url" "os" "strconv" "strings" @@ -58,28 +57,38 @@ type MySQLConnectParam struct { Vars map[string]string } -func (param *MySQLConnectParam) ToDSN() string { - hostPort := net.JoinHostPort(param.Host, strconv.Itoa(param.Port)) - dsn := fmt.Sprintf("%s:%s@tcp(%s)/?charset=utf8mb4&sql_mode='%s'&maxAllowedPacket=%d&tls=%s", - param.User, param.Password, hostPort, - param.SQLMode, param.MaxAllowedPacket, param.TLS) +func (param *MySQLConnectParam) ToDriverConfig() *mysql.Config { + cfg := mysql.NewConfig() + cfg.Params = make(map[string]string) + + cfg.User = param.User + cfg.Passwd = param.Password + cfg.Net = "tcp" + cfg.Addr = net.JoinHostPort(param.Host, strconv.Itoa(param.Port)) + cfg.Params["charset"] = "utf8mb4" + cfg.Params["sql_mode"] = fmt.Sprintf("'%s'", param.SQLMode) + cfg.MaxAllowedPacket = int(param.MaxAllowedPacket) + cfg.TLSConfig = param.TLS for k, v := range param.Vars { - dsn += fmt.Sprintf("&%s='%s'", k, url.QueryEscape(v)) + cfg.Params[k] = fmt.Sprintf("'%s'", v) } - - return dsn + return cfg } -func tryConnectMySQL(dsn string) (*sql.DB, error) { - driverName := "mysql" - failpoint.Inject("MockMySQLDriver", func(val failpoint.Value) { - driverName = val.(string) +func tryConnectMySQL(cfg *mysql.Config) (*sql.DB, error) { + failpoint.Inject("MustMySQLPassword", func(val failpoint.Value) { + pwd := val.(string) + if cfg.Passwd != pwd { + failpoint.Return(nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}) + } + failpoint.Return(nil, nil) }) - db, err := sql.Open(driverName, dsn) + c, err := mysql.NewConnector(cfg) if err != nil { return nil, errors.Trace(err) } + db := sql.OpenDB(c) if err = db.Ping(); err != nil { _ = db.Close() return nil, errors.Trace(err) @@ -89,13 +98,9 @@ func tryConnectMySQL(dsn string) (*sql.DB, error) { // ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding, // we will try to connect MySQL with the base64 decoding of the password. -func ConnectMySQL(dsn string) (*sql.DB, error) { - cfg, err := mysql.ParseDSN(dsn) - if err != nil { - return nil, errors.Trace(err) - } +func ConnectMySQL(cfg *mysql.Config) (*sql.DB, error) { // Try plain password first. - db, firstErr := tryConnectMySQL(dsn) + db, firstErr := tryConnectMySQL(cfg) if firstErr == nil { return db, nil } @@ -104,9 +109,9 @@ func ConnectMySQL(dsn string) (*sql.DB, error) { // If password is encoded by base64, try the decoded string as well. if password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd); decodeErr == nil && string(password) != cfg.Passwd { cfg.Passwd = string(password) - db, err = tryConnectMySQL(cfg.FormatDSN()) + db2, err := tryConnectMySQL(cfg) if err == nil { - return db, nil + return db2, nil } } } @@ -115,7 +120,7 @@ func ConnectMySQL(dsn string) (*sql.DB, error) { } func (param *MySQLConnectParam) Connect() (*sql.DB, error) { - db, err := ConnectMySQL(param.ToDSN()) + db, err := ConnectMySQL(param.ToDriverConfig()) if err != nil { return nil, errors.Trace(err) } diff --git a/br/pkg/lightning/common/util_test.go b/br/pkg/lightning/common/util_test.go index c7c95b44f69bf..a192ecea11906 100644 --- a/br/pkg/lightning/common/util_test.go +++ b/br/pkg/lightning/common/util_test.go @@ -16,16 +16,12 @@ package common_test import ( "context" - "database/sql" - "database/sql/driver" "encoding/base64" "encoding/json" "fmt" "io" - "math/rand" "net/http" "net/http/httptest" - "strconv" "testing" "time" @@ -35,7 +31,6 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/log" - tmysql "github.com/pingcap/tidb/errno" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -85,66 +80,14 @@ func TestGetJSON(t *testing.T) { require.Regexp(t, ".*http status code != 200.*", err.Error()) } -func TestToDSN(t *testing.T) { - param := common.MySQLConnectParam{ - Host: "127.0.0.1", - Port: 4000, - User: "root", - Password: "123456", - SQLMode: "strict", - MaxAllowedPacket: 1234, - TLS: "cluster", - Vars: map[string]string{ - "tidb_distsql_scan_concurrency": "1", - }, - } - require.Equal(t, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN()) - - param.Host = "::1" - require.Equal(t, "root:123456@tcp([::1]:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN()) -} - -type mockDriver struct { - driver.Driver - plainPsw string -} - -func (m *mockDriver) Open(dsn string) (driver.Conn, error) { - cfg, err := mysql.ParseDSN(dsn) - if err != nil { - return nil, err - } - accessDenied := cfg.Passwd != m.plainPsw - return &mockConn{accessDenied: accessDenied}, nil -} - -type mockConn struct { - driver.Conn - driver.Pinger - accessDenied bool -} - -func (c *mockConn) Ping(ctx context.Context) error { - if c.accessDenied { - return &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"} - } - return nil -} - -func (c *mockConn) Close() error { - return nil -} - func TestConnect(t *testing.T) { plainPsw := "dQAUoDiyb1ucWZk7" - driverName := "mysql-mock-" + strconv.Itoa(rand.Int()) - sql.Register(driverName, &mockDriver{plainPsw: plainPsw}) require.NoError(t, failpoint.Enable( - "github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver", - fmt.Sprintf("return(\"%s\")", driverName))) + "github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword", + fmt.Sprintf("return(\"%s\")", plainPsw))) defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword")) }() param := common.MySQLConnectParam{ @@ -155,13 +98,11 @@ func TestConnect(t *testing.T) { SQLMode: "strict", MaxAllowedPacket: 1234, } - db, err := param.Connect() + _, err := param.Connect() require.NoError(t, err) - require.NoError(t, db.Close()) param.Password = base64.StdEncoding.EncodeToString([]byte(plainPsw)) - db, err = param.Connect() + _, err = param.Connect() require.NoError(t, err) - require.NoError(t, db.Close()) } func TestIsContextCanceledError(t *testing.T) { diff --git a/br/pkg/lightning/config/config.go b/br/pkg/lightning/config/config.go index 4e6d51647db9d..9c3c7b8c24ed1 100644 --- a/br/pkg/lightning/config/config.go +++ b/br/pkg/lightning/config/config.go @@ -545,11 +545,12 @@ type TikvImporter struct { } type Checkpoint struct { - Schema string `toml:"schema" json:"schema"` - DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON. - Driver string `toml:"driver" json:"driver"` - Enable bool `toml:"enable" json:"enable"` - KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"` + Schema string `toml:"schema" json:"schema"` + DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON. + MySQLParam *common.MySQLConnectParam `toml:"-" json:"-"` // For some security reason, we use MySQLParam instead of DSN. + Driver string `toml:"driver" json:"driver"` + Enable bool `toml:"enable" json:"enable"` + KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"` } type Cron struct { @@ -1126,7 +1127,7 @@ func (cfg *Config) AdjustCheckPoint() { MaxAllowedPacket: defaultMaxAllowedPacket, TLS: cfg.TiDB.TLS, } - cfg.Checkpoint.DSN = param.ToDSN() + cfg.Checkpoint.MySQLParam = ¶m case CheckpointDriverFile: cfg.Checkpoint.DSN = "/tmp/" + cfg.Checkpoint.Schema + ".pb" } diff --git a/br/pkg/lightning/config/config_test.go b/br/pkg/lightning/config/config_test.go index 252fd7f01bb07..44b0632c0dd0c 100644 --- a/br/pkg/lightning/config/config_test.go +++ b/br/pkg/lightning/config/config_test.go @@ -31,7 +31,6 @@ import ( "github.com/BurntSushi/toml" "github.com/pingcap/tidb/br/pkg/lightning/config" - "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" ) @@ -625,7 +624,9 @@ func TestLoadConfig(t *testing.T) { taskCfg.TiDB.DistSQLScanConcurrency = 1 err = taskCfg.Adjust(context.Background()) require.NoError(t, err) - require.Equal(t, "guest:12345@tcp(172.16.30.11:4001)/?charset=utf8mb4&sql_mode='"+mysql.DefaultSQLMode+"'&maxAllowedPacket=67108864&tls=false", taskCfg.Checkpoint.DSN) + equivalentDSN := taskCfg.Checkpoint.MySQLParam.ToDriverConfig().FormatDSN() + expectedDSN := "guest:12345@tcp(172.16.30.11:4001)/?tls=false&maxAllowedPacket=67108864&charset=utf8mb4&sql_mode=%27ONLY_FULL_GROUP_BY%2CSTRICT_TRANS_TABLES%2CNO_ZERO_IN_DATE%2CNO_ZERO_DATE%2CERROR_FOR_DIVISION_BY_ZERO%2CNO_AUTO_CREATE_USER%2CNO_ENGINE_SUBSTITUTION%27" + require.Equal(t, expectedDSN, equivalentDSN) result := taskCfg.String() require.Regexp(t, `.*"pd-addr":"172.16.30.11:2379,172.16.30.12:2379".*`, result) diff --git a/cmd/importer/db.go b/cmd/importer/db.go index 49f3d0ec67ad5..f2aca0bdeb83d 100644 --- a/cmd/importer/db.go +++ b/cmd/importer/db.go @@ -22,7 +22,7 @@ import ( "strconv" "strings" - _ "github.com/go-sql-driver/mysql" + mysql2 "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" "github.com/pingcap/log" "github.com/pingcap/tidb/parser/mysql" @@ -318,13 +318,18 @@ func execSQL(db *sql.DB, sql string) error { } func createDB(cfg DBConfig) (*sql.DB, error) { - dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name) - db, err := sql.Open("mysql", dbDSN) + driverCfg := mysql2.NewConfig() + driverCfg.User = cfg.User + driverCfg.Passwd = cfg.Password + driverCfg.Net = "tcp" + driverCfg.Addr = cfg.Host + ":" + strconv.Itoa(cfg.Port) + driverCfg.DBName = cfg.Name + + c, err := mysql2.NewConnector(driverCfg) if err != nil { return nil, errors.Trace(err) } - - return db, nil + return sql.OpenDB(c), nil } func closeDB(db *sql.DB) error { diff --git a/dumpling/export/config.go b/dumpling/export/config.go index 9c9e2484baef8..338ef78dbcbb5 100644 --- a/dumpling/export/config.go +++ b/dumpling/export/config.go @@ -212,6 +212,31 @@ func (conf *Config) GetDSN(db string) string { return dsn } +// GetDriverConfig returns the MySQL driver config from Config. +func (conf *Config) GetDriverConfig(db string) *mysql.Config { + driverCfg := mysql.NewConfig() + // maxAllowedPacket=0 can be used to automatically fetch the max_allowed_packet variable from server on every connection. + // https://github.com/go-sql-driver/mysql#maxallowedpacket + hostPort := net.JoinHostPort(conf.Host, strconv.Itoa(conf.Port)) + driverCfg.User = conf.User + driverCfg.Passwd = conf.Password + driverCfg.Net = "tcp" + driverCfg.Addr = hostPort + driverCfg.DBName = db + driverCfg.Collation = "utf8mb4_general_ci" + driverCfg.ReadTimeout = conf.ReadTimeout + driverCfg.WriteTimeout = 30 * time.Second + driverCfg.InterpolateParams = true + driverCfg.MaxAllowedPacket = 0 + if conf.Security.DriveTLSName != "" { + driverCfg.TLSConfig = conf.Security.DriveTLSName + } + if conf.AllowCleartextPasswords { + driverCfg.AllowCleartextPasswords = true + } + return driverCfg +} + func timestampDirName() string { return fmt.Sprintf("./export-%s", time.Now().Format(time.RFC3339)) } diff --git a/dumpling/export/dump.go b/dumpling/export/dump.go index 76a0c737fb5e7..fa367fd4facd5 100755 --- a/dumpling/export/dump.go +++ b/dumpling/export/dump.go @@ -38,7 +38,7 @@ import ( "github.com/pingcap/tidb/util/codec" ) -var openDBFunc = sql.Open +var openDBFunc = openDB var emptyHandleValsErr = errors.New("empty handleVals for TiDB table") @@ -1293,11 +1293,11 @@ func startHTTPService(d *Dumper) error { // openSQLDB is an initialization step of Dumper. func openSQLDB(d *Dumper) error { conf := d.conf - pool, err := sql.Open("mysql", conf.GetDSN("")) + c, err := mysql.NewConnector(conf.GetDriverConfig("")) if err != nil { return errors.Trace(err) } - d.dbHandle = pool + d.dbHandle = sql.OpenDB(c) return nil } @@ -1470,12 +1470,20 @@ func setSessionParam(d *Dumper) error { } } } - if d.dbHandle, err = resetDBWithSessionParams(d.tctx, pool, conf.GetDSN(""), conf.SessionParams); err != nil { + if d.dbHandle, err = resetDBWithSessionParams(d.tctx, pool, conf.GetDriverConfig(""), conf.SessionParams); err != nil { return errors.Trace(err) } return nil } +func openDB(cfg *mysql.Config) (*sql.DB, error) { + c, err := mysql.NewConnector(cfg) + if err != nil { + return nil, errors.Trace(err) + } + return sql.OpenDB(c), nil +} + func (d *Dumper) renewSelectTableRegionFuncForLowerTiDB(tctx *tcontext.Context) error { conf := d.conf if !(conf.ServerInfo.ServerType == version.ServerTypeTiDB && conf.ServerInfo.ServerVersion != nil && conf.ServerInfo.HasTiKV && @@ -1492,7 +1500,7 @@ func (d *Dumper) renewSelectTableRegionFuncForLowerTiDB(tctx *tcontext.Context) d.selectTiDBTableRegionFunc = func(_ *tcontext.Context, _ *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) { return nil, nil, errors.Annotatef(emptyHandleValsErr, "table: `%s`.`%s`", escapeString(meta.DatabaseName()), escapeString(meta.TableName())) } - dbHandle, err := openDBFunc("mysql", conf.GetDSN("")) + dbHandle, err := openDBFunc(conf.GetDriverConfig("")) if err != nil { return errors.Trace(err) } diff --git a/dumpling/export/sql.go b/dumpling/export/sql.go index 430068a434021..87b9c24a452b0 100644 --- a/dumpling/export/sql.go +++ b/dumpling/export/sql.go @@ -10,7 +10,6 @@ import ( "fmt" "io" "math" - "net/url" "strconv" "strings" @@ -833,7 +832,7 @@ func isUnknownSystemVariableErr(err error) bool { // resetDBWithSessionParams will return a new sql.DB as a replacement for input `db` with new session parameters. // If returned error is nil, the input `db` will be closed. -func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, dsn string, params map[string]interface{}) (*sql.DB, error) { +func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, cfg *mysql.Config, params map[string]interface{}) (*sql.DB, error) { support := make(map[string]interface{}) for k, v := range params { var pv interface{} @@ -861,6 +860,10 @@ func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, dsn string, pa support[k] = pv } + if cfg.Params == nil { + cfg.Params = make(map[string]string) + } + for k, v := range support { var s string // Wrap string with quote to handle string with space. For example, '2020-10-20 13:41:40' @@ -870,14 +873,27 @@ func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, dsn string, pa } else { s = fmt.Sprintf("%v", v) } - dsn += fmt.Sprintf("&%s=%s", k, url.QueryEscape(s)) + cfg.Params[k] = s } +<<<<<<< HEAD newDB, err := sql.Open("mysql", dsn) if err == nil { db.Close() +======= + db.Close() + c, err := mysql.NewConnector(cfg) + if err != nil { + return nil, errors.Trace(err) + } + newDB := sql.OpenDB(c) + // ping to make sure all session parameters are set correctly + err = newDB.PingContext(tctx) + if err != nil { + newDB.Close() +>>>>>>> d0376379d6 (*: don't use DSN to avoid some security problems (#38342)) } - return newDB, errors.Trace(err) + return newDB, nil } func createConnWithConsistency(ctx context.Context, db *sql.DB, repeatableRead bool) (*sql.Conn, error) { diff --git a/dumpling/export/sql_test.go b/dumpling/export/sql_test.go index 74df4557c6caf..1fd0d0621052c 100644 --- a/dumpling/export/sql_test.go +++ b/dumpling/export/sql_test.go @@ -1341,7 +1341,7 @@ func TestBuildVersion3RegionQueries(t *testing.T) { defer func() { openDBFunc = oldOpenFunc }() - openDBFunc = func(_, _ string) (*sql.DB, error) { + openDBFunc = func(*mysql.Config) (*sql.DB, error) { return db, nil } diff --git a/dumpling/tests/s3/import.go b/dumpling/tests/s3/import.go index 0489be3fa7a80..30dc95fae84b1 100644 --- a/dumpling/tests/s3/import.go +++ b/dumpling/tests/s3/import.go @@ -6,7 +6,9 @@ import ( "context" "database/sql" "fmt" + "net" "os" + "strconv" _ "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" @@ -48,7 +50,7 @@ func main() { return errors.Trace(err) } - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4", "root", "", "127.0.0.1", port, database) + dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4", "root", "", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)), database) db, err := sql.Open("mysql", dsn) if err != nil { return errors.Trace(err) diff --git a/util/dbutil/common.go b/util/dbutil/common.go index eadf0714aea6e..7ee717f090f19 100644 --- a/util/dbutil/common.go +++ b/util/dbutil/common.go @@ -19,7 +19,7 @@ import ( "database/sql" "encoding/json" "fmt" - "net/url" + "net" "os" "strconv" "strings" @@ -112,26 +112,31 @@ func GetDBConfigFromEnv(schema string) DBConfig { // OpenDB opens a mysql connection FD func OpenDB(cfg DBConfig, vars map[string]string) (*sql.DB, error) { - var dbDSN string + driverCfg := mysql.NewConfig() + driverCfg.Params = make(map[string]string) + driverCfg.User = cfg.User + driverCfg.Passwd = cfg.Password + driverCfg.Net = "tcp" + driverCfg.Addr = net.JoinHostPort(cfg.Host, strconv.Itoa(cfg.Port)) + driverCfg.Params["charset"] = "utf8mb4" + if len(cfg.Snapshot) != 0 { log.Info("create connection with snapshot", zap.String("snapshot", cfg.Snapshot)) - dbDSN = fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4&tidb_snapshot=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Snapshot) - } else { - dbDSN = fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4", cfg.User, cfg.Password, cfg.Host, cfg.Port) + driverCfg.Params["tidb_snapshot"] = cfg.Snapshot } for key, val := range vars { // key='val'. add single quote for better compatibility. - dbDSN += fmt.Sprintf("&%s=%%27%s%%27", key, url.QueryEscape(val)) + driverCfg.Params[key] = fmt.Sprintf("'%s'", val) } - dbConn, err := sql.Open("mysql", dbDSN) + c, err := mysql.NewConnector(driverCfg) if err != nil { return nil, errors.Trace(err) } - - err = dbConn.Ping() - return dbConn, errors.Trace(err) + db := sql.OpenDB(c) + err = db.Ping() + return db, errors.Trace(err) } // CloseDB closes the mysql fd From 4e2849e6a9ec83fa29a92a564f1a7d993f0bca8a Mon Sep 17 00:00:00 2001 From: lance6716 Date: Wed, 19 Oct 2022 14:17:57 +0800 Subject: [PATCH 2/2] fix conflict Signed-off-by: lance6716 --- dumpling/export/sql.go | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/dumpling/export/sql.go b/dumpling/export/sql.go index 87b9c24a452b0..4682df7ed7c27 100644 --- a/dumpling/export/sql.go +++ b/dumpling/export/sql.go @@ -876,12 +876,6 @@ func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, cfg *mysql.Con cfg.Params[k] = s } -<<<<<<< HEAD - newDB, err := sql.Open("mysql", dsn) - if err == nil { - db.Close() -======= - db.Close() c, err := mysql.NewConnector(cfg) if err != nil { return nil, errors.Trace(err) @@ -889,9 +883,8 @@ func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, cfg *mysql.Con newDB := sql.OpenDB(c) // ping to make sure all session parameters are set correctly err = newDB.PingContext(tctx) - if err != nil { - newDB.Close() ->>>>>>> d0376379d6 (*: don't use DSN to avoid some security problems (#38342)) + if err == nil { + db.Close() } return newDB, nil }