From 5427a8d66217a7251b22769cf841916922a2b4f1 Mon Sep 17 00:00:00 2001 From: dvilaverde <dvilaverde@adobe.com> Date: Fri, 26 Apr 2024 14:16:14 -0400 Subject: [PATCH 01/19] allow setting the collation in auth handshake --- client/auth.go | 12 +++++++++++- client/client_test.go | 21 ++++++++++++++++++++- client/conn.go | 16 ++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/client/auth.go b/client/auth.go index e4fa908d3..7392f8fdd 100644 --- a/client/auth.go +++ b/client/auth.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/binary" "fmt" + "github.com/pingcap/tidb/pkg/parser/charset" . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" @@ -269,7 +270,16 @@ func (c *Conn) writeAuthHandshake() error { // Charset [1 byte] // use default collation id 33 here, is utf-8 - data[12] = DEFAULT_COLLATION_ID + collationName := c.collation + if len(collationName) == 0 { + collationName = DEFAULT_COLLATION_NAME + } + collation, err := charset.GetCollationByName(collationName) + if err != nil { + return fmt.Errorf("invalid collation name %s", collationName) + } + + data[12] = byte(collation.ID) // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest diff --git a/client/client_test.go b/client/client_test.go index c47c795ef..b27c4c669 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -31,7 +31,11 @@ func TestClientSuite(t *testing.T) { func (s *clientTestSuite) SetupSuite() { var err error addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - s.c, err = Connect(addr, *testUser, *testPassword, "") + s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + // test the collation logic, but this is essentially a no-op since + // the collation set is the default value + _ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME) + }) require.NoError(s.T(), err) var result *mysql.Result @@ -228,6 +232,21 @@ func (s *clientTestSuite) TestConn_SetCharset() { require.NoError(s.T(), err) } +func (s *clientTestSuite) TestConn_SetCollationAfterConnect() { + err := s.c.SetCollation("latin1_swedish_ci") + require.Error(s.T(), err) +} + +func (s *clientTestSuite) TestConn_SetCollation() { + addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) + _, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + // test the collation logic + _ = conn.SetCollation("invalid_collation") + }) + + require.Error(s.T(), err) +} + func (s *clientTestSuite) testStmt_DropTable() { str := `drop table if exists mixer_test_stmt` diff --git a/client/conn.go b/client/conn.go index b1f3e52d1..1db021762 100644 --- a/client/conn.go +++ b/client/conn.go @@ -37,6 +37,8 @@ type Conn struct { status uint16 charset string + // sets the collation to be set on the auth handshake, this does not issue a 'set names' command + collation string salt []byte authPluginName string @@ -357,6 +359,20 @@ func (c *Conn) SetCharset(charset string) error { } } +func (c *Conn) SetCollation(collation string) error { + if c.status == 0 { + c.collation = collation + } else { + return errors.Trace(errors.Errorf("cannot set collation after connection is established")) + } + + return nil +} + +func (c *Conn) GetCollation() string { + return c.collation +} + func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) { if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil { return nil, errors.Trace(err) From 10339ddc0ad003634f447d92547195bea02c544d Mon Sep 17 00:00:00 2001 From: dvilaverde <dvilaverde@adobe.com> Date: Fri, 26 Apr 2024 16:42:52 -0400 Subject: [PATCH 02/19] Allow connect with context in order to provide configurable connect timeouts --- client/conn.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/client/conn.go b/client/conn.go index 1db021762..9d7014951 100644 --- a/client/conn.go +++ b/client/conn.go @@ -69,15 +69,19 @@ func Connect(addr string, user string, password string, dbName string, options . ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - dialer := &net.Dialer{} + return ConnectWithContext(ctx, addr, user, password, dbName, options...) +} +// ConnectWithContext to a MySQL addr using the provided context. +func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) { + dialer := &net.Dialer{} return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...) } // Dialer connects to the address on the named network using the provided context. type Dialer func(ctx context.Context, network, address string) (net.Conn, error) -// Connect to a MySQL server using the given Dialer. +// ConnectWithDialer to a MySQL server using the given Dialer. func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) { c := new(Conn) From 1c658b96c145e3319e05e9f623ecab7ef5a40edb Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Sat, 27 Apr 2024 17:33:03 -0400 Subject: [PATCH 03/19] add driver arguments --- driver/driver.go | 69 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 20 deletions(-) diff --git a/driver/driver.go b/driver/driver.go index b86c4b374..93291e800 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -10,6 +10,7 @@ import ( "io" "net/url" "regexp" + "strconv" "sync" "github.com/go-mysql-org/go-mysql/client" @@ -21,7 +22,10 @@ import ( var customTLSMutex sync.Mutex // Map of dsn address (makes more sense than full dsn?) to tls Config -var customTLSConfigMap = make(map[string]*tls.Config) +var ( + customTLSConfigMap = make(map[string]*tls.Config) + options = make(map[string]connOption) +) type driver struct { } @@ -35,6 +39,8 @@ type connInfo struct { params url.Values } +type connOption func(c *client.Conn, value string) + // ParseDSN takes a DSN string and splits it up into struct containing addr, // user, password and db. // It returns an error if unable to parse. @@ -92,27 +98,41 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { } if ci.standardDSN { - if ci.params["ssl"] != nil { - tlsConfigName := ci.params.Get("ssl") - switch tlsConfigName { - case "true": - // This actually does insecureSkipVerify - // But not even sure if it makes sense to handle false? According to - // client_test.go it doesn't - it'd result in an error - c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, func(c *client.Conn) { c.UseSSL(true) }) - case "custom": - // I was too concerned about mimicking what go-sql-driver/mysql does which will - // allow any name for a custom tls profile and maps the query parameter value to - // that TLSConfig variable... there is no need to be that clever. - // Instead of doing that, let's store required custom TLSConfigs in a map that - // uses the DSN address as the key - c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, func(c *client.Conn) { c.SetTLSConfig(customTLSConfigMap[ci.addr]) }) - default: - return nil, errors.Errorf("Supported options are ssl=true or ssl=custom") + configuredOptions := make([]func(*client.Conn), 0, len(ci.params)) + for key, value := range ci.params { + if key == "ssl" { + tlsConfigName := ci.params.Get("ssl") + switch tlsConfigName { + case "true": + // This actually does insecureSkipVerify + // But not even sure if it makes sense to handle false? According to + // client_test.go it doesn't - it'd result in an error + configuredOptions = append(configuredOptions, func(c *client.Conn) { c.UseSSL(true) }) + case "custom": + // I was too concerned about mimicking what go-sql-driver/mysql does which will + // allow any name for a custom tls profile and maps the query parameter value to + // that TLSConfig variable... there is no need to be that clever. + // Instead of doing that, let's store required custom TLSConfigs in a map that + // uses the DSN address as the key + configuredOptions = append(configuredOptions, func(c *client.Conn) { c.SetTLSConfig(customTLSConfigMap[ci.addr]) }) + default: + return nil, errors.Errorf("Supported options are ssl=true or ssl=custom") + } + } else { + if option, ok := options[key]; ok { + opt := func(o connOption, v string) func(c *client.Conn) { + return func(c *client.Conn) { + o(c, v) + } + }(option, value[0]) + configuredOptions = append(configuredOptions, opt) + } else { + return nil, errors.Errorf("unsupported connection option: %s", key) + } } - } else { - c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db) } + + c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, configuredOptions...) } else { // No more processing here. Let's only support url parameters with the newer style DSN c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db) @@ -296,6 +316,15 @@ func (r *rows) Next(dest []sqldriver.Value) error { } func init() { + options["compress"] = func(c *client.Conn, value string) { + if b, err := strconv.ParseBool(value); err == nil && b { + c.SetCapability(mysql.CLIENT_COMPRESS) + } + } + options["collation"] = func(c *client.Conn, value string) { + c.SetCollation(value) + } + sql.Register("mysql", driver{}) } From 877bc05e40f38d21f88c06ad4760011b80035a18 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Sat, 27 Apr 2024 21:42:49 -0400 Subject: [PATCH 04/19] check for empty ssl value when setting conn options --- driver/driver.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/driver/driver.go b/driver/driver.go index 93291e800..80bde7a49 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -100,8 +100,8 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { if ci.standardDSN { configuredOptions := make([]func(*client.Conn), 0, len(ci.params)) for key, value := range ci.params { - if key == "ssl" { - tlsConfigName := ci.params.Get("ssl") + if key == "ssl" && len(value) > 0 { + tlsConfigName := value[0] switch tlsConfigName { case "true": // This actually does insecureSkipVerify From 3deb7dc77e8f0b50c2a96b9d70b2c36b65f599d2 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Tue, 30 Apr 2024 07:55:22 -0400 Subject: [PATCH 05/19] allow setting the collation in auth handshake (#860) * Allow connect with context in order to provide configurable connect timeouts * support collations IDs greater than 255 on the auth handshake --------- Co-authored-by: dvilaverde <dvilaverde@adobe.com> --- client/auth.go | 22 +++++++++--- client/auth_test.go | 78 ++++++++++++++++++++++++++++++++++++++++++- client/client_test.go | 1 + client/conn.go | 5 ++- 4 files changed, 97 insertions(+), 9 deletions(-) diff --git a/client/auth.go b/client/auth.go index 7392f8fdd..1f4d7c1de 100644 --- a/client/auth.go +++ b/client/auth.go @@ -5,11 +5,11 @@ import ( "crypto/tls" "encoding/binary" "fmt" - "github.com/pingcap/tidb/pkg/parser/charset" . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/parser/charset" ) const defaultAuthPluginName = AUTH_NATIVE_PASSWORD @@ -269,7 +269,7 @@ func (c *Conn) writeAuthHandshake() error { data[11] = 0x00 // Charset [1 byte] - // use default collation id 33 here, is utf-8 + // use default collation id 33 here, is `utf8mb3_general_ci` collationName := c.collation if len(collationName) == 0 { collationName = DEFAULT_COLLATION_NAME @@ -279,7 +279,15 @@ func (c *Conn) writeAuthHandshake() error { return fmt.Errorf("invalid collation name %s", collationName) } - data[12] = byte(collation.ID) + // the MySQL protocol calls for the collation id to be sent as 1, where only the + // lower 8 bits are used in this field. But wireshark shows that the first byte of + // the 23 bytes of filler is used to send the right middle 8 bits of the collation id. + // see https://github.com/mysql/mysql-server/pull/541 + data[12] = byte(collation.ID & 0xff) + // if the collation ID is <= 255 the middle 8 bits are 0s so this is the equivalent of + // padding the filler with a 0. If ID is > 255 then the first byte of filler will contain + // the right middle 8 bits of the collation ID. + data[13] = byte((collation.ID & 0xff00) >> 8) // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest @@ -301,8 +309,12 @@ func (c *Conn) writeAuthHandshake() error { } // Filler [23 bytes] (all 0x00) - pos := 13 - for ; pos < 13+23; pos++ { + // the filler starts at position 13, but the first byte of the filler + // has been set with the collation id earlier, so position 13 at this point + // will be either 0x00, or the right middle 8 bits of the collation id. + // Therefore, we start at position 14 and fill the remaining 22 bytes with 0x00. + pos := 14 + for ; pos < 14+22; pos++ { data[pos] = 0 } diff --git a/client/auth_test.go b/client/auth_test.go index 85dba1e98..0837f1767 100644 --- a/client/auth_test.go +++ b/client/auth_test.go @@ -1,10 +1,14 @@ package client import ( + "net" "testing" - "github.com/go-mysql-org/go-mysql/mysql" + "github.com/pingcap/tidb/pkg/parser/charset" "github.com/stretchr/testify/require" + + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/packet" ) func TestConnGenAttributes(t *testing.T) { @@ -34,3 +38,75 @@ func TestConnGenAttributes(t *testing.T) { require.Subset(t, data, fixt) } } + +func TestConnCollation(t *testing.T) { + collations := []string{ + "big5_chinese_ci", + "utf8_general_ci", + "utf8mb4_0900_ai_ci", + "utf8mb4_de_pb_0900_ai_ci", + "utf8mb4_ja_0900_as_cs", + "utf8mb4_0900_bin", + "utf8mb4_zh_pinyin_tidb_as_cs", + } + + // test all supported collations by calling writeAuthHandshake() and reading the bytes + // sent to the server to ensure the collation id is set correctly + for _, c := range collations { + collation, err := charset.GetCollationByName(c) + require.NoError(t, err) + server := sendAuthResponse(t, collation.Name) + // read the all the bytes of the handshake response so that client goroutine can complete without blocking + // on the server read. + handShakeResponse := make([]byte, 128) + _, err = server.Read(handShakeResponse) + require.NoError(t, err) + + // validate the collation id is set correctly + // if the collation ID is <= 255 the collation ID is stored in the 12th byte + if collation.ID <= 255 { + require.Equal(t, byte(collation.ID), handShakeResponse[12]) + // the 13th byte should always be 0x00 + require.Equal(t, byte(0x00), handShakeResponse[13]) + } else { + // if the collation ID is > 255 the collation ID is stored in the 12th and 13th bytes + require.Equal(t, byte(collation.ID&0xff), handShakeResponse[12]) + require.Equal(t, byte(collation.ID>>8), handShakeResponse[13]) + } + + // sanity check: validate the 22 bytes of filler with value 0x00 are set correctly + for i := 14; i < 14+22; i++ { + require.Equal(t, byte(0x00), handShakeResponse[i]) + } + + // and finally the username + username := string(handShakeResponse[36:40]) + require.Equal(t, "test", username) + + require.NoError(t, server.Close()) + } +} + +func sendAuthResponse(t *testing.T, collation string) net.Conn { + server, client := net.Pipe() + c := &Conn{ + Conn: &packet.Conn{ + Conn: client, + }, + authPluginName: "mysql_native_password", + user: "test", + db: "test", + password: "test", + proto: "tcp", + collation: collation, + salt: ([]byte)("123456781234567812345678"), + } + + go func() { + err := c.writeAuthHandshake() + require.NoError(t, err) + err = c.Close() + require.NoError(t, err) + }() + return server +} diff --git a/client/client_test.go b/client/client_test.go index b27c4c669..10515e622 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -235,6 +235,7 @@ func (s *clientTestSuite) TestConn_SetCharset() { func (s *clientTestSuite) TestConn_SetCollationAfterConnect() { err := s.c.SetCollation("latin1_swedish_ci") require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "cannot set collation after connection is established") } func (s *clientTestSuite) TestConn_SetCollation() { diff --git a/client/conn.go b/client/conn.go index 9d7014951..9fc7faf16 100644 --- a/client/conn.go +++ b/client/conn.go @@ -364,12 +364,11 @@ func (c *Conn) SetCharset(charset string) error { } func (c *Conn) SetCollation(collation string) error { - if c.status == 0 { - c.collation = collation - } else { + if len(c.serverVersion) != 0 { return errors.Trace(errors.Errorf("cannot set collation after connection is established")) } + c.collation = collation return nil } From 509acf4c9147ac30fd4aeee25b25ff2b1687a649 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Mon, 20 May 2024 19:22:36 -0400 Subject: [PATCH 06/19] refactored and added more driver args --- Makefile | 3 ++- README.md | 52 +++++++++++++++++++++++++++++++++++++ canal/canal.go | 7 ++--- client/client_test.go | 20 ++++++++------ client/conn.go | 30 ++++++++++++++------- client/conn_test.go | 3 ++- client/pool.go | 2 +- client/pool_options.go | 4 +-- driver/driver.go | 44 +++++++++++++++++-------------- driver/driver_options.go | 50 +++++++++++++++++++++++++++++++++++ driver/driver_test.go | 16 +++++++----- packet/conn.go | 43 +++++++++++++++++++++++++++--- replication/backup_test.go | 3 ++- replication/binlogsyncer.go | 5 ++-- 14 files changed, 225 insertions(+), 57 deletions(-) create mode 100644 driver/driver_options.go diff --git a/Makefile b/Makefile index 480b6a874..8da0d7081 100644 --- a/Makefile +++ b/Makefile @@ -13,9 +13,10 @@ test: MYSQL_VERSION ?= 8.0 test-local: - docker run --rm -d --network=host --name go-mysql-server \ + docker run --rm --name go-mysql-server \ -e MYSQL_ALLOW_EMPTY_PASSWORD=true \ -e MYSQL_DATABASE=test \ + -p 3306:3306 \ -v $${PWD}/docker/resources/replication.cnf:/etc/mysql/conf.d/replication.cnf \ mysql:$(MYSQL_VERSION) docker/resources/waitfor.sh 127.0.0.1 3306 \ diff --git a/README.md b/README.md index 926c1532d..74cf3b077 100644 --- a/README.md +++ b/README.md @@ -360,6 +360,58 @@ func main() { } ``` +### Driver Options + +Configuration options can be provided by the standard DSN (Data Source Name). + +``` +[user[:password]@]addr[/db[?param=X]] +``` + +#### `compress` + +Enable zlib compression between the client and the server. + +| Type | Default | Example | +| --------- | --------- | ------------------------------------------- | +| boolean | false | user:pass@localhost/mydb?compress=true | + +#### `readTimeout` + +I/O read timeout. + +| Type | Default | Example | +| --------- | --------- | ------------------------------------------- | +| duration | 0 | user:pass@localhost/mydb?readTimeout=10s | + +#### `ssl` + +Enable TLS between client and server. + +| Type | Default | Allowed Values | +| --------- | --------- | ------------------------------------------- | +| string | | `true` or `custom` | + + +#### `timeout` + +Timeout is the maximum amount of time a dial will wait for a connect to complete. + +| Type | Default | Example | +| --------- | --------- | ------------------------------------------- | +| duration | 0 | user:pass@localhost/mydb?timeout=1m | + +#### `writeTimeout` + +I/O write timeout. + +| Type | Default | Example | +| --------- | --------- | ----------------------------------------------- | +| duration | 0 | user:pass@localhost/mydb?writeTimeout=1m30s | + + + + We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-) ## Donate diff --git a/canal/canal.go b/canal/canal.go index 0108f3c27..20e09952e 100644 --- a/canal/canal.go +++ b/canal/canal.go @@ -499,7 +499,7 @@ func (c *Canal) prepareSyncer() error { return nil } -func (c *Canal) connect(options ...func(*client.Conn)) (*client.Conn, error) { +func (c *Canal) connect(options ...client.Option) (*client.Conn, error) { ctx, cancel := context.WithTimeout(c.ctx, time.Second*10) defer cancel() @@ -511,10 +511,11 @@ func (c *Canal) connect(options ...func(*client.Conn)) (*client.Conn, error) { func (c *Canal) Execute(cmd string, args ...interface{}) (rr *mysql.Result, err error) { c.connLock.Lock() defer c.connLock.Unlock() - argF := make([]func(*client.Conn), 0) + argF := make([]client.Option, 0) if c.cfg.TLSConfig != nil { - argF = append(argF, func(conn *client.Conn) { + argF = append(argF, func(conn *client.Conn) error { conn.SetTLSConfig(c.cfg.TLSConfig) + return nil }) } diff --git a/client/client_test.go b/client/client_test.go index aaf72ff42..3917db3f5 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -31,10 +31,10 @@ func TestClientSuite(t *testing.T) { func (s *clientTestSuite) SetupSuite() { var err error addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error { // test the collation logic, but this is essentially a no-op since // the collation set is the default value - _ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME) + return conn.SetCollation(mysql.DEFAULT_COLLATION_NAME) }) require.NoError(s.T(), err) @@ -91,8 +91,9 @@ func (s *clientTestSuite) TestConn_Ping() { func (s *clientTestSuite) TestConn_Compress() { addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error { conn.SetCapability(mysql.CLIENT_COMPRESS) + return nil }) require.NoError(s.T(), err) @@ -142,8 +143,9 @@ func (s *clientTestSuite) TestConn_TLS_Verify() { // Verify that the provided tls.Config is used when attempting to connect to mysql. // An empty tls.Config will result in a connection error. addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { + _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) error { c.UseSSL(false) + return nil }) expected := "either ServerName or InsecureSkipVerify must be specified in the tls.Config" @@ -153,8 +155,9 @@ func (s *clientTestSuite) TestConn_TLS_Verify() { func (s *clientTestSuite) TestConn_TLS_Skip_Verify() { // An empty tls.Config will result in a connection error but we can configure to skip it. addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { + _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) error { c.UseSSL(true) + return nil }) require.NoError(s.T(), err) } @@ -165,8 +168,9 @@ func (s *clientTestSuite) TestConn_TLS_Certificate() { // "x509: certificate is valid for MySQL_Server_8.0.12_Auto_Generated_Server_Certificate, not not-a-valid-name" tlsConfig := NewClientTLSConfig(test_keys.CaPem, test_keys.CertPem, test_keys.KeyPem, false, "not-a-valid-name") addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { + _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) error { c.SetTLSConfig(tlsConfig) + return nil }) require.Error(s.T(), err) if !strings.Contains(errors.ErrorStack(err), "certificate is not valid for any names") && @@ -251,9 +255,9 @@ func (s *clientTestSuite) TestConn_SetCollationAfterConnect() { func (s *clientTestSuite) TestConn_SetCollation() { addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - _, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + _, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error { // test the collation logic - _ = conn.SetCollation("invalid_collation") + return conn.SetCollation("invalid_collation") }) require.Error(s.T(), err) diff --git a/client/conn.go b/client/conn.go index bef9b2de9..6ff61a786 100644 --- a/client/conn.go +++ b/client/conn.go @@ -18,6 +18,8 @@ import ( "github.com/go-mysql-org/go-mysql/utils" ) +type Option func(*Conn) error + type Conn struct { *packet.Conn @@ -27,6 +29,10 @@ type Conn struct { tlsConfig *tls.Config proto string + // Connection read and write timeouts to set on the connection + ReadTimeout time.Duration + WriteTimeout time.Duration + serverVersion string // server capabilities capability uint32 @@ -66,16 +72,18 @@ func getNetProto(addr string) string { // Connect to a MySQL server, addr can be ip:port, or a unix socket domain like /var/sock. // Accepts a series of configuration functions as a variadic argument. -func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() +func Connect(addr, user, password, dbName string, options ...Option) (*Conn, error) { + return ConnectWithTimeout(addr, user, password, dbName, time.Second*10, options...) +} - return ConnectWithContext(ctx, addr, user, password, dbName, options...) +// ConnectWithTimeout to a MySQL address using a timeout. +func ConnectWithTimeout(addr, user, password, dbName string, timeout time.Duration, options ...Option) (*Conn, error) { + return ConnectWithContext(context.Background(), addr, user, password, dbName, time.Second*10, options...) } // ConnectWithContext to a MySQL addr using the provided context. -func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) { - dialer := &net.Dialer{} +func ConnectWithContext(ctx context.Context, addr, user, password, dbName string, timeout time.Duration, options ...Option) (*Conn, error) { + dialer := &net.Dialer{Timeout: timeout} return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...) } @@ -83,7 +91,7 @@ func ConnectWithContext(ctx context.Context, addr string, user string, password type Dialer func(ctx context.Context, network, address string) (net.Conn, error) // ConnectWithDialer to a MySQL server using the given Dialer. -func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) { +func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbName string, dialer Dialer, options ...Option) (*Conn, error) { c := new(Conn) c.attributes = map[string]string{ @@ -108,19 +116,21 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st c.password = password c.db = dbName c.proto = network - c.Conn = packet.NewConn(conn) // use default charset here, utf-8 c.charset = DEFAULT_CHARSET // Apply configuration functions. for i := range options { - options[i](c) + if err := options[i](c); err != nil { + return nil, err + } } + c.Conn = packet.NewConnWithTimeout(conn, c.ReadTimeout, c.WriteTimeout) if c.tlsConfig != nil { seq := c.Conn.Sequence - c.Conn = packet.NewTLSConn(conn) + c.Conn = packet.NewTLSConnWithTimeout(conn, c.ReadTimeout, c.WriteTimeout) c.Conn.Sequence = seq } diff --git a/client/conn_test.go b/client/conn_test.go index e2091d50e..55ea973d6 100644 --- a/client/conn_test.go +++ b/client/conn_test.go @@ -28,10 +28,11 @@ func TestConnSuite(t *testing.T) { func (s *connTestSuite) SetupSuite() { var err error addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) { + s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) error { // required for the ExecuteMultiple test c.SetCapability(mysql.CLIENT_MULTI_STATEMENTS) c.SetAttributes(map[string]string{"attrtest": "attrvalue"}) + return nil }) require.NoError(s.T(), err) diff --git a/client/pool.go b/client/pool.go index 91341a537..6e5d6dc21 100644 --- a/client/pool.go +++ b/client/pool.go @@ -166,7 +166,7 @@ func NewPool( user string, password string, dbName string, - options ...func(conn *Conn), + options ...Option, ) *Pool { pool, err := NewPoolWithOptions( addr, diff --git a/client/pool_options.go b/client/pool_options.go index f47b00716..90bf5bd0d 100644 --- a/client/pool_options.go +++ b/client/pool_options.go @@ -17,7 +17,7 @@ type ( password string dbName string - connOptions []func(conn *Conn) + connOptions []Option newPoolPingTimeout time.Duration } @@ -46,7 +46,7 @@ func WithLogFunc(f LogFunc) PoolOption { } } -func WithConnOptions(options ...func(conn *Conn)) PoolOption { +func WithConnOptions(options ...Option) PoolOption { return func(o *poolOptions) { o.connOptions = append(o.connOptions, options...) } diff --git a/driver/driver.go b/driver/driver.go index 80bde7a49..377b87d5d 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -10,8 +10,8 @@ import ( "io" "net/url" "regexp" - "strconv" "sync" + "time" "github.com/go-mysql-org/go-mysql/client" "github.com/go-mysql-org/go-mysql/mysql" @@ -24,7 +24,7 @@ var customTLSMutex sync.Mutex // Map of dsn address (makes more sense than full dsn?) to tls Config var ( customTLSConfigMap = make(map[string]*tls.Config) - options = make(map[string]connOption) + options = make(map[string]DriverOption) ) type driver struct { @@ -39,8 +39,6 @@ type connInfo struct { params url.Values } -type connOption func(c *client.Conn, value string) - // ParseDSN takes a DSN string and splits it up into struct containing addr, // user, password and db. // It returns an error if unable to parse. @@ -98,7 +96,8 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { } if ci.standardDSN { - configuredOptions := make([]func(*client.Conn), 0, len(ci.params)) + var timeout time.Duration + configuredOptions := make([]client.Option, 0, len(ci.params)) for key, value := range ci.params { if key == "ssl" && len(value) > 0 { tlsConfigName := value[0] @@ -107,22 +106,29 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { // This actually does insecureSkipVerify // But not even sure if it makes sense to handle false? According to // client_test.go it doesn't - it'd result in an error - configuredOptions = append(configuredOptions, func(c *client.Conn) { c.UseSSL(true) }) + configuredOptions = append(configuredOptions, UseSslOption) case "custom": // I was too concerned about mimicking what go-sql-driver/mysql does which will // allow any name for a custom tls profile and maps the query parameter value to // that TLSConfig variable... there is no need to be that clever. // Instead of doing that, let's store required custom TLSConfigs in a map that // uses the DSN address as the key - configuredOptions = append(configuredOptions, func(c *client.Conn) { c.SetTLSConfig(customTLSConfigMap[ci.addr]) }) + configuredOptions = append(configuredOptions, func(c *client.Conn) error { + c.SetTLSConfig(customTLSConfigMap[ci.addr]) + return nil + }) default: return nil, errors.Errorf("Supported options are ssl=true or ssl=custom") } + } else if key == "timeout" && len(value) > 0 { + if timeout, err = time.ParseDuration(value[0]); err != nil { + return nil, errors.Wrap(err, "invalid duration value for timeout option") + } } else { if option, ok := options[key]; ok { - opt := func(o connOption, v string) func(c *client.Conn) { - return func(c *client.Conn) { - o(c, v) + opt := func(o DriverOption, v string) client.Option { + return func(c *client.Conn) error { + return o(c, v) } }(option, value[0]) configuredOptions = append(configuredOptions, opt) @@ -132,7 +138,11 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { } } - c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, configuredOptions...) + if timeout > 0 { + c, err = client.ConnectWithTimeout(ci.addr, ci.user, ci.password, ci.db, timeout, configuredOptions...) + } else { + c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, configuredOptions...) + } } else { // No more processing here. Let's only support url parameters with the newer style DSN c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db) @@ -316,14 +326,10 @@ func (r *rows) Next(dest []sqldriver.Value) error { } func init() { - options["compress"] = func(c *client.Conn, value string) { - if b, err := strconv.ParseBool(value); err == nil && b { - c.SetCapability(mysql.CLIENT_COMPRESS) - } - } - options["collation"] = func(c *client.Conn, value string) { - c.SetCollation(value) - } + options["compress"] = CompressOption + options["collation"] = CollationOption + options["readTimeout"] = ReadTimeoutOption + options["writeTimeout"] = WriteTimeoutOption sql.Register("mysql", driver{}) } diff --git a/driver/driver_options.go b/driver/driver_options.go new file mode 100644 index 000000000..37949a68e --- /dev/null +++ b/driver/driver_options.go @@ -0,0 +1,50 @@ +package driver + +import ( + "strconv" + "time" + + "github.com/go-mysql-org/go-mysql/client" + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/pingcap/errors" +) + +// DriverOption sets configuration on a client connection before the MySQL handshake. +// The value represents the query string parameter value supplied by in the DNS. +type DriverOption func(c *client.Conn, value string) error + +func UseSslOption(c *client.Conn) error { + c.UseSSL(true) + return nil +} + +func CollationOption(c *client.Conn, value string) error { + return c.SetCollation(value) +} + +func ReadTimeoutOption(c *client.Conn, value string) error { + var err error + c.ReadTimeout, err = time.ParseDuration(value) + return errors.Wrap(err, "invalid duration value for readTimeout option") +} + +func WriteTimeoutOption(c *client.Conn, value string) error { + var err error + c.WriteTimeout, err = time.ParseDuration(value) + return errors.Wrap(err, "invalid duration value for writeTimeout option") +} + +func CompressOption(c *client.Conn, value string) error { + var ( + b bool + err error + ) + if b, err = strconv.ParseBool(value); err != nil { + return errors.Errorf("invalid boolean value '%s' for compress option", value) + } + if b { + c.SetCapability(mysql.CLIENT_COMPRESS) + } + + return nil +} diff --git a/driver/driver_test.go b/driver/driver_test.go index 3f7575613..5df71cf5b 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -82,12 +82,16 @@ func TestParseDSN(t *testing.T) { // Use different numbered domains to more readily see what has failed - since we // test in a loop we get the same line number on error testDSNs := map[string]connInfo{ - "user:password@localhost?db": {standardDSN: false, addr: "localhost", user: "user", password: "password", db: "db", params: url.Values{}}, - "user@1.domain.com?db": {standardDSN: false, addr: "1.domain.com", user: "user", password: "", db: "db", params: url.Values{}}, - "user:password@2.domain.com/db": {standardDSN: true, addr: "2.domain.com", user: "user", password: "password", db: "db", params: url.Values{}}, - "user:password@3.domain.com/db?ssl=true": {standardDSN: true, addr: "3.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"true"}}}, - "user:password@4.domain.com/db?ssl=custom": {standardDSN: true, addr: "4.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"custom"}}}, - "user:password@5.domain.com/db?unused=param": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"unused": []string{"param"}}}, + "user:password@localhost?db": {standardDSN: false, addr: "localhost", user: "user", password: "password", db: "db", params: url.Values{}}, + "user@1.domain.com?db": {standardDSN: false, addr: "1.domain.com", user: "user", password: "", db: "db", params: url.Values{}}, + "user:password@2.domain.com/db": {standardDSN: true, addr: "2.domain.com", user: "user", password: "password", db: "db", params: url.Values{}}, + "user:password@3.domain.com/db?ssl=true": {standardDSN: true, addr: "3.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"true"}}}, + "user:password@4.domain.com/db?ssl=custom": {standardDSN: true, addr: "4.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"custom"}}}, + "user:password@5.domain.com/db?unused=param": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"unused": []string{"param"}}}, + "user:password@5.domain.com/db?timeout=1s": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"timeout": []string{"1s"}}}, + "user:password@5.domain.com/db?readTimeout=1m": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"readTimeout": []string{"1m"}}}, + "user:password@5.domain.com/db?writeTimeout=1m": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"writeTimeout": []string{"1m"}}}, + "user:password@5.domain.com/db?compress=true": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"compress": []string{"true"}}}, } for supplied, expected := range testDSNs { diff --git a/packet/conn.go b/packet/conn.go index 6096d4f06..7b6c2614e 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -13,6 +13,7 @@ import ( "io" "net" "sync" + "time" . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/utils" @@ -46,6 +47,8 @@ func (b *BufPool) Return(buf *bytes.Buffer) { // Conn is the base class to handle MySQL protocol. type Conn struct { net.Conn + readTimeout time.Duration + writeTimeout time.Duration // we removed the buffer reader because it will cause the SSLRequest to block (tls connection handshake won't be // able to read the "Client Hello" data since it has been buffered into the buffer reader) @@ -84,6 +87,13 @@ func NewConn(conn net.Conn) *Conn { return c } +func NewConnWithTimeout(conn net.Conn, readTimeout, writeTimeout time.Duration) *Conn { + c := NewConn(conn) + c.readTimeout = readTimeout + c.writeTimeout = writeTimeout + return c +} + func NewTLSConn(conn net.Conn) *Conn { c := new(Conn) c.Conn = conn @@ -96,6 +106,13 @@ func NewTLSConn(conn net.Conn) *Conn { return c } +func NewTLSConnWithTimeout(conn net.Conn, readTimeout, writeTimeout time.Duration) *Conn { + c := NewTLSConn(conn) + c.readTimeout = readTimeout + c.writeTimeout = writeTimeout + return c +} + func (c *Conn) ReadPacket() ([]byte, error) { return c.ReadPacketReuseMem(nil) } @@ -152,6 +169,11 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) { // newCompressedPacketReader creates a new compressed packet reader. func (c *Conn) newCompressedPacketReader() (io.Reader, error) { + if c.readTimeout != 0 { + if err := c.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil { + return nil, err + } + } if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil { return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err) } @@ -197,6 +219,11 @@ func (c *Conn) copyN(dst io.Writer, n int64) (int64, error) { // Call ReadAtLeast with the currentPacketReader as it may change on every iteration // of this loop. + if c.readTimeout != 0 { + if err := c.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil { + return written, err + } + } rd, err := io.ReadAtLeast(c.currentPacketReader(), buf, bcap) n -= int64(rd) @@ -291,7 +318,7 @@ func (c *Conn) WritePacket(data []byte) error { data[3] = c.Sequence - if n, err := c.Write(data[:4+MaxPayloadLen]); err != nil { + if n, err := c.writeWithTimeout(data[:4+MaxPayloadLen]); err != nil { return errors.Wrapf(ErrBadConn, "Write(payload portion) failed. err %v", err) } else if n != (4 + MaxPayloadLen) { return errors.Wrapf(ErrBadConn, "Write(payload portion) failed. only %v bytes written, while %v expected", n, 4+MaxPayloadLen) @@ -309,7 +336,7 @@ func (c *Conn) WritePacket(data []byte) error { switch c.Compression { case MYSQL_COMPRESS_NONE: - if n, err := c.Write(data); err != nil { + if n, err := c.writeWithTimeout(data); err != nil { return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) } else if n != len(data) { return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) @@ -330,6 +357,16 @@ func (c *Conn) WritePacket(data []byte) error { return nil } +func (c *Conn) writeWithTimeout(b []byte) (n int, err error) { + if c.writeTimeout != 0 { + if err := c.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil { + return n, err + } + } + + return c.Write(b) +} + func (c *Conn) writeCompressed(data []byte) (n int, err error) { var compressedLength, uncompressedLength int var payload, compressedPacket bytes.Buffer @@ -388,7 +425,7 @@ func (c *Conn) writeCompressed(data []byte) (n int, err error) { return 0, err } - _, err = c.Write(compressedPacket.Bytes()) + _, err = c.writeWithTimeout(compressedPacket.Bytes()) if err != nil { return 0, err } diff --git a/replication/backup_test.go b/replication/backup_test.go index 1f77e8e3e..abefd3f8d 100644 --- a/replication/backup_test.go +++ b/replication/backup_test.go @@ -38,7 +38,8 @@ func (t *testSyncerSuite) TestStartBackupEndInGivenTime() { done <- true }() failTimeout := 5 * timeout - ctx, _ := context.WithTimeout(context.Background(), failTimeout) + ctx, cancel := context.WithTimeout(context.Background(), failTimeout) + defer cancel() select { case <-done: return diff --git a/replication/binlogsyncer.go b/replication/binlogsyncer.go index 72a22c45c..4c82034c8 100644 --- a/replication/binlogsyncer.go +++ b/replication/binlogsyncer.go @@ -897,12 +897,13 @@ func (b *BinlogSyncer) newConnection(ctx context.Context) (*client.Conn, error) defer cancel() return client.ConnectWithDialer(timeoutCtx, "", addr, b.cfg.User, b.cfg.Password, - "", b.cfg.Dialer, func(c *client.Conn) { + "", b.cfg.Dialer, func(c *client.Conn) error { c.SetTLSConfig(b.cfg.TLSConfig) c.SetAttributes(map[string]string{"_client_role": "binary_log_listener"}) if b.cfg.ReadTimeout > 0 { - _ = c.SetReadDeadline(time.Now().Add(b.cfg.ReadTimeout)) + return c.SetReadDeadline(time.Now().Add(b.cfg.ReadTimeout)) } + return nil }) } From 948ed700974828c891e7200f954d87c8ac6216a9 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Mon, 20 May 2024 19:24:25 -0400 Subject: [PATCH 07/19] revert change to Makefile --- Makefile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 8da0d7081..480b6a874 100644 --- a/Makefile +++ b/Makefile @@ -13,10 +13,9 @@ test: MYSQL_VERSION ?= 8.0 test-local: - docker run --rm --name go-mysql-server \ + docker run --rm -d --network=host --name go-mysql-server \ -e MYSQL_ALLOW_EMPTY_PASSWORD=true \ -e MYSQL_DATABASE=test \ - -p 3306:3306 \ -v $${PWD}/docker/resources/replication.cnf:/etc/mysql/conf.d/replication.cnf \ mysql:$(MYSQL_VERSION) docker/resources/waitfor.sh 127.0.0.1 3306 \ From 92a8ac8b3db9775e58f34240cda6c733aea31483 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Sun, 2 Jun 2024 15:22:07 -0400 Subject: [PATCH 08/19] added tests for timeouts --- driver/driver_options_test.go | 205 ++++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 driver/driver_options_test.go diff --git a/driver/driver_options_test.go b/driver/driver_options_test.go new file mode 100644 index 000000000..57dd24648 --- /dev/null +++ b/driver/driver_options_test.go @@ -0,0 +1,205 @@ +package driver + +import ( + "context" + "database/sql" + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/server" + "github.com/pingcap/errors" + "github.com/siddontang/go/log" + "github.com/stretchr/testify/require" +) + +var _ server.Handler = &mockHandler{} + +type testServer struct { + *server.Server + + listener net.Listener + handler *mockHandler +} + +type mockHandler struct { +} + +func TestDriverOptions_ConnectTimeout(t *testing.T) { + log.SetLevel(log.LevelDebug) + srv := CreateMockServer(t) + defer srv.Stop() + + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?timeout=1s") + require.NoError(t, err) + + rows, err := conn.QueryContext(context.TODO(), "select * from table;") + require.NotNil(t, rows) + require.NoError(t, err) + + conn.Close() +} + +func TestDriverOptions_ReadTimeout(t *testing.T) { + log.SetLevel(log.LevelDebug) + srv := CreateMockServer(t) + defer srv.Stop() + + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?readTimeout=1s") + require.NoError(t, err) + + rows, err := conn.QueryContext(context.TODO(), "select * from slow;") + require.Nil(t, rows) + require.Error(t, err) + + rows, err = conn.QueryContext(context.TODO(), "select * from fast;") + require.NotNil(t, rows) + require.NoError(t, err) + + conn.Close() +} + +func TestDriverOptions_writeTimeout(t *testing.T) { + log.SetLevel(log.LevelDebug) + srv := CreateMockServer(t) + defer srv.Stop() + + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=10") + require.NoError(t, err) + + result, err := conn.ExecContext(context.TODO(), "insert into slow(a,b) values(1,2);") + require.Nil(t, result) + require.Error(t, err) + + conn.Close() +} + +func CreateMockServer(t *testing.T) *testServer { + inMemProvider := server.NewInMemoryProvider() + inMemProvider.AddUser(*testUser, *testPassword) + defaultServer := server.NewDefaultServer() + + l, err := net.Listen("tcp", "127.0.0.1:3307") + require.NoError(t, err) + + handler := &mockHandler{} + + go func() { + for { + conn, err := l.Accept() + if err != nil { + return + } + + go func() { + co, err := server.NewCustomizedConn(conn, defaultServer, inMemProvider, handler) + require.NoError(t, err) + for { + err = co.HandleCommand() + if err != nil { + return + } + } + }() + } + }() + + return &testServer{ + Server: defaultServer, + listener: l, + handler: handler, + } +} + +func (s *testServer) Stop() { + s.listener.Close() +} + +func (h *mockHandler) UseDB(dbName string) error { + return nil +} + +func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, error) { + ss := strings.Split(query, " ") + switch strings.ToLower(ss[0]) { + case "select": + var r *mysql.Resultset + var err error + //for handle go mysql driver select @@max_allowed_packet + if strings.Contains(strings.ToLower(query), "max_allowed_packet") { + r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{ + {mysql.MaxPayloadLen}, + }, binary) + } else { + if strings.Contains(query, "slow") { + time.Sleep(time.Second * 5) + } + + r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{ + {1, "hello world"}, + }, binary) + } + + if err != nil { + return nil, errors.Trace(err) + } else { + return &mysql.Result{ + Status: 0, + Warnings: 0, + InsertId: 0, + AffectedRows: 0, + Resultset: r, + }, nil + } + case "insert": + return &mysql.Result{ + Status: 0, + Warnings: 0, + InsertId: 1, + AffectedRows: 0, + Resultset: nil, + }, nil + default: + return nil, fmt.Errorf("invalid query %s", query) + } +} + +func (h *mockHandler) HandleQuery(query string) (*mysql.Result, error) { + return h.handleQuery(query, false) +} + +func (h *mockHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) { + return nil, nil +} + +func (h *mockHandler) HandleStmtPrepare(query string) (params int, columns int, context interface{}, err error) { + params = 1 + columns = 0 + return params, columns, nil, nil +} + +func (h *mockHandler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*mysql.Result, error) { + + if strings.HasPrefix(strings.ToLower(query), "select") { + return h.HandleQuery(query) + } + + return &mysql.Result{ + Status: 0, + Warnings: 0, + InsertId: 1, + AffectedRows: 0, + Resultset: nil, + }, nil +} + +func (h *mockHandler) HandleStmtClose(context interface{}) error { + return nil +} + +func (h *mockHandler) HandleOtherCommand(cmd byte, data []byte) error { + return nil +} From 74f45e1bd0d0e0db0a3ee03d65b9f00c42b6fa5c Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Sun, 2 Jun 2024 21:00:30 -0400 Subject: [PATCH 09/19] adding more tests --- client/conn.go | 5 +++++ driver/driver_options.go | 2 ++ driver/driver_options_test.go | 18 ++++++++++++++++++ 3 files changed, 25 insertions(+) diff --git a/client/conn.go b/client/conn.go index 6ff61a786..9d8392053 100644 --- a/client/conn.go +++ b/client/conn.go @@ -216,6 +216,11 @@ func (c *Conn) UnsetCapability(cap uint32) { c.ccaps &= ^cap } +// HasCapability returns true if the connection has the specific capability +func (c *Conn) HasCapability(cap uint32) bool { + return c.ccaps&cap > 0 +} + // UseSSL: use default SSL // pass to options when connect func (c *Conn) UseSSL(insecureSkipVerify bool) { diff --git a/driver/driver_options.go b/driver/driver_options.go index 37949a68e..a98d47b1e 100644 --- a/driver/driver_options.go +++ b/driver/driver_options.go @@ -44,6 +44,8 @@ func CompressOption(c *client.Conn, value string) error { } if b { c.SetCapability(mysql.CLIENT_COMPRESS) + } else { + c.UnsetCapability(mysql.CLIENT_COMPRESS) } return nil diff --git a/driver/driver_options_test.go b/driver/driver_options_test.go index 57dd24648..b12af20ff 100644 --- a/driver/driver_options_test.go +++ b/driver/driver_options_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/go-mysql-org/go-mysql/client" "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/server" "github.com/pingcap/errors" @@ -28,6 +29,23 @@ type testServer struct { type mockHandler struct { } +func TestDriverOptions_SetCollation(t *testing.T) { + c := &client.Conn{} + CollationOption(c, "latin2_bin") + require.Equal(t, "latin2_bin", c.GetCollation()) +} + +func TestDriverOptions_SetCompression(t *testing.T) { + c := &client.Conn{} + CompressOption(c, "true") + require.True(t, c.HasCapability(mysql.CLIENT_COMPRESS)) + + CompressOption(c, "false") + require.False(t, c.HasCapability(mysql.CLIENT_COMPRESS)) + + require.Error(t, CompressOption(c, "foo")) +} + func TestDriverOptions_ConnectTimeout(t *testing.T) { log.SetLevel(log.LevelDebug) srv := CreateMockServer(t) From a1e459f09dff5cf84961d44b1c33d3725ccfb79f Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Sun, 2 Jun 2024 21:24:58 -0400 Subject: [PATCH 10/19] fixing linting issues --- client/conn.go | 4 ++-- driver/driver_options_test.go | 11 +++++++---- mysql/error.go | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/client/conn.go b/client/conn.go index 9d8392053..568ca16ca 100644 --- a/client/conn.go +++ b/client/conn.go @@ -121,8 +121,8 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam c.charset = DEFAULT_CHARSET // Apply configuration functions. - for i := range options { - if err := options[i](c); err != nil { + for _, option := range options { + if err := option(c); err != nil { return nil, err } } diff --git a/driver/driver_options_test.go b/driver/driver_options_test.go index b12af20ff..23199ddfa 100644 --- a/driver/driver_options_test.go +++ b/driver/driver_options_test.go @@ -31,16 +31,20 @@ type mockHandler struct { func TestDriverOptions_SetCollation(t *testing.T) { c := &client.Conn{} - CollationOption(c, "latin2_bin") + err := CollationOption(c, "latin2_bin") + require.NoError(t, err) require.Equal(t, "latin2_bin", c.GetCollation()) } func TestDriverOptions_SetCompression(t *testing.T) { + var err error c := &client.Conn{} - CompressOption(c, "true") + err = CompressOption(c, "true") + require.NoError(t, err) require.True(t, c.HasCapability(mysql.CLIENT_COMPRESS)) - CompressOption(c, "false") + err = CompressOption(c, "false") + require.NoError(t, err) require.False(t, c.HasCapability(mysql.CLIENT_COMPRESS)) require.Error(t, CompressOption(c, "foo")) @@ -200,7 +204,6 @@ func (h *mockHandler) HandleStmtPrepare(query string) (params int, columns int, } func (h *mockHandler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*mysql.Result, error) { - if strings.HasPrefix(strings.ToLower(query), "select") { return h.HandleQuery(query) } diff --git a/mysql/error.go b/mysql/error.go index abda6dea0..e9915779b 100644 --- a/mysql/error.go +++ b/mysql/error.go @@ -61,6 +61,6 @@ func NewError(errCode uint16, message string) *MyError { func ErrorCode(errMsg string) (code int) { var tmpStr string // golang scanf doesn't support %*,so I used a temporary variable - fmt.Sscanf(errMsg, "%s%d", &tmpStr, &code) + _, _ = fmt.Sscanf(errMsg, "%s%d", &tmpStr, &code) return } From e9f8359f3d6c5ae9911afd11e1cf779aec4244d4 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Sun, 2 Jun 2024 21:34:03 -0400 Subject: [PATCH 11/19] avoiding panic on test complete --- driver/driver_options_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/driver/driver_options_test.go b/driver/driver_options_test.go index 23199ddfa..c1dc7b488 100644 --- a/driver/driver_options_test.go +++ b/driver/driver_options_test.go @@ -118,7 +118,9 @@ func CreateMockServer(t *testing.T) *testServer { go func() { co, err := server.NewCustomizedConn(conn, defaultServer, inMemProvider, handler) - require.NoError(t, err) + if err != nil { + return + } for { err = co.HandleCommand() if err != nil { From e48458f0ce5a0c013da02dc92b0f74cd8d8bf368 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Sun, 2 Jun 2024 21:40:53 -0400 Subject: [PATCH 12/19] revert returning set readtimeout error in binlogsyncer --- replication/binlogsyncer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replication/binlogsyncer.go b/replication/binlogsyncer.go index 4c82034c8..89772736d 100644 --- a/replication/binlogsyncer.go +++ b/replication/binlogsyncer.go @@ -901,7 +901,7 @@ func (b *BinlogSyncer) newConnection(ctx context.Context) (*client.Conn, error) c.SetTLSConfig(b.cfg.TLSConfig) c.SetAttributes(map[string]string{"_client_role": "binary_log_listener"}) if b.cfg.ReadTimeout > 0 { - return c.SetReadDeadline(time.Now().Add(b.cfg.ReadTimeout)) + _ = c.SetReadDeadline(time.Now().Add(b.cfg.ReadTimeout)) } return nil }) From 1c4d22b075ac69c3c7aa73087c3dc84770945a81 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Sun, 2 Jun 2024 21:50:14 -0400 Subject: [PATCH 13/19] fixing nil violation when connection with timeout from binlogsyncer --- replication/binlogsyncer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replication/binlogsyncer.go b/replication/binlogsyncer.go index 89772736d..39e5749ea 100644 --- a/replication/binlogsyncer.go +++ b/replication/binlogsyncer.go @@ -901,7 +901,7 @@ func (b *BinlogSyncer) newConnection(ctx context.Context) (*client.Conn, error) c.SetTLSConfig(b.cfg.TLSConfig) c.SetAttributes(map[string]string{"_client_role": "binary_log_listener"}) if b.cfg.ReadTimeout > 0 { - _ = c.SetReadDeadline(time.Now().Add(b.cfg.ReadTimeout)) + c.ReadTimeout = b.cfg.ReadTimeout } return nil }) From 09fe1b416eea5019d37972538231f7d5ab6aa568 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Tue, 4 Jun 2024 07:19:56 -0400 Subject: [PATCH 14/19] Update README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Daniƫl van Eeden <github@myname.nl> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 74cf3b077..3a2735ccd 100644 --- a/README.md +++ b/README.md @@ -378,7 +378,7 @@ Enable zlib compression between the client and the server. #### `readTimeout` -I/O read timeout. +I/O read timeout. 0 means no timeout. | Type | Default | Example | | --------- | --------- | ------------------------------------------- | From 19f5963e9966ab60c6b73088bd164584c43495c1 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Tue, 4 Jun 2024 07:30:36 -0400 Subject: [PATCH 15/19] addressing pull request feedback --- README.md | 22 ++++++++++------------ driver/driver.go | 5 +++-- driver/driver_options.go | 18 ++++++++---------- driver/driver_options_test.go | 9 +++++++-- driver/driver_test.go | 4 +++- 5 files changed, 31 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 3a2735ccd..b904e4ab3 100644 --- a/README.md +++ b/README.md @@ -370,11 +370,11 @@ Configuration options can be provided by the standard DSN (Data Source Name). #### `compress` -Enable zlib compression between the client and the server. +Enable compression between the client and the server. Valid values are 'zstd','zlib','uncompressed'. -| Type | Default | Example | -| --------- | --------- | ------------------------------------------- | -| boolean | false | user:pass@localhost/mydb?compress=true | +| Type | Default | Example | +| --------- | ------------- | --------------------------------------- | +| string | uncompressed | user:pass@localhost/mydb?compress=zlib | #### `readTimeout` @@ -384,18 +384,17 @@ I/O read timeout. 0 means no timeout. | --------- | --------- | ------------------------------------------- | | duration | 0 | user:pass@localhost/mydb?readTimeout=10s | -#### `ssl` +#### `tls` -Enable TLS between client and server. +Enable TLS between client and server. Valid values are `true` or `custom`. -| Type | Default | Allowed Values | +| Type | Default | Example | | --------- | --------- | ------------------------------------------- | -| string | | `true` or `custom` | - +| string | | user:pass@localhost/mydb?tls=true | #### `timeout` -Timeout is the maximum amount of time a dial will wait for a connect to complete. +Timeout is the maximum amount of time a dial will wait for a connect to complete. 0 means no timeout. | Type | Default | Example | | --------- | --------- | ------------------------------------------- | @@ -403,7 +402,7 @@ Timeout is the maximum amount of time a dial will wait for a connect to complete #### `writeTimeout` -I/O write timeout. +I/O write timeout. 0 means no timeout. | Type | Default | Example | | --------- | --------- | ----------------------------------------------- | @@ -411,7 +410,6 @@ I/O write timeout. - We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-) ## Donate diff --git a/driver/driver.go b/driver/driver.go index 377b87d5d..2124e57e9 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -99,7 +99,8 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { var timeout time.Duration configuredOptions := make([]client.Option, 0, len(ci.params)) for key, value := range ci.params { - if key == "ssl" && len(value) > 0 { + // the key ssl has been deprecated in favor of tls + if (key == "ssl" || key == "tls") && len(value) > 0 { tlsConfigName := value[0] switch tlsConfigName { case "true": @@ -118,7 +119,7 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { return nil }) default: - return nil, errors.Errorf("Supported options are ssl=true or ssl=custom") + return nil, errors.Errorf("Supported options are tls=true or tls=custom") } } else if key == "timeout" && len(value) > 0 { if timeout, err = time.ParseDuration(value[0]); err != nil { diff --git a/driver/driver_options.go b/driver/driver_options.go index a98d47b1e..605e68f81 100644 --- a/driver/driver_options.go +++ b/driver/driver_options.go @@ -1,7 +1,6 @@ package driver import ( - "strconv" "time" "github.com/go-mysql-org/go-mysql/client" @@ -35,17 +34,16 @@ func WriteTimeoutOption(c *client.Conn, value string) error { } func CompressOption(c *client.Conn, value string) error { - var ( - b bool - err error - ) - if b, err = strconv.ParseBool(value); err != nil { - return errors.Errorf("invalid boolean value '%s' for compress option", value) - } - if b { + switch value { + case "zlib": c.SetCapability(mysql.CLIENT_COMPRESS) - } else { + case "zstd": + c.SetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM) + case "uncompressed": c.UnsetCapability(mysql.CLIENT_COMPRESS) + c.UnsetCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM) + default: + return errors.Errorf("invalid compression algorithm '%s', valid values are 'zstd','zlib','uncompressed'", value) } return nil diff --git a/driver/driver_options_test.go b/driver/driver_options_test.go index c1dc7b488..32431932a 100644 --- a/driver/driver_options_test.go +++ b/driver/driver_options_test.go @@ -39,13 +39,18 @@ func TestDriverOptions_SetCollation(t *testing.T) { func TestDriverOptions_SetCompression(t *testing.T) { var err error c := &client.Conn{} - err = CompressOption(c, "true") + err = CompressOption(c, "zlib") require.NoError(t, err) require.True(t, c.HasCapability(mysql.CLIENT_COMPRESS)) - err = CompressOption(c, "false") + err = CompressOption(c, "zstd") + require.NoError(t, err) + require.True(t, c.HasCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM)) + + err = CompressOption(c, "uncompressed") require.NoError(t, err) require.False(t, c.HasCapability(mysql.CLIENT_COMPRESS)) + require.False(t, c.HasCapability(mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM)) require.Error(t, CompressOption(c, "foo")) } diff --git a/driver/driver_test.go b/driver/driver_test.go index 5df71cf5b..c345cb2e0 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -86,12 +86,14 @@ func TestParseDSN(t *testing.T) { "user@1.domain.com?db": {standardDSN: false, addr: "1.domain.com", user: "user", password: "", db: "db", params: url.Values{}}, "user:password@2.domain.com/db": {standardDSN: true, addr: "2.domain.com", user: "user", password: "password", db: "db", params: url.Values{}}, "user:password@3.domain.com/db?ssl=true": {standardDSN: true, addr: "3.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"true"}}}, + "user:password@3.domain.com/db?tls=true": {standardDSN: true, addr: "3.domain.com", user: "user", password: "password", db: "db", params: url.Values{"tls": []string{"true"}}}, "user:password@4.domain.com/db?ssl=custom": {standardDSN: true, addr: "4.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"custom"}}}, + "user:password@4.domain.com/db?tls=custom": {standardDSN: true, addr: "4.domain.com", user: "user", password: "password", db: "db", params: url.Values{"tls": []string{"custom"}}}, "user:password@5.domain.com/db?unused=param": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"unused": []string{"param"}}}, "user:password@5.domain.com/db?timeout=1s": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"timeout": []string{"1s"}}}, "user:password@5.domain.com/db?readTimeout=1m": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"readTimeout": []string{"1m"}}}, "user:password@5.domain.com/db?writeTimeout=1m": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"writeTimeout": []string{"1m"}}}, - "user:password@5.domain.com/db?compress=true": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"compress": []string{"true"}}}, + "user:password@5.domain.com/db?compress=zlib": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"compress": []string{"zlib"}}}, } for supplied, expected := range testDSNs { From 4124a7ec5562f6cac58d813b456dcc18220cbf23 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Tue, 4 Jun 2024 13:55:07 -0400 Subject: [PATCH 16/19] revert rename driver arg ssl to tls --- README.md | 4 ++-- driver/driver.go | 5 ++--- driver/driver_test.go | 2 -- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index b904e4ab3..fd1ce6055 100644 --- a/README.md +++ b/README.md @@ -384,13 +384,13 @@ I/O read timeout. 0 means no timeout. | --------- | --------- | ------------------------------------------- | | duration | 0 | user:pass@localhost/mydb?readTimeout=10s | -#### `tls` +#### `ssl` Enable TLS between client and server. Valid values are `true` or `custom`. | Type | Default | Example | | --------- | --------- | ------------------------------------------- | -| string | | user:pass@localhost/mydb?tls=true | +| string | | user:pass@localhost/mydb?ssl=true | #### `timeout` diff --git a/driver/driver.go b/driver/driver.go index 2124e57e9..377b87d5d 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -99,8 +99,7 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { var timeout time.Duration configuredOptions := make([]client.Option, 0, len(ci.params)) for key, value := range ci.params { - // the key ssl has been deprecated in favor of tls - if (key == "ssl" || key == "tls") && len(value) > 0 { + if key == "ssl" && len(value) > 0 { tlsConfigName := value[0] switch tlsConfigName { case "true": @@ -119,7 +118,7 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) { return nil }) default: - return nil, errors.Errorf("Supported options are tls=true or tls=custom") + return nil, errors.Errorf("Supported options are ssl=true or ssl=custom") } } else if key == "timeout" && len(value) > 0 { if timeout, err = time.ParseDuration(value[0]); err != nil { diff --git a/driver/driver_test.go b/driver/driver_test.go index c345cb2e0..1c21bd0e3 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -86,9 +86,7 @@ func TestParseDSN(t *testing.T) { "user@1.domain.com?db": {standardDSN: false, addr: "1.domain.com", user: "user", password: "", db: "db", params: url.Values{}}, "user:password@2.domain.com/db": {standardDSN: true, addr: "2.domain.com", user: "user", password: "password", db: "db", params: url.Values{}}, "user:password@3.domain.com/db?ssl=true": {standardDSN: true, addr: "3.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"true"}}}, - "user:password@3.domain.com/db?tls=true": {standardDSN: true, addr: "3.domain.com", user: "user", password: "password", db: "db", params: url.Values{"tls": []string{"true"}}}, "user:password@4.domain.com/db?ssl=custom": {standardDSN: true, addr: "4.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"custom"}}}, - "user:password@4.domain.com/db?tls=custom": {standardDSN: true, addr: "4.domain.com", user: "user", password: "password", db: "db", params: url.Values{"tls": []string{"custom"}}}, "user:password@5.domain.com/db?unused=param": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"unused": []string{"param"}}}, "user:password@5.domain.com/db?timeout=1s": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"timeout": []string{"1s"}}}, "user:password@5.domain.com/db?readTimeout=1m": {standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"readTimeout": []string{"1m"}}}, From b353467a1d7ab8fb5e22c5600e5275a6bfacdbf1 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Wed, 5 Jun 2024 07:31:47 -0400 Subject: [PATCH 17/19] addressing PR feedback --- README.md | 18 ++++++++++++++---- client/conn.go | 5 +++++ driver/driver.go | 16 ++++++++++++++++ 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index fd1ce6055..f199c059f 100644 --- a/README.md +++ b/README.md @@ -378,7 +378,10 @@ Enable compression between the client and the server. Valid values are 'zstd','z #### `readTimeout` -I/O read timeout. 0 means no timeout. +I/O read timeout. The time unit is specified in the argument value using +golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format. + +0 means no timeout. | Type | Default | Example | | --------- | --------- | ------------------------------------------- | @@ -386,7 +389,8 @@ I/O read timeout. 0 means no timeout. #### `ssl` -Enable TLS between client and server. Valid values are `true` or `custom`. +Enable TLS between client and server. Valid values are `true` or `custom`. When using `custom`, +the connection will use the TLS configuration set by SetCustomTLSConfig matching the host. | Type | Default | Example | | --------- | --------- | ------------------------------------------- | @@ -394,7 +398,10 @@ Enable TLS between client and server. Valid values are `true` or `custom`. #### `timeout` -Timeout is the maximum amount of time a dial will wait for a connect to complete. 0 means no timeout. +Timeout is the maximum amount of time a dial will wait for a connect to complete. +The time unit is specified in the argument value using golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format. + +0 means no timeout. | Type | Default | Example | | --------- | --------- | ------------------------------------------- | @@ -402,7 +409,10 @@ Timeout is the maximum amount of time a dial will wait for a connect to complete #### `writeTimeout` -I/O write timeout. 0 means no timeout. +I/O write timeout. The time unit is specified in the argument value using +golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format. + +0 means no timeout. | Type | Default | Example | | --------- | --------- | ----------------------------------------------- | diff --git a/client/conn.go b/client/conn.go index 568ca16ca..c7be06b85 100644 --- a/client/conn.go +++ b/client/conn.go @@ -123,6 +123,8 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam // Apply configuration functions. for _, option := range options { if err := option(c); err != nil { + // must close the connection in the event the provided configuration is not valid + _ = conn.Close() return nil, err } } @@ -135,6 +137,7 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam } if err = c.handshake(); err != nil { + // in the event of an error c.handshake() will close the connection return nil, errors.Trace(err) } @@ -149,11 +152,13 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam if len(c.collation) != 0 { collation, err := charset.GetCollationByName(c.collation) if err != nil { + c.Close() return nil, errors.Trace(fmt.Errorf("invalid collation name %s", c.collation)) } if collation.ID > 255 { if _, err := c.exec(fmt.Sprintf("SET NAMES %s COLLATE %s", c.charset, c.collation)); err != nil { + c.Close() return nil, errors.Trace(err) } } diff --git a/driver/driver.go b/driver/driver.go index 377b87d5d..adc04f860 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -359,3 +359,19 @@ func SetCustomTLSConfig(dsn string, caPem []byte, certPem []byte, keyPem []byte, return nil } + +// SetDsnOptions sets custom options to the driver that allows modifications to the connection. +// It requires a full import of the driver (not by side-effects only). +// Example of supplying a custom option: +// +// driver.SetDsnOptions(map[string]DriverOption{ +// "my_option": func(c *client.Conn, value string) error { +// c.SetCapability(mysql.CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS) +// return nil +// }, +// }) +func SetDsnOptions(customOptions map[string]DriverOption) { + for o, f := range customOptions { + options[o] = f + } +} From 3ef00d7621f844ae8b8b15e056bac267f8cadab9 Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Wed, 5 Jun 2024 13:07:09 -0400 Subject: [PATCH 18/19] write compressed packet using writeWithTimeout --- packet/conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packet/conn.go b/packet/conn.go index e12a316a1..9901e34be 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -412,7 +412,7 @@ func (c *Conn) writeCompressed(data []byte) (n int, err error) { if err != nil { return 0, err } - if _, err = c.Write(compressedPacket.Bytes()); err != nil { + if _, err = c.writeWithTimeout(compressedPacket.Bytes()); err != nil { return 0, err } From 2934745237d0580e11da0a983d5513b8c5c2cefc Mon Sep 17 00:00:00 2001 From: David Vilaverde <dvilaverde@gmail.com> Date: Thu, 6 Jun 2024 06:29:31 -0400 Subject: [PATCH 19/19] updated README.md --- README.md | 29 +++++++++++++++++++++++++++++ driver/driver.go | 6 +++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f199c059f..e108469e0 100644 --- a/README.md +++ b/README.md @@ -418,6 +418,35 @@ golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format. | --------- | --------- | ----------------------------------------------- | | duration | 0 | user:pass@localhost/mydb?writeTimeout=1m30s | +### Custom Driver Options + +The driver package exposes the function `SetDSNOptions`, allowing for modification of the +connection by adding custom driver options. +It requires a full import of the driver (not by side-effects only). + +Example of defining a custom option: + +```golang +import ( + "database/sql" + + "github.com/go-mysql-org/go-mysql/driver" +) + +func main() { + driver.SetDSNOptions(map[string]DriverOption{ + "no_metadata": func(c *client.Conn, value string) error { + c.SetCapability(mysql.CLIENT_OPTIONAL_RESULTSET_METADATA) + return nil + }, + }) + + // dsn format: "user:password@addr/dbname?" + dsn := "root@127.0.0.1:3306/test?no_metadata=true" + db, _ := sql.Open(dsn) + db.Close() +} +``` We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-) diff --git a/driver/driver.go b/driver/driver.go index adc04f860..8f132d2b3 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -360,17 +360,17 @@ func SetCustomTLSConfig(dsn string, caPem []byte, certPem []byte, keyPem []byte, return nil } -// SetDsnOptions sets custom options to the driver that allows modifications to the connection. +// SetDSNOptions sets custom options to the driver that allows modifications to the connection. // It requires a full import of the driver (not by side-effects only). // Example of supplying a custom option: // -// driver.SetDsnOptions(map[string]DriverOption{ +// driver.SetDSNOptions(map[string]DriverOption{ // "my_option": func(c *client.Conn, value string) error { // c.SetCapability(mysql.CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS) // return nil // }, // }) -func SetDsnOptions(customOptions map[string]DriverOption) { +func SetDSNOptions(customOptions map[string]DriverOption) { for o, f := range customOptions { options[o] = f }