Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: don't use DSN to avoid some security problems (#38342) #38543

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions br/pkg/lightning/checkpoints/checkpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,15 @@ func OpenCheckpointsDB(ctx context.Context, cfg *config.Config) (DB, error) {

switch cfg.Checkpoint.Driver {
case config.CheckpointDriverMySQL:
db, err := common.ConnectMySQL(cfg.Checkpoint.DSN)
var (
db *sql.DB
err error
)
if cfg.Checkpoint.MySQLParam != nil {
db, err = cfg.Checkpoint.MySQLParam.Connect()
} else {
db, err = sql.Open("mysql", cfg.Checkpoint.DSN)
}
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -545,7 +553,15 @@ func IsCheckpointsDBExists(ctx context.Context, cfg *config.Config) (bool, error
}
switch cfg.Checkpoint.Driver {
case config.CheckpointDriverMySQL:
db, err := sql.Open("mysql", cfg.Checkpoint.DSN)
var (
db *sql.DB
err error
)
if cfg.Checkpoint.MySQLParam != nil {
db, err = cfg.Checkpoint.MySQLParam.Connect()
} else {
db, err = sql.Open("mysql", cfg.Checkpoint.DSN)
}
if err != nil {
return false, errors.Trace(err)
}
Expand Down
51 changes: 28 additions & 23 deletions br/pkg/lightning/common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
Expand Down Expand Up @@ -58,28 +57,38 @@ type MySQLConnectParam struct {
Vars map[string]string
}

func (param *MySQLConnectParam) ToDSN() string {
hostPort := net.JoinHostPort(param.Host, strconv.Itoa(param.Port))
dsn := fmt.Sprintf("%s:%s@tcp(%s)/?charset=utf8mb4&sql_mode='%s'&maxAllowedPacket=%d&tls=%s",
param.User, param.Password, hostPort,
param.SQLMode, param.MaxAllowedPacket, param.TLS)
func (param *MySQLConnectParam) ToDriverConfig() *mysql.Config {
cfg := mysql.NewConfig()
cfg.Params = make(map[string]string)

cfg.User = param.User
cfg.Passwd = param.Password
cfg.Net = "tcp"
cfg.Addr = net.JoinHostPort(param.Host, strconv.Itoa(param.Port))
cfg.Params["charset"] = "utf8mb4"
cfg.Params["sql_mode"] = fmt.Sprintf("'%s'", param.SQLMode)
cfg.MaxAllowedPacket = int(param.MaxAllowedPacket)
cfg.TLSConfig = param.TLS

for k, v := range param.Vars {
dsn += fmt.Sprintf("&%s='%s'", k, url.QueryEscape(v))
cfg.Params[k] = fmt.Sprintf("'%s'", v)
}

return dsn
return cfg
}

func tryConnectMySQL(dsn string) (*sql.DB, error) {
driverName := "mysql"
failpoint.Inject("MockMySQLDriver", func(val failpoint.Value) {
driverName = val.(string)
func tryConnectMySQL(cfg *mysql.Config) (*sql.DB, error) {
failpoint.Inject("MustMySQLPassword", func(val failpoint.Value) {
pwd := val.(string)
if cfg.Passwd != pwd {
failpoint.Return(nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"})
}
failpoint.Return(nil, nil)
})
db, err := sql.Open(driverName, dsn)
c, err := mysql.NewConnector(cfg)
if err != nil {
return nil, errors.Trace(err)
}
db := sql.OpenDB(c)
if err = db.Ping(); err != nil {
_ = db.Close()
return nil, errors.Trace(err)
Expand All @@ -89,13 +98,9 @@ func tryConnectMySQL(dsn string) (*sql.DB, error) {

// ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding,
// we will try to connect MySQL with the base64 decoding of the password.
func ConnectMySQL(dsn string) (*sql.DB, error) {
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return nil, errors.Trace(err)
}
func ConnectMySQL(cfg *mysql.Config) (*sql.DB, error) {
// Try plain password first.
db, firstErr := tryConnectMySQL(dsn)
db, firstErr := tryConnectMySQL(cfg)
if firstErr == nil {
return db, nil
}
Expand All @@ -104,9 +109,9 @@ func ConnectMySQL(dsn string) (*sql.DB, error) {
// If password is encoded by base64, try the decoded string as well.
if password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd); decodeErr == nil && string(password) != cfg.Passwd {
cfg.Passwd = string(password)
db, err = tryConnectMySQL(cfg.FormatDSN())
db2, err := tryConnectMySQL(cfg)
if err == nil {
return db, nil
return db2, nil
}
}
}
Expand All @@ -115,7 +120,7 @@ func ConnectMySQL(dsn string) (*sql.DB, error) {
}

func (param *MySQLConnectParam) Connect() (*sql.DB, error) {
db, err := ConnectMySQL(param.ToDSN())
db, err := ConnectMySQL(param.ToDriverConfig())
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
69 changes: 5 additions & 64 deletions br/pkg/lightning/common/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@ package common_test

import (
"context"
"database/sql"
"database/sql/driver"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"

Expand All @@ -35,7 +31,6 @@ import (
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/lightning/log"
tmysql "github.com/pingcap/tidb/errno"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -85,66 +80,14 @@ func TestGetJSON(t *testing.T) {
require.Regexp(t, ".*http status code != 200.*", err.Error())
}

func TestToDSN(t *testing.T) {
param := common.MySQLConnectParam{
Host: "127.0.0.1",
Port: 4000,
User: "root",
Password: "123456",
SQLMode: "strict",
MaxAllowedPacket: 1234,
TLS: "cluster",
Vars: map[string]string{
"tidb_distsql_scan_concurrency": "1",
},
}
require.Equal(t, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN())

param.Host = "::1"
require.Equal(t, "root:123456@tcp([::1]:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN())
}

type mockDriver struct {
driver.Driver
plainPsw string
}

func (m *mockDriver) Open(dsn string) (driver.Conn, error) {
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return nil, err
}
accessDenied := cfg.Passwd != m.plainPsw
return &mockConn{accessDenied: accessDenied}, nil
}

type mockConn struct {
driver.Conn
driver.Pinger
accessDenied bool
}

func (c *mockConn) Ping(ctx context.Context) error {
if c.accessDenied {
return &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}
}
return nil
}

func (c *mockConn) Close() error {
return nil
}

func TestConnect(t *testing.T) {
plainPsw := "dQAUoDiyb1ucWZk7"
driverName := "mysql-mock-" + strconv.Itoa(rand.Int())
sql.Register(driverName, &mockDriver{plainPsw: plainPsw})

require.NoError(t, failpoint.Enable(
"github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver",
fmt.Sprintf("return(\"%s\")", driverName)))
"github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword",
fmt.Sprintf("return(\"%s\")", plainPsw)))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver"))
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword"))
}()

param := common.MySQLConnectParam{
Expand All @@ -155,13 +98,11 @@ func TestConnect(t *testing.T) {
SQLMode: "strict",
MaxAllowedPacket: 1234,
}
db, err := param.Connect()
_, err := param.Connect()
require.NoError(t, err)
require.NoError(t, db.Close())
param.Password = base64.StdEncoding.EncodeToString([]byte(plainPsw))
db, err = param.Connect()
_, err = param.Connect()
require.NoError(t, err)
require.NoError(t, db.Close())
}

func TestIsContextCanceledError(t *testing.T) {
Expand Down
13 changes: 7 additions & 6 deletions br/pkg/lightning/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,12 @@ type TikvImporter struct {
}

type Checkpoint struct {
Schema string `toml:"schema" json:"schema"`
DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON.
Driver string `toml:"driver" json:"driver"`
Enable bool `toml:"enable" json:"enable"`
KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"`
Schema string `toml:"schema" json:"schema"`
DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON.
MySQLParam *common.MySQLConnectParam `toml:"-" json:"-"` // For some security reason, we use MySQLParam instead of DSN.
Driver string `toml:"driver" json:"driver"`
Enable bool `toml:"enable" json:"enable"`
KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"`
}

type Cron struct {
Expand Down Expand Up @@ -1126,7 +1127,7 @@ func (cfg *Config) AdjustCheckPoint() {
MaxAllowedPacket: defaultMaxAllowedPacket,
TLS: cfg.TiDB.TLS,
}
cfg.Checkpoint.DSN = param.ToDSN()
cfg.Checkpoint.MySQLParam = &param
case CheckpointDriverFile:
cfg.Checkpoint.DSN = "/tmp/" + cfg.Checkpoint.Schema + ".pb"
}
Expand Down
5 changes: 3 additions & 2 deletions br/pkg/lightning/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (

"github.com/BurntSushi/toml"
"github.com/pingcap/tidb/br/pkg/lightning/config"
"github.com/pingcap/tidb/parser/mysql"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -625,7 +624,9 @@ func TestLoadConfig(t *testing.T) {
taskCfg.TiDB.DistSQLScanConcurrency = 1
err = taskCfg.Adjust(context.Background())
require.NoError(t, err)
require.Equal(t, "guest:12345@tcp(172.16.30.11:4001)/?charset=utf8mb4&sql_mode='"+mysql.DefaultSQLMode+"'&maxAllowedPacket=67108864&tls=false", taskCfg.Checkpoint.DSN)
equivalentDSN := taskCfg.Checkpoint.MySQLParam.ToDriverConfig().FormatDSN()
expectedDSN := "guest:12345@tcp(172.16.30.11:4001)/?tls=false&maxAllowedPacket=67108864&charset=utf8mb4&sql_mode=%27ONLY_FULL_GROUP_BY%2CSTRICT_TRANS_TABLES%2CNO_ZERO_IN_DATE%2CNO_ZERO_DATE%2CERROR_FOR_DIVISION_BY_ZERO%2CNO_AUTO_CREATE_USER%2CNO_ENGINE_SUBSTITUTION%27"
require.Equal(t, expectedDSN, equivalentDSN)

result := taskCfg.String()
require.Regexp(t, `.*"pd-addr":"172.16.30.11:2379,172.16.30.12:2379".*`, result)
Expand Down
15 changes: 10 additions & 5 deletions cmd/importer/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"strconv"
"strings"

_ "github.com/go-sql-driver/mysql"
mysql2 "github.com/go-sql-driver/mysql"
"github.com/pingcap/errors"
"github.com/pingcap/log"
"github.com/pingcap/tidb/parser/mysql"
Expand Down Expand Up @@ -318,13 +318,18 @@ func execSQL(db *sql.DB, sql string) error {
}

func createDB(cfg DBConfig) (*sql.DB, error) {
dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name)
db, err := sql.Open("mysql", dbDSN)
driverCfg := mysql2.NewConfig()
driverCfg.User = cfg.User
driverCfg.Passwd = cfg.Password
driverCfg.Net = "tcp"
driverCfg.Addr = cfg.Host + ":" + strconv.Itoa(cfg.Port)
driverCfg.DBName = cfg.Name

c, err := mysql2.NewConnector(driverCfg)
if err != nil {
return nil, errors.Trace(err)
}

return db, nil
return sql.OpenDB(c), nil
}

func closeDB(db *sql.DB) error {
Expand Down
25 changes: 25 additions & 0 deletions dumpling/export/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,31 @@ func (conf *Config) GetDSN(db string) string {
return dsn
}

// GetDriverConfig returns the MySQL driver config from Config.
func (conf *Config) GetDriverConfig(db string) *mysql.Config {
driverCfg := mysql.NewConfig()
// maxAllowedPacket=0 can be used to automatically fetch the max_allowed_packet variable from server on every connection.
// https://github.com/go-sql-driver/mysql#maxallowedpacket
hostPort := net.JoinHostPort(conf.Host, strconv.Itoa(conf.Port))
driverCfg.User = conf.User
driverCfg.Passwd = conf.Password
driverCfg.Net = "tcp"
driverCfg.Addr = hostPort
driverCfg.DBName = db
driverCfg.Collation = "utf8mb4_general_ci"
driverCfg.ReadTimeout = conf.ReadTimeout
driverCfg.WriteTimeout = 30 * time.Second
driverCfg.InterpolateParams = true
driverCfg.MaxAllowedPacket = 0
if conf.Security.DriveTLSName != "" {
driverCfg.TLSConfig = conf.Security.DriveTLSName
}
if conf.AllowCleartextPasswords {
driverCfg.AllowCleartextPasswords = true
}
return driverCfg
}

func timestampDirName() string {
return fmt.Sprintf("./export-%s", time.Now().Format(time.RFC3339))
}
Expand Down
Loading