Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MySQL Connection Attributes #1241

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
71151ad
Support for sending connection attributes
Vasfed Jan 18, 2018
9057689
Rename connectionAttributes to connectAttrs and mention performance_s…
Vasfed Jun 15, 2018
71a1948
Fix formatting
Vasfed Jun 15, 2018
7038d92
Fix excessive null-termination for auth data in handshake
Vasfed Jun 16, 2018
151e4dc
Use performance_schema.session_connect_attrs in test
Vasfed Jun 16, 2018
26b97ba
Skip connectAttrs test if performance_schema is disabled
Vasfed Jun 16, 2018
7d54949
Merge branch 'master' into feature/connection-attributes
andygrunwald Jul 31, 2021
e8f4d0d
Fix ./packets.go:308: undefined: insecureAuth
andygrunwald Jul 31, 2021
aa0e7a0
Fix unit test TestDSNParser for connection attributes
andygrunwald Jul 31, 2021
2f9b253
Make SQL keywords uppercase
andygrunwald Jul 31, 2021
3ab2571
Changed "_client_name" attribute from "go-mysql-driver" to "github.co…
andygrunwald Dec 17, 2021
b5117fc
Allow overrwriting connection attribute "_client_name"
andygrunwald Dec 17, 2021
9042bbd
Error when connection attributes that start with "_" are used.
andygrunwald Dec 21, 2021
7c87a56
Set additional (internal) connection attributes: _os, _platform, _pid
andygrunwald Dec 21, 2021
1058830
Removed check if `clientSecureConn` is set inside clientFlags, becaus…
andygrunwald Dec 31, 2021
00a09d8
Add Andy Grunwald to AUTHORS file
andygrunwald Dec 31, 2021
f6f3d84
Merge branch 'master' into connection-attributes
andygrunwald Dec 31, 2021
406e2a1
Close rows after execution
andygrunwald Jan 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Achille Roussel <achille.roussel at gmail.com>
Alex Snast <alexsn at fb.com>
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
Andrew Reid <andrew.reid at tixtrack.com>
Andy Grunwald <andygrunwald at gmail.com>
Animesh Ray <mail.rayanimesh at gmail.com>
Arne Hormann <arnehormann at gmail.com>
Ariel Mashraki <ariel at mashraki.co.il>
Expand Down Expand Up @@ -95,6 +96,7 @@ Tan Jinhua <312841925 at qq.com>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tim Ruffles <timruffles at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
Vasily Fedoseyev <vasilyfedoseyev at gmail.com>
Vladimir Kovpak <cn007b at gmail.com>
Vladyslav Zhelezniak <zhvladi at gmail.com>
Xiangyu Hu <xiangyu.hu at outlook.com>
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,16 @@ SELECT u.id FROM users as u

will return `u.id` instead of just `id` if `columnsWithAlias=true`.

##### `connectAttrs`

```
Type: map
Valid Values: comma-separated list of attribute:value pairs
Default: empty
```

Allows setting of connection attributes, for example `connectAttrs=program_name:YourProgramName` will show `YourProgramName` in `Program` field of connections list of Mysql Workbench, if your server supports it (requires `performance_schema` to be supported and enabled).

##### `interpolateParams`

```
Expand Down
4 changes: 3 additions & 1 deletion const.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ const (
clientIgnoreSIGPIPE
clientTransactions
clientReserved
clientSecureConn
clientSecureConn // reserved2 in 8.0
clientMultiStatements
clientMultiResults
clientPSMultiResults
Expand All @@ -56,6 +56,8 @@ const (
clientCanHandleExpiredPasswords
clientSessionTrack
clientDeprecateEOF
clientSslVerifyServerCert clientFlag = 1 << 30
clientRememberOptions clientFlag = 1 << 31
)

const (
Expand Down
50 changes: 50 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2189,6 +2189,56 @@ func TestEmptyPassword(t *testing.T) {
}
}

func TestConnectAttrs(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}

db, err := sql.Open("mysql", dsn+"&connectAttrs=program_name:GoTest,foo:bar")
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()
dbt := &DBTest{t, db}

rows := dbt.mustQuery("SHOW VARIABLES LIKE 'performance_schema'")
if rows.Next() {
var var_name, value string
rows.Scan(&var_name, &value)
if value != "ON" {
t.Skip("performance_schema is disabled")
}
} else {
t.Skip("no performance_schema variable in mysql")
}

rows, err = dbt.db.Query("SELECT attr_value FROM performance_schema.session_connect_attrs WHERE processlist_id=CONNECTION_ID() AND attr_name='program_name'")
if err != nil {
dbt.Skipf("server probably does not support performance_schema.session_connect_attrs: %s", err)
}

if rows.Next() {
var str string
rows.Scan(&str)
if "GoTest" != str {
dbt.Errorf("GoTest != %s", str)
}
} else {
dbt.Error("no data for program_name")
}

andygrunwald marked this conversation as resolved.
Show resolved Hide resolved
rows = dbt.mustQuery("SELECT attr_value FROM performance_schema.session_connect_attrs WHERE processlist_id=CONNECTION_ID() AND attr_name='foo'")
if rows.Next() {
var str string
rows.Scan(&str)
if "bar" != str {
dbt.Errorf("bar != %s", str)
}
} else {
dbt.Error("no data for custom attribute")
}
andygrunwald marked this conversation as resolved.
Show resolved Hide resolved
}

// static interface implementation checks of mysqlConn
var (
_ driver.ConnBeginTx = &mysqlConn{}
Expand Down
43 changes: 43 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type Config struct {
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
ConnectAttrs map[string]string // Connection attributes
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
Expand Down Expand Up @@ -272,6 +273,30 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "maxAllowedPacket", strconv.Itoa(cfg.MaxAllowedPacket))
}

if len(cfg.ConnectAttrs) > 0 {
// connectAttrs=program_name:Login Server,other_name:other
if hasParam {
buf.WriteString("&connectAttrs=")
} else {
hasParam = true
buf.WriteString("?connectAttrs=")
}

var attr_names []string
for attr_name := range cfg.ConnectAttrs {
attr_names = append(attr_names, attr_name)
}
sort.Strings(attr_names)
for index, attr_name := range attr_names {
if index > 0 {
buf.WriteByte(',')
}
buf.WriteString(attr_name)
buf.WriteByte(':')
buf.WriteString(url.QueryEscape(cfg.ConnectAttrs[attr_name]))
}
}

// other params
if cfg.Params != nil {
var params []string
Expand Down Expand Up @@ -536,6 +561,24 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return
}
case "connectAttrs":
if cfg.ConnectAttrs == nil {
cfg.ConnectAttrs = make(map[string]string)
}

var ConnectAttrs string
if ConnectAttrs, err = url.QueryUnescape(value); err != nil {
return
}

// program_name:Name,foo:bar
for _, attr_str := range strings.Split(ConnectAttrs, ",") {
attr := strings.SplitN(attr_str, ":", 2)
if len(attr) != 2 {
continue
}
cfg.ConnectAttrs[attr[0]] = attr[1]
}
default:
// lazy init
if cfg.Params == nil {
Expand Down
17 changes: 17 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ var testDSNs = []struct {
}, {
"tcp(de:ad:be:ef::ca:fe)/dbname",
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
}, {
"tcp(127.0.0.1)/dbname?connectAttrs=program_name:SomeService",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", ConnectAttrs: map[string]string{"program_name": "SomeService"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
},
}

Expand Down Expand Up @@ -403,6 +406,20 @@ func TestNormalizeTLSConfig(t *testing.T) {
}
}

func TestAttributesAreSorted(t *testing.T) {
expected := "/dbname?connectAttrs=p1:v1,p2:v2"
cfg := NewConfig()
cfg.DBName = "dbname"
cfg.ConnectAttrs = map[string]string{
"p2": "v2",
"p1": "v1",
}
actual := cfg.FormatDSN()
if actual != expected {
t.Errorf("generic Config.ConnectAttrs were not sorted: want %#v, got %#v", expected, actual)
}
}

func BenchmarkParseDSN(b *testing.B) {
b.ReportAllocs()

Expand Down
77 changes: 69 additions & 8 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ import (
"fmt"
"io"
"math"
"os"
"runtime"
"strconv"
"strings"
"time"
)

Expand Down Expand Up @@ -235,10 +239,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
pos += 1 + 2

// capability flags (upper 2 bytes) [2 bytes]
mc.flags |= clientFlag(uint32(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16)
pos += 2

// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10
pos += 1 + 10

// second part of the password cipher [mininum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
Expand Down Expand Up @@ -312,9 +321,42 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
}

pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
if clientFlags&clientSecureConn == 0 || clientFlags&clientPluginAuthLenEncClientData == 0 {
pktLen++
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this if block needed?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andygrunwald – any context on this one? Tests pass correctly without this block. I'll dig in deeper if you don't remember why this was needed.


connectAttrsBuf := make([]byte, 0, 100)
if mc.flags&clientConnectAttrs != 0 {
clientFlags |= clientConnectAttrs

// Set default connection attributes
// See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("_client_name"))
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("github.com/go-sql-driver/mysql"))

connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("_os"))
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(runtime.GOOS))

connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("_platform"))
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(runtime.GOARCH))

connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("_pid"))
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(strconv.Itoa(os.Getpid())))

for k, v := range mc.cfg.ConnectAttrs {
if strings.HasPrefix(k, "_") {
return errors.New("connection attributes cannot start with '_'. They are reserved for internal usage")
}

connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(k))
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(v))
}
connectAttrsBuf = appendLengthEncodedString(make([]byte, 0, 100), connectAttrsBuf)
pktLen += len(connectAttrsBuf)
}

// To specify a db name
if n := len(mc.cfg.DBName); n > 0 {
if n := len(mc.cfg.DBName); mc.flags&clientConnectWithDB != 0 && n > 0 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this.

clientFlags |= clientConnectWithDB
pktLen += n + 1
}
Expand Down Expand Up @@ -380,20 +422,39 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[pos] = 0x00
pos++

// Auth Data [length encoded integer]
pos += copy(data[pos:], authRespLEI)
// Auth Data [length encoded integer + data] if clientPluginAuthLenEncClientData
// clientSecureConn => 1 byte len + data
// else null-terminated
if clientFlags&clientPluginAuthLenEncClientData != 0 {
pos += copy(data[pos:], authRespLEI)
} else if clientFlags&clientSecureConn != 0 {
data[pos] = uint8(len(authResp))
pos++
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andygrunwald – do you remember any context on this one? I tried removing this block and confirmed many tests fail. I'll dig in deeper tomorrow, but wanted to start off seeing if you remembered this code.

pos += copy(data[pos:], authResp)
if clientFlags&clientSecureConn == 0 && clientFlags&clientPluginAuthLenEncClientData == 0 {
data[pos] = 0x00
pos++
}

// Databasename [null terminated string]
if len(mc.cfg.DBName) > 0 {
if clientFlags&clientConnectWithDB != 0 {
pos += copy(data[pos:], mc.cfg.DBName)
data[pos] = 0x00
pos++
}

pos += copy(data[pos:], plugin)
data[pos] = 0x00
pos++
// auth plugin name [null terminated string]
if clientFlags&clientPluginAuth != 0 {
pos += copy(data[pos:], plugin)
data[pos] = 0x00
pos++
}

// connection attributes [lenenc-int total + lenenc-str key-value pairs]
if clientFlags&clientConnectAttrs != 0 {
pos += copy(data[pos:], connectAttrsBuf)
}

// Send Auth packet
return mc.writePacket(data[:pos])
Expand Down
6 changes: 6 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,12 @@ func skipLengthEncodedString(b []byte) (int, error) {
return n, io.EOF
}

// encodes a bytes slice with prepended length-encoded size and appends it to the given bytes slice
func appendLengthEncodedString(b []byte, str []byte) []byte {
b = appendLengthEncodedInteger(b, uint64(len(str)))
return append(b, str...)
}

// returns the number read, whether the value is NULL and the number of bytes read
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
// See issue #349
Expand Down