Skip to content

Commit b13191f

Browse files
dvilaverdedvilaverdedveeden
authored
Additional Driver args for compression and connection read/write timeouts (#885)
* allow setting the collation in auth handshake * Allow connect with context in order to provide configurable connect timeouts * add driver arguments * check for empty ssl value when setting conn options * allow setting the collation in auth handshake (#860) * Allow connect with context in order to provide configurable connect timeouts * support collations IDs greater than 255 on the auth handshake --------- Co-authored-by: dvilaverde <dvilaverde@adobe.com> * refactored and added more driver args * revert change to Makefile * added tests for timeouts * adding more tests * fixing linting issues * avoiding panic on test complete * revert returning set readtimeout error in binlogsyncer * fixing nil violation when connection with timeout from binlogsyncer * Update README.md Co-authored-by: Daniël van Eeden <github@myname.nl> * addressing pull request feedback * revert rename driver arg ssl to tls * addressing PR feedback * write compressed packet using writeWithTimeout * updated README.md --------- Co-authored-by: dvilaverde <dvilaverde@adobe.com> Co-authored-by: Daniël van Eeden <github@myname.nl>
1 parent 6c99b4b commit b13191f

15 files changed

+551
-58
lines changed

README.md

+89
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,95 @@ func main() {
360360
}
361361
```
362362

363+
### Driver Options
364+
365+
Configuration options can be provided by the standard DSN (Data Source Name).
366+
367+
```
368+
[user[:password]@]addr[/db[?param=X]]
369+
```
370+
371+
#### `compress`
372+
373+
Enable compression between the client and the server. Valid values are 'zstd','zlib','uncompressed'.
374+
375+
| Type | Default | Example |
376+
| --------- | ------------- | --------------------------------------- |
377+
| string | uncompressed | user:pass@localhost/mydb?compress=zlib |
378+
379+
#### `readTimeout`
380+
381+
I/O read timeout. The time unit is specified in the argument value using
382+
golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format.
383+
384+
0 means no timeout.
385+
386+
| Type | Default | Example |
387+
| --------- | --------- | ------------------------------------------- |
388+
| duration | 0 | user:pass@localhost/mydb?readTimeout=10s |
389+
390+
#### `ssl`
391+
392+
Enable TLS between client and server. Valid values are `true` or `custom`. When using `custom`,
393+
the connection will use the TLS configuration set by SetCustomTLSConfig matching the host.
394+
395+
| Type | Default | Example |
396+
| --------- | --------- | ------------------------------------------- |
397+
| string | | user:pass@localhost/mydb?ssl=true |
398+
399+
#### `timeout`
400+
401+
Timeout is the maximum amount of time a dial will wait for a connect to complete.
402+
The time unit is specified in the argument value using golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format.
403+
404+
0 means no timeout.
405+
406+
| Type | Default | Example |
407+
| --------- | --------- | ------------------------------------------- |
408+
| duration | 0 | user:pass@localhost/mydb?timeout=1m |
409+
410+
#### `writeTimeout`
411+
412+
I/O write timeout. The time unit is specified in the argument value using
413+
golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format.
414+
415+
0 means no timeout.
416+
417+
| Type | Default | Example |
418+
| --------- | --------- | ----------------------------------------------- |
419+
| duration | 0 | user:pass@localhost/mydb?writeTimeout=1m30s |
420+
421+
### Custom Driver Options
422+
423+
The driver package exposes the function `SetDSNOptions`, allowing for modification of the
424+
connection by adding custom driver options.
425+
It requires a full import of the driver (not by side-effects only).
426+
427+
Example of defining a custom option:
428+
429+
```golang
430+
import (
431+
"database/sql"
432+
433+
"github.com/go-mysql-org/go-mysql/driver"
434+
)
435+
436+
func main() {
437+
driver.SetDSNOptions(map[string]DriverOption{
438+
"no_metadata": func(c *client.Conn, value string) error {
439+
c.SetCapability(mysql.CLIENT_OPTIONAL_RESULTSET_METADATA)
440+
return nil
441+
},
442+
})
443+
444+
// dsn format: "user:password@addr/dbname?"
445+
dsn := "root@127.0.0.1:3306/test?no_metadata=true"
446+
db, _ := sql.Open(dsn)
447+
db.Close()
448+
}
449+
```
450+
451+
363452
We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-)
364453

365454
## Donate

canal/canal.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ func (c *Canal) prepareSyncer() error {
499499
return nil
500500
}
501501

502-
func (c *Canal) connect(options ...func(*client.Conn)) (*client.Conn, error) {
502+
func (c *Canal) connect(options ...client.Option) (*client.Conn, error) {
503503
ctx, cancel := context.WithTimeout(c.ctx, time.Second*10)
504504
defer cancel()
505505

@@ -511,10 +511,11 @@ func (c *Canal) connect(options ...func(*client.Conn)) (*client.Conn, error) {
511511
func (c *Canal) Execute(cmd string, args ...interface{}) (rr *mysql.Result, err error) {
512512
c.connLock.Lock()
513513
defer c.connLock.Unlock()
514-
argF := make([]func(*client.Conn), 0)
514+
argF := make([]client.Option, 0)
515515
if c.cfg.TLSConfig != nil {
516-
argF = append(argF, func(conn *client.Conn) {
516+
argF = append(argF, func(conn *client.Conn) error {
517517
conn.SetTLSConfig(c.cfg.TLSConfig)
518+
return nil
518519
})
519520
}
520521

client/client_test.go

+12-8
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ func TestClientSuite(t *testing.T) {
3131
func (s *clientTestSuite) SetupSuite() {
3232
var err error
3333
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
34-
s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
34+
s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error {
3535
// test the collation logic, but this is essentially a no-op since
3636
// the collation set is the default value
37-
_ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME)
37+
return conn.SetCollation(mysql.DEFAULT_COLLATION_NAME)
3838
})
3939
require.NoError(s.T(), err)
4040

@@ -91,8 +91,9 @@ func (s *clientTestSuite) TestConn_Ping() {
9191

9292
func (s *clientTestSuite) TestConn_Compress() {
9393
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
94-
conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
94+
conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error {
9595
conn.SetCapability(mysql.CLIENT_COMPRESS)
96+
return nil
9697
})
9798
require.NoError(s.T(), err)
9899

@@ -142,8 +143,9 @@ func (s *clientTestSuite) TestConn_TLS_Verify() {
142143
// Verify that the provided tls.Config is used when attempting to connect to mysql.
143144
// An empty tls.Config will result in a connection error.
144145
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
145-
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) {
146+
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) error {
146147
c.UseSSL(false)
148+
return nil
147149
})
148150
expected := "either ServerName or InsecureSkipVerify must be specified in the tls.Config"
149151

@@ -153,8 +155,9 @@ func (s *clientTestSuite) TestConn_TLS_Verify() {
153155
func (s *clientTestSuite) TestConn_TLS_Skip_Verify() {
154156
// An empty tls.Config will result in a connection error but we can configure to skip it.
155157
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
156-
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) {
158+
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) error {
157159
c.UseSSL(true)
160+
return nil
158161
})
159162
require.NoError(s.T(), err)
160163
}
@@ -165,8 +168,9 @@ func (s *clientTestSuite) TestConn_TLS_Certificate() {
165168
// "x509: certificate is valid for MySQL_Server_8.0.12_Auto_Generated_Server_Certificate, not not-a-valid-name"
166169
tlsConfig := NewClientTLSConfig(test_keys.CaPem, test_keys.CertPem, test_keys.KeyPem, false, "not-a-valid-name")
167170
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
168-
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) {
171+
_, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) error {
169172
c.SetTLSConfig(tlsConfig)
173+
return nil
170174
})
171175
require.Error(s.T(), err)
172176
if !strings.Contains(errors.ErrorStack(err), "certificate is not valid for any names") &&
@@ -251,9 +255,9 @@ func (s *clientTestSuite) TestConn_SetCollationAfterConnect() {
251255

252256
func (s *clientTestSuite) TestConn_SetCollation() {
253257
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
254-
_, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
258+
_, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error {
255259
// test the collation logic
256-
_ = conn.SetCollation("invalid_collation")
260+
return conn.SetCollation("invalid_collation")
257261
})
258262

259263
require.Error(s.T(), err)

client/conn.go

+31-11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import (
1818
"github.com/go-mysql-org/go-mysql/utils"
1919
)
2020

21+
type Option func(*Conn) error
22+
2123
type Conn struct {
2224
*packet.Conn
2325

@@ -27,6 +29,10 @@ type Conn struct {
2729
tlsConfig *tls.Config
2830
proto string
2931

32+
// Connection read and write timeouts to set on the connection
33+
ReadTimeout time.Duration
34+
WriteTimeout time.Duration
35+
3036
serverVersion string
3137
// server capabilities
3238
capability uint32
@@ -66,24 +72,26 @@ func getNetProto(addr string) string {
6672

6773
// Connect to a MySQL server, addr can be ip:port, or a unix socket domain like /var/sock.
6874
// Accepts a series of configuration functions as a variadic argument.
69-
func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
70-
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
71-
defer cancel()
75+
func Connect(addr, user, password, dbName string, options ...Option) (*Conn, error) {
76+
return ConnectWithTimeout(addr, user, password, dbName, time.Second*10, options...)
77+
}
7278

73-
return ConnectWithContext(ctx, addr, user, password, dbName, options...)
79+
// ConnectWithTimeout to a MySQL address using a timeout.
80+
func ConnectWithTimeout(addr, user, password, dbName string, timeout time.Duration, options ...Option) (*Conn, error) {
81+
return ConnectWithContext(context.Background(), addr, user, password, dbName, time.Second*10, options...)
7482
}
7583

7684
// ConnectWithContext to a MySQL addr using the provided context.
77-
func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
78-
dialer := &net.Dialer{}
85+
func ConnectWithContext(ctx context.Context, addr, user, password, dbName string, timeout time.Duration, options ...Option) (*Conn, error) {
86+
dialer := &net.Dialer{Timeout: timeout}
7987
return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...)
8088
}
8189

8290
// Dialer connects to the address on the named network using the provided context.
8391
type Dialer func(ctx context.Context, network, address string) (net.Conn, error)
8492

8593
// ConnectWithDialer to a MySQL server using the given Dialer.
86-
func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) {
94+
func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbName string, dialer Dialer, options ...Option) (*Conn, error) {
8795
c := new(Conn)
8896

8997
c.attributes = map[string]string{
@@ -108,23 +116,28 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st
108116
c.password = password
109117
c.db = dbName
110118
c.proto = network
111-
c.Conn = packet.NewConn(conn)
112119

113120
// use default charset here, utf-8
114121
c.charset = DEFAULT_CHARSET
115122

116123
// Apply configuration functions.
117-
for i := range options {
118-
options[i](c)
124+
for _, option := range options {
125+
if err := option(c); err != nil {
126+
// must close the connection in the event the provided configuration is not valid
127+
_ = conn.Close()
128+
return nil, err
129+
}
119130
}
120131

132+
c.Conn = packet.NewConnWithTimeout(conn, c.ReadTimeout, c.WriteTimeout)
121133
if c.tlsConfig != nil {
122134
seq := c.Conn.Sequence
123-
c.Conn = packet.NewTLSConn(conn)
135+
c.Conn = packet.NewTLSConnWithTimeout(conn, c.ReadTimeout, c.WriteTimeout)
124136
c.Conn.Sequence = seq
125137
}
126138

127139
if err = c.handshake(); err != nil {
140+
// in the event of an error c.handshake() will close the connection
128141
return nil, errors.Trace(err)
129142
}
130143

@@ -139,11 +152,13 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st
139152
if len(c.collation) != 0 {
140153
collation, err := charset.GetCollationByName(c.collation)
141154
if err != nil {
155+
c.Close()
142156
return nil, errors.Trace(fmt.Errorf("invalid collation name %s", c.collation))
143157
}
144158

145159
if collation.ID > 255 {
146160
if _, err := c.exec(fmt.Sprintf("SET NAMES %s COLLATE %s", c.charset, c.collation)); err != nil {
161+
c.Close()
147162
return nil, errors.Trace(err)
148163
}
149164
}
@@ -206,6 +221,11 @@ func (c *Conn) UnsetCapability(cap uint32) {
206221
c.ccaps &= ^cap
207222
}
208223

224+
// HasCapability returns true if the connection has the specific capability
225+
func (c *Conn) HasCapability(cap uint32) bool {
226+
return c.ccaps&cap > 0
227+
}
228+
209229
// UseSSL: use default SSL
210230
// pass to options when connect
211231
func (c *Conn) UseSSL(insecureSkipVerify bool) {

client/conn_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ func TestConnSuite(t *testing.T) {
2828
func (s *connTestSuite) SetupSuite() {
2929
var err error
3030
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
31-
s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) {
31+
s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) error {
3232
// required for the ExecuteMultiple test
3333
c.SetCapability(mysql.CLIENT_MULTI_STATEMENTS)
3434
c.SetAttributes(map[string]string{"attrtest": "attrvalue"})
35+
return nil
3536
})
3637
require.NoError(s.T(), err)
3738

client/pool.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func NewPool(
166166
user string,
167167
password string,
168168
dbName string,
169-
options ...func(conn *Conn),
169+
options ...Option,
170170
) *Pool {
171171
pool, err := NewPoolWithOptions(
172172
addr,

client/pool_options.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ type (
1717
password string
1818
dbName string
1919

20-
connOptions []func(conn *Conn)
20+
connOptions []Option
2121

2222
newPoolPingTimeout time.Duration
2323
}
@@ -46,7 +46,7 @@ func WithLogFunc(f LogFunc) PoolOption {
4646
}
4747
}
4848

49-
func WithConnOptions(options ...func(conn *Conn)) PoolOption {
49+
func WithConnOptions(options ...Option) PoolOption {
5050
return func(o *poolOptions) {
5151
o.connOptions = append(o.connOptions, options...)
5252
}

0 commit comments

Comments
 (0)