Skip to content

Commit

Permalink
Fix in the URL parser with go 1.12.8 and github.com/go-sql-driver/mys…
Browse files Browse the repository at this point in the history
…ql (#265)

* Fix in the URL parser with go 1.12.8 and github.com/go-sql-driver/mysql

Change schemeFromURL to just split the url by :// to find the scheme.
It's not required to parse the whole URL. MySQL DSNs aren't valid URLs.

Fixes #264

* The mysql driver itself also used net/url.Parse

* Also fix TestPasswordUnencodedReservedURLChars

* Keep backwards compatibility with url encoded username and passwords

* Fix suggestions

* Reuse old function names
  • Loading branch information
erikdubbelboer authored and dhui committed Aug 17, 2019
1 parent b071731 commit eb7d0dd
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 63 deletions.
60 changes: 28 additions & 32 deletions database/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
)

import (
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
)

Expand Down Expand Up @@ -98,43 +97,35 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return mx, nil
}

// urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config.
// Manually sets username and password to avoid net/url from url-encoding the reserved URL characters
func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) {
origUserInfo := u.User
u.User = nil

c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://"))
func urlToMySQLConfig(url string) (*mysql.Config, error) {
config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
if err != nil {
return nil, err
}
if origUserInfo != nil {
c.User = origUserInfo.Username()
if p, ok := origUserInfo.Password(); ok {
c.Passwd = p
}
}
return c, nil
}

func (m *Mysql) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
config.MultiStatements = true

// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
// net/url.Parse() would automatically unescape it for us.
// See: https://play.golang.org/p/q9j1io-YICQ
user, err := nurl.QueryUnescape(config.User)
if err != nil {
return nil, err
}
config.User = user

q := purl.Query()
q.Set("multiStatements", "true")
purl.RawQuery = q.Encode()

migrationsTable := purl.Query().Get("x-migrations-table")
password, err := nurl.QueryUnescape(config.Passwd)
if err != nil {
return nil, err
}
config.Passwd = password

// use custom TLS?
ctls := purl.Query().Get("tls")
ctls := config.TLSConfig
if len(ctls) > 0 {
if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca"))
pem, err := ioutil.ReadFile(config.Params["x-tls-ca"])
if err != nil {
return nil, err
}
Expand All @@ -144,7 +135,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
}

clientCert := make([]tls.Certificate, 0, 1)
if ccert, ckey := purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"); ccert != "" || ckey != "" {
if ccert, ckey := config.Params["x-tls-cert"], config.Params["x-tls-key"]; ccert != "" || ckey != "" {
if ccert == "" || ckey == "" {
return nil, ErrTLSCertKeyConfig
}
Expand All @@ -156,8 +147,8 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
}

insecureSkipVerify := false
if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 {
x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify"))
if len(config.Params["x-tls-insecure-skip-verify"]) > 0 {
x, err := strconv.ParseBool(config.Params["x-tls-insecure-skip-verify"])
if err != nil {
return nil, err
}
Expand All @@ -175,18 +166,23 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
}
}

c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl))
return config, nil
}

func (m *Mysql) Open(url string) (database.Driver, error) {
config, err := urlToMySQLConfig(url)
if err != nil {
return nil, err
}
db, err := sql.Open("mysql", c.FormatDSN())

db, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
return nil, err
}

mx, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
DatabaseName: config.DBName,
MigrationsTable: config.Params["x-migrations-table"],
})
if err != nil {
return nil, err
Expand Down
15 changes: 4 additions & 11 deletions database/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"log"

"github.com/golang-migrate/migrate/v4"
"net/url"
"testing"
)

Expand Down Expand Up @@ -210,19 +209,13 @@ func TestURLToMySQLConfig(t *testing.T) {
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
u, err := url.Parse(tc.urlStr)
config, err := urlToMySQLConfig(tc.urlStr)
if err != nil {
t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
}
if config, err := urlToMySQLConfig(*u); err == nil {
dsn := config.FormatDSN()
if dsn != tc.expectedDSN {
t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
}
} else {
if tc.expectedDSN != "" {
t.Error("Got unexpected error:", err, "urlStr:", tc.urlStr)
}
dsn := config.FormatDSN()
if dsn != tc.expectedDSN {
t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
}
})
}
Expand Down
9 changes: 3 additions & 6 deletions database/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ func TestPasswordUnencodedReservedURLChars(t *testing.T) {
}{
{char: "!", parses: true, expectedUsername: username, expectedPassword: basePassword + "!",
encodedURL: schemeAndUsernameAndSep + basePassword + "%21" + urlSuffixAndSep},
{char: "#", parses: true, expectedUsername: "", expectedPassword: "",
encodedURL: schemeAndUsernameAndSep + basePassword + "#" + urlSuffixAndSep},
{char: "#", parses: false},
{char: "$", parses: true, expectedUsername: username, expectedPassword: basePassword + "$",
encodedURL: schemeAndUsernameAndSep + basePassword + "$" + urlSuffixAndSep},
{char: "%", parses: false},
Expand All @@ -158,16 +157,14 @@ func TestPasswordUnencodedReservedURLChars(t *testing.T) {
encodedURL: schemeAndUsernameAndSep + basePassword + "+" + urlSuffixAndSep},
{char: ",", parses: true, expectedUsername: username, expectedPassword: "password,",
encodedURL: schemeAndUsernameAndSep + basePassword + "," + urlSuffixAndSep},
{char: "/", parses: true, expectedUsername: "", expectedPassword: "",
encodedURL: schemeAndUsernameAndSep + basePassword + "/" + urlSuffixAndSep},
{char: "/", parses: false},
{char: ":", parses: true, expectedUsername: username, expectedPassword: "password:",
encodedURL: schemeAndUsernameAndSep + basePassword + "%3A" + urlSuffixAndSep},
{char: ";", parses: true, expectedUsername: username, expectedPassword: "password;",
encodedURL: schemeAndUsernameAndSep + basePassword + ";" + urlSuffixAndSep},
{char: "=", parses: true, expectedUsername: username, expectedPassword: "password=",
encodedURL: schemeAndUsernameAndSep + basePassword + "=" + urlSuffixAndSep},
{char: "?", parses: true, expectedUsername: "", expectedPassword: "",
encodedURL: schemeAndUsernameAndSep + basePassword + "?" + urlSuffixAndSep},
{char: "?", parses: false},
{char: "@", parses: true, expectedUsername: username, expectedPassword: "password@",
encodedURL: schemeAndUsernameAndSep + basePassword + "%40" + urlSuffixAndSep},
{char: "[", parses: false},
Expand Down
11 changes: 5 additions & 6 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,14 @@ func schemeFromURL(url string) (string, error) {
return "", errEmptyURL
}

u, err := nurl.Parse(url)
if err != nil {
return "", err
}
if len(u.Scheme) == 0 {
i := strings.Index(url, ":")

// No : or : is the first character.
if i < 1 {
return "", errNoScheme
}

return u.Scheme, nil
return url[0:i], nil
}

// FilterCustomQuery filters all query values starting with `x-`
Expand Down
35 changes: 27 additions & 8 deletions util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,34 @@ func TestSourceSchemeFromUrlFailure(t *testing.T) {
}

func TestDatabaseSchemeFromUrlSuccess(t *testing.T) {
urlStr := "protocol://path"
expected := "protocol"

u, err := databaseSchemeFromURL(urlStr)
if err != nil {
t.Fatalf("expected no error, but received %q", err)
cases := []struct {
name string
urlStr string
expected string
}{
{
name: "Simple",
urlStr: "protocol://path",
expected: "protocol",
},
{
// See issue #264
name: "MySQLWithPort",
urlStr: "mysql://user:pass@tcp(host:1337)/db",
expected: "mysql",
},
}
if u != expected {
t.Fatalf("expected %q, but received %q", expected, u)

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
u, err := databaseSchemeFromURL(tc.urlStr)
if err != nil {
t.Fatalf("expected no error, but received %q", err)
}
if u != tc.expected {
t.Fatalf("expected %q, but received %q", tc.expected, u)
}
})
}
}

Expand Down

0 comments on commit eb7d0dd

Please sign in to comment.