diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 403eec0f0..428fcb8b3 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -1,6 +1,9 @@ +// +build go1.9 + package mysql import ( + "context" "crypto/tls" "crypto/x509" "database/sql" @@ -35,7 +38,9 @@ type Config struct { } type Mysql struct { - db *sql.DB + // mysql RELEASE_LOCK must be called from the same conn, so + // just do everything over a single conn anyway. + conn *sql.Conn isLocked bool config *Config @@ -67,8 +72,13 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config.MigrationsTable = DefaultMigrationsTable } + conn, err := instance.Conn(context.Background()) + if err != nil { + return nil, err + } + mx := &Mysql{ - db: instance, + conn: conn, config: config, } @@ -148,7 +158,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) { } func (m *Mysql) Close() error { - return m.db.Close() + return m.conn.Close() } func (m *Mysql) Lock() error { @@ -162,9 +172,9 @@ func (m *Mysql) Lock() error { return err } - query := "SELECT GET_LOCK(?, 1)" + query := "SELECT GET_LOCK(?, 10)" var success bool - if err := m.db.QueryRow(query, aid).Scan(&success); err != nil { + if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} } @@ -188,10 +198,14 @@ func (m *Mysql) Unlock() error { } query := `SELECT RELEASE_LOCK(?)` - if _, err := m.db.Exec(query, aid); err != nil { + if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), + // in which case isLocked should be true until the timeout expires -- synchronizing + // these states is likely not worth trying to do; reconsider the necessity of isLocked. + m.isLocked = false return nil } @@ -203,7 +217,7 @@ func (m *Mysql) Run(migration io.Reader) error { } query := string(migr[:]) - if _, err := m.db.Exec(query); err != nil { + if _, err := m.conn.ExecContext(context.Background(), query); err != nil { return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } @@ -211,19 +225,20 @@ func (m *Mysql) Run(migration io.Reader) error { } func (m *Mysql) SetVersion(version int, dirty bool) error { - tx, err := m.db.Begin() + tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} } query := "TRUNCATE `" + m.config.MigrationsTable + "`" - if _, err := m.db.Exec(query); err != nil { + if _, err := tx.ExecContext(context.Background(), query); err != nil { + tx.Rollback() return &database.Error{OrigErr: err, Query: []byte(query)} } if version >= 0 { query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" - if _, err := m.db.Exec(query, version, dirty); err != nil { + if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil { tx.Rollback() return &database.Error{OrigErr: err, Query: []byte(query)} } @@ -238,7 +253,7 @@ func (m *Mysql) SetVersion(version int, dirty bool) error { func (m *Mysql) Version() (version int, dirty bool, err error) { query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" - err = m.db.QueryRow(query).Scan(&version, &dirty) + err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) switch { case err == sql.ErrNoRows: return database.NilVersion, false, nil @@ -259,7 +274,7 @@ func (m *Mysql) Version() (version int, dirty bool, err error) { func (m *Mysql) Drop() error { // select all tables query := `SHOW TABLES LIKE '%'` - tables, err := m.db.Query(query) + tables, err := m.conn.QueryContext(context.Background(), query) if err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } @@ -281,7 +296,7 @@ func (m *Mysql) Drop() error { // delete one by one ... for _, t := range tableNames { query = "DROP TABLE IF EXISTS `" + t + "` CASCADE" - if _, err := m.db.Exec(query); err != nil { + if _, err := m.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } } @@ -297,7 +312,7 @@ func (m *Mysql) ensureVersionTable() error { // check if migration table exists var result string query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"` - if err := m.db.QueryRow(query).Scan(&result); err != nil { + if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil { if err != sql.ErrNoRows { return &database.Error{OrigErr: err, Query: []byte(query)} } @@ -307,7 +322,7 @@ func (m *Mysql) ensureVersionTable() error { // if not, create the empty migration table query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" - if _, err := m.db.Exec(query); err != nil { + if _, err := m.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } return nil diff --git a/database/mysql/mysql_test.go b/database/mysql/mysql_test.go index 5fdb75670..ae2c9567f 100644 --- a/database/mysql/mysql_test.go +++ b/database/mysql/mysql_test.go @@ -63,3 +63,37 @@ func Test(t *testing.T) { } }) } + +func TestLockWorks(t *testing.T) { + mt.ParallelTest(t, versions, isReady, + func(t *testing.T, i mt.Instance) { + p := &Mysql{} + addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", i.Host(), i.Port()) + d, err := p.Open(addr) + if err != nil { + t.Fatalf("%v", err) + } + dt.Test(t, d, []byte("SELECT 1")) + + ms := d.(*Mysql) + + err = ms.Lock() + if err != nil { + t.Fatal(err) + } + err = ms.Unlock() + if err != nil { + t.Fatal(err) + } + + // make sure the 2nd lock works (RELEASE_LOCK is very finicky) + err = ms.Lock() + if err != nil { + t.Fatal(err) + } + err = ms.Unlock() + if err != nil { + t.Fatal(err) + } + }) +}