Skip to content

Commit

Permalink
The mysql driver itself also used net/url.Parse
Browse files Browse the repository at this point in the history
  • Loading branch information
erikdubbelboer committed Aug 14, 2019
1 parent b5f8430 commit 2e39eb9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 90 deletions.
47 changes: 10 additions & 37 deletions database/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"fmt"
"io"
"io/ioutil"
nurl "net/url"
"strconv"
"strings"
)
Expand All @@ -21,7 +20,6 @@ import (
)

import (
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
)

Expand Down Expand Up @@ -98,43 +96,22 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return mx, nil
}

// urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config.
// Manually sets username and password to avoid net/url from url-encoding the reserved URL characters
func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) {
origUserInfo := u.User
u.User = nil

c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://"))
if err != nil {
return nil, err
}
if origUserInfo != nil {
c.User = origUserInfo.Username()
if p, ok := origUserInfo.Password(); ok {
c.Passwd = p
}
}
return c, nil
}

func (m *Mysql) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
if err != nil {
return nil, err
}

q := purl.Query()
q.Set("multiStatements", "true")
purl.RawQuery = q.Encode()
config.MultiStatements = true

migrationsTable := purl.Query().Get("x-migrations-table")
migrationsTable := config.Params["x-migrations-table"]

// use custom TLS?
ctls := purl.Query().Get("tls")
ctls := config.TLSConfig
if len(ctls) > 0 {
if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca"))
pem, err := ioutil.ReadFile(config.Params["x-tls-ca"])
if err != nil {
return nil, err
}
Expand All @@ -144,7 +121,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
}

clientCert := make([]tls.Certificate, 0, 1)
if ccert, ckey := purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"); ccert != "" || ckey != "" {
if ccert, ckey := config.Params["x-tls-cert"], config.Params["x-tls-key"]; ccert != "" || ckey != "" {
if ccert == "" || ckey == "" {
return nil, ErrTLSCertKeyConfig
}
Expand All @@ -156,8 +133,8 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
}

insecureSkipVerify := false
if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 {
x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify"))
if len(config.Params["x-tls-insecure-skip-verify"]) > 0 {
x, err := strconv.ParseBool(config.Params["x-tls-insecure-skip-verify"])
if err != nil {
return nil, err
}
Expand All @@ -175,17 +152,13 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
}
}

c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl))
if err != nil {
return nil, err
}
db, err := sql.Open("mysql", c.FormatDSN())
db, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
return nil, err
}

mx, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
DatabaseName: config.DBName,
MigrationsTable: migrationsTable,
})
if err != nil {
Expand Down
53 changes: 0 additions & 53 deletions database/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"log"

"github.com/golang-migrate/migrate/v4"
"net/url"
"testing"
)

Expand Down Expand Up @@ -175,55 +174,3 @@ func TestLockWorks(t *testing.T) {
}
})
}

func TestURLToMySQLConfig(t *testing.T) {
testcases := []struct {
name string
urlStr string
expectedDSN string // empty string signifies that an error is expected
}{
{name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "only user - with encoded :",
urlStr: "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "only user - with encoded @",
urlStr: "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
// Not supported yet: https://github.com/go-sql-driver/mysql/issues/591
// {name: "user/password - user with encoded :",
// urlStr: "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
// expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "user/password - user with encoded @",
urlStr: "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "user/password - password with encoded :",
urlStr: "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "user/password - password with encoded @",
urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
u, err := url.Parse(tc.urlStr)
if err != nil {
t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
}
if config, err := urlToMySQLConfig(*u); err == nil {
dsn := config.FormatDSN()
if dsn != tc.expectedDSN {
t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
}
} else {
if tc.expectedDSN != "" {
t.Error("Got unexpected error:", err, "urlStr:", tc.urlStr)
}
}
})
}
}

0 comments on commit 2e39eb9

Please sign in to comment.