Skip to content

Commit

Permalink
Encode connection attribute only once.
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed May 19, 2023
1 parent 4e1c200 commit 5ba21b7
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 40 deletions.
1 change: 1 addition & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type mysqlConn struct {
affectedRows uint64
insertId uint64
cfg *Config
connector *connector
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
Expand Down
46 changes: 45 additions & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,54 @@ package mysql
import (
"context"
"database/sql/driver"
"fmt"
"net"
"os"
"strconv"
"strings"
)

type connector struct {
cfg *Config // immutable private copy.
cfg *Config // immutable private copy.
encodedAttributes string // Encoded connection attributes.
}

func encodeConnectionAttributes(textAttributes string) string {
connAttrsBuf := make([]byte, 0, 251)

// default connection attributes
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))

// user-defined connection attributes
for _, connAttr := range strings.Split(textAttributes, ",") {
attr := strings.SplitN(connAttr, ":", 2)
if len(attr) != 2 {
continue
}
for _, v := range attr {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
}
}

return string(connAttrsBuf)
}

func newConnector(cfg *Config) (*connector, error) {
encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes)
if len(encodedAttributes) > 250 {
return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes)
}
return &connector{
cfg: cfg,
encodedAttributes: encodedAttributes,
}, nil
}

// Connect implements driver.Connector interface.
Expand All @@ -29,6 +72,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: c.cfg,
connector: c,
}
mc.parseTime = mc.cfg.ParseTime

Expand Down
9 changes: 6 additions & 3 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ import (
)

func TestConnectorReturnsTimeout(t *testing.T) {
connector := &connector{&Config{
connector, err := newConnector(&Config{
Net: "tcp",
Addr: "1.1.1.1:1234",
Timeout: 10 * time.Millisecond,
}}
})
if err != nil {
t.Fatal(err)
}

_, err := connector.Connect(context.Background())
_, err = connector.Connect(context.Background())
if err == nil {
t.Fatal("error expected")
}
Expand Down
11 changes: 5 additions & 6 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
if err != nil {
return nil, err
}
c := &connector{
cfg: cfg,
c, err := newConnector(cfg)
if err != nil {
return nil, err
}
return c.Connect(context.Background())
}
Expand All @@ -103,7 +104,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) {
if err := cfg.normalize(); err != nil {
return nil, err
}
return &connector{cfg: cfg}, nil
return newConnector(cfg)
}

// OpenConnector implements driver.DriverContext.
Expand All @@ -112,7 +113,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
if err != nil {
return nil, err
}
return &connector{
cfg: cfg,
}, nil
return newConnector(cfg)
}
36 changes: 7 additions & 29 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ import (
"fmt"
"io"
"math"
"os"
"strconv"
"strings"
"time"
)

Expand Down Expand Up @@ -322,31 +319,12 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pktLen += n + 1
}

connAttrsBuf := make([]byte, 0, 100)

// default connection attributes
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))

// user-defined connection attributes
for _, connAttr := range strings.Split(mc.cfg.ConnectionAttributes, ",") {
attr := strings.Split(connAttr, ":")
if len(attr) != 2 {
continue
}
for _, v := range attr {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
}
}

// 1 byte to store length of all key-values
pktLen += len(connAttrsBuf) + 1
// NOTE: Actually, this is length encoded integer.
// But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer
// doesn't support buffer size more than 4096 bytes.
// TODO(methane): Rewrite buffer management.
pktLen += 1 + len(mc.connector.encodedAttributes)

// Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
Expand Down Expand Up @@ -425,9 +403,9 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pos++

// Connection Attributes
data[pos] = byte(len(connAttrsBuf))
data[pos] = byte(len(mc.connector.encodedAttributes))
pos++
pos += copy(data[pos:], connAttrsBuf)
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))

// Send Auth packet
return mc.writePacket(data[:pos])
Expand Down
7 changes: 6 additions & 1 deletion packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,14 @@ var _ net.Conn = new(mockConn)

func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
conn := new(mockConn)
connector, err := newConnector(NewConfig())
if err != nil {
panic(err)
}
mc := &mysqlConn{
buf: newBuffer(conn),
cfg: NewConfig(),
cfg: connector.cfg,
connector: connector,
netConn: conn,
closech: make(chan struct{}),
maxAllowedPacket: defaultMaxAllowedPacket,
Expand Down

0 comments on commit 5ba21b7

Please sign in to comment.