Skip to content

Commit

Permalink
Let database.Open() use schemeFromURL as well (#271)
Browse files Browse the repository at this point in the history
* Let database.Open() use schemeFromURL as well

Otherwise it will fail on MySQL DSNs.

Moved schemeFromURL into the database package. Also removed databaseSchemeFromURL
and sourceSchemeFromURL as they were just calling schemeFromURL.

Fixes golang-migrate/migrate#265 (comment)

* Moved url functions into internal/url

Also merged the test cases.

* Add some database tests to improve coverage

* Fix suggestions
  • Loading branch information
FPiety0521 authored and FPiety0521 committed Aug 20, 2019
1 parent c07a48d commit 69aa53a
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 152 deletions.
16 changes: 6 additions & 10 deletions database/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ package database
import (
"fmt"
"io"
nurl "net/url"
"sync"

iurl "github.com/golang-migrate/migrate/v4/internal/url"
)

var (
Expand Down Expand Up @@ -81,21 +82,16 @@ type Driver interface {

// Open returns a new driver instance.
func Open(url string) (Driver, error) {
u, err := nurl.Parse(url)
scheme, err := iurl.SchemeFromURL(url)
if err != nil {
return nil, fmt.Errorf("Unable to parse URL. Did you escape all reserved URL characters? "+
"See: https://github.com/golang-migrate/migrate#database-urls Error: %v", err)
}

if u.Scheme == "" {
return nil, fmt.Errorf("database driver: invalid URL scheme")
return nil, err
}

driversMu.RLock()
d, ok := drivers[u.Scheme]
d, ok := drivers[scheme]
driversMu.RUnlock()
if !ok {
return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", u.Scheme)
return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", scheme)
}

return d.Open(url)
Expand Down
107 changes: 107 additions & 0 deletions database/driver_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,115 @@
package database

import (
"io"
"testing"
)

func ExampleDriver() {
// see database/stub for an example

// database/stub/stub.go has the driver implementation
// database/stub/stub_test.go runs database/testing/test.go:Test
}

// Using database/stub here is not possible as it
// results in an import cycle.
type mockDriver struct {
url string
}

func (m *mockDriver) Open(url string) (Driver, error) {
return &mockDriver{
url: url,
}, nil
}

func (m *mockDriver) Close() error {
return nil
}

func (m *mockDriver) Lock() error {
return nil
}

func (m *mockDriver) Unlock() error {
return nil
}

func (m *mockDriver) Run(migration io.Reader) error {
return nil
}

func (m *mockDriver) SetVersion(version int, dirty bool) error {
return nil
}

func (m *mockDriver) Version() (version int, dirty bool, err error) {
return 0, false, nil
}

func (m *mockDriver) Drop() error {
return nil
}

func TestRegisterTwice(t *testing.T) {
Register("mock", &mockDriver{})

var err interface{}
func() {
defer func() {
err = recover()
}()
Register("mock", &mockDriver{})
}()

if err == nil {
t.Fatal("expected a panic when calling Register twice")
}
}

func TestOpen(t *testing.T) {
// Make sure the driver is registered.
// But if the previous test already registered it just ignore the panic.
// If we don't do this it will be impossible to run this test standalone.
func() {
defer func() {
_ = recover()
}()
Register("mock", &mockDriver{})
}()

cases := []struct {
url string
err bool
}{
{
"mock://user:pass@tcp(host:1337)/db",
false,
},
{
"unknown://bla",
true,
},
}

for _, c := range cases {
t.Run(c.url, func(t *testing.T) {
d, err := Open(c.url)

if err == nil {
if c.err {
t.Fatal("expected an error for an unknown driver")
} else {
if md, ok := d.(*mockDriver); !ok {
t.Fatalf("expected *mockDriver got %T", d)
} else if md.url != c.url {
t.Fatalf("expected %q got %q", c.url, md.url)
}
}
} else if !c.err {
t.Fatalf("did not expect %q", err)
}
})
}
}
25 changes: 25 additions & 0 deletions internal/url/url.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package url

import (
"errors"
"strings"
)

var errNoScheme = errors.New("no scheme")
var errEmptyURL = errors.New("URL cannot be empty")

// schemeFromURL returns the scheme from a URL string
func SchemeFromURL(url string) (string, error) {
if url == "" {
return "", errEmptyURL
}

i := strings.Index(url, ":")

// No : or : is the first character.
if i < 1 {
return "", errNoScheme
}

return url[0:i], nil
}
48 changes: 48 additions & 0 deletions internal/url/url_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package url

import (
"testing"
)

func TestSchemeFromUrl(t *testing.T) {
cases := []struct {
name string
urlStr string
expected string
expectErr error
}{
{
name: "Simple",
urlStr: "protocol://path",
expected: "protocol",
},
{
// See issue #264
name: "MySQLWithPort",
urlStr: "mysql://user:pass@tcp(host:1337)/db",
expected: "mysql",
},
{
name: "Empty",
urlStr: "",
expectErr: errEmptyURL,
},
{
name: "NoScheme",
urlStr: "hello",
expectErr: errNoScheme,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
s, err := SchemeFromURL(tc.urlStr)
if err != tc.expectErr {
t.Fatalf("expected %q, but received %q", tc.expectErr, err)
}
if s != tc.expected {
t.Fatalf("expected %q, but received %q", tc.expected, s)
}
})
}
}
9 changes: 5 additions & 4 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/golang-migrate/migrate/v4/database"
iurl "github.com/golang-migrate/migrate/v4/internal/url"
"github.com/golang-migrate/migrate/v4/source"
)

Expand Down Expand Up @@ -85,13 +86,13 @@ type Migrate struct {
func New(sourceURL, databaseURL string) (*Migrate, error) {
m := newCommon()

sourceName, err := sourceSchemeFromURL(sourceURL)
sourceName, err := iurl.SchemeFromURL(sourceURL)
if err != nil {
return nil, err
}
m.sourceName = sourceName

databaseName, err := databaseSchemeFromURL(databaseURL)
databaseName, err := iurl.SchemeFromURL(databaseURL)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -119,7 +120,7 @@ func New(sourceURL, databaseURL string) (*Migrate, error) {
func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
m := newCommon()

sourceName, err := schemeFromURL(sourceURL)
sourceName, err := iurl.SchemeFromURL(sourceURL)
if err != nil {
return nil, err
}
Expand All @@ -145,7 +146,7 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst
func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) {
m := newCommon()

databaseName, err := schemeFromURL(databaseURL)
databaseName, err := iurl.SchemeFromURL(databaseURL)
if err != nil {
return nil, err
}
Expand Down
36 changes: 0 additions & 36 deletions util.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package migrate

import (
"errors"
"fmt"
nurl "net/url"
"strings"
Expand Down Expand Up @@ -49,41 +48,6 @@ func suint(n int) uint {
return uint(n)
}

var errNoScheme = errors.New("no scheme")
var errEmptyURL = errors.New("URL cannot be empty")

func sourceSchemeFromURL(url string) (string, error) {
u, err := schemeFromURL(url)
if err != nil {
return "", fmt.Errorf("source: %v", err)
}
return u, nil
}

func databaseSchemeFromURL(url string) (string, error) {
u, err := schemeFromURL(url)
if err != nil {
return "", fmt.Errorf("database: %v", err)
}
return u, nil
}

// schemeFromURL returns the scheme from a URL string
func schemeFromURL(url string) (string, error) {
if url == "" {
return "", errEmptyURL
}

i := strings.Index(url, ":")

// No : or : is the first character.
if i < 1 {
return "", errNoScheme
}

return url[0:i], nil
}

// FilterCustomQuery filters all query values starting with `x-`
func FilterCustomQuery(u *nurl.URL) *nurl.URL {
ux := *u
Expand Down
Loading

0 comments on commit 69aa53a

Please sign in to comment.