Skip to content

Commit

Permalink
Merge pull request petoju#37 from petoju/feature/connect-mysql-ad-hoc
Browse files Browse the repository at this point in the history
Connect to MySQL only when necessary
  • Loading branch information
petoju authored Sep 6, 2022
2 parents 6d022ea + 994baa1 commit 14931ca
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 92 deletions.
5 changes: 4 additions & 1 deletion mysql/data_source_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ func dataSourceTables() *schema.Resource {
}

func ShowTables(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return diag.FromErr(err)
}

database := d.Get("database").(string)
pattern := d.Get("pattern").(string)
Expand Down
61 changes: 34 additions & 27 deletions mysql/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,28 @@ const (
unknownUserErrCode = 1396
)

type OneConnection struct {
Db *sql.DB
Version *version.Version
}

type MySQLConfiguration struct {
Config *mysql.Config
Db *sql.DB
MaxConnLifetime time.Duration
MaxOpenConns int
ConnectRetryTimeoutSec time.Duration
Version *version.Version
}

var (
connectionCacheMtx sync.Mutex
connectionCache map[string]*sql.DB
connectionCache map[string]*OneConnection
)

func init() {
connectionCacheMtx.Lock()
defer connectionCacheMtx.Unlock()

connectionCache = map[string]*sql.DB{}
connectionCache = map[string]*OneConnection{}
}

func Provider() *schema.Provider {
Expand Down Expand Up @@ -194,41 +197,29 @@ func providerConfigure(ctx context.Context, d *schema.ResourceData) (interface{}
ConnectRetryTimeoutSec: time.Duration(d.Get("connect_retry_timeout_sec").(int)) * time.Second,
}

db, err := connectToMySQL(ctx, mysqlConf)

if err != nil {
return nil, diag.Errorf("failed to connect to MySQL: %v", err)
}

mysqlConf.Db = db
if err := afterConnect(ctx, mysqlConf, db); err != nil {
return nil, diag.Errorf("failed running after connect command: %v", err)
}

return mysqlConf, nil
}

func afterConnect(ctx context.Context, mysqlConf *MySQLConfiguration, db *sql.DB) error {
func afterConnectVersion(ctx context.Context, mysqlConf *MySQLConfiguration, db *sql.DB) (*version.Version, error) {
// Set up env so that we won't create users randomly.
fmt.Printf("AAA Running after connect\n")
currentVersion, err := serverVersion(db)
if err != nil {
return fmt.Errorf("Failed getting server version: %v", err)
return nil, fmt.Errorf("Failed getting server version: %v", err)
}

mysqlConf.Version = currentVersion

versionMinInclusive, _ := version.NewVersion("5.7.5")
versionMaxExclusive, _ := version.NewVersion("8.0.0")
if mysqlConf.Version.GreaterThanOrEqual(versionMinInclusive) &&
mysqlConf.Version.LessThan(versionMaxExclusive) {
if currentVersion.GreaterThanOrEqual(versionMinInclusive) &&
currentVersion.LessThan(versionMaxExclusive) {
// CONCAT and setting works even if there is no value.
_, err = db.ExecContext(ctx, `SET SESSION sql_mode=CONCAT(@@sql_mode, ',NO_AUTO_CREATE_USER')`)
if err != nil {
return fmt.Errorf("failed setting SQL mode: %v", err)
return nil, fmt.Errorf("failed setting SQL mode: %v", err)
}
}

return nil
return currentVersion, nil
}

var identQuoteReplacer = strings.NewReplacer("`", "``")
Expand Down Expand Up @@ -276,8 +267,15 @@ func serverVersionString(db *sql.DB) (string, error) {

return versionString, nil
}

func connectToMySQL(ctx context.Context, conf *MySQLConfiguration) (*sql.DB, error) {
conn, err := connectToMySQLInternal(ctx, conf)
if err != nil {
return nil, err
}
return conn.Db, nil
}

func connectToMySQLInternal(ctx context.Context, conf *MySQLConfiguration) (*OneConnection, error) {
// This is fine - we'll connect serially, but we don't expect more than
// 1 or 2 connections starting at once.
connectionCacheMtx.Lock()
Expand Down Expand Up @@ -316,12 +314,21 @@ func connectToMySQL(ctx context.Context, conf *MySQLConfiguration) (*sql.DB, err
})

if retryError != nil {
return nil, fmt.Errorf("Could not connect to server: %s", retryError)
return nil, fmt.Errorf("could not connect to server: %s", retryError)
}
connectionCache[dsn] = db
db.SetConnMaxLifetime(conf.MaxConnLifetime)
db.SetMaxOpenConns(conf.MaxOpenConns)
return db, nil

currentVersion, err := afterConnectVersion(ctx, conf, db)
if err != nil {
return nil, fmt.Errorf("failed running after connect command: %v", err)
}

connectionCache[dsn] = &OneConnection{
Db: db,
Version: currentVersion,
}
return connectionCache[dsn], nil
}

// 0 == not mysql error or not error at all.
Expand Down
30 changes: 21 additions & 9 deletions mysql/resource_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ func resourceDatabase() *schema.Resource {
}

func CreateDatabase(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return diag.FromErr(err)
}

stmtSQL := databaseConfigSQL("CREATE", d)
log.Println("Executing statement:", stmtSQL)

_, err := db.ExecContext(ctx, stmtSQL)
_, err = db.ExecContext(ctx, stmtSQL)
if err != nil {
return diag.Errorf("failed running SQL to create DB: %v", err)
}
Expand All @@ -65,12 +68,15 @@ func CreateDatabase(ctx context.Context, d *schema.ResourceData, meta interface{
}

func UpdateDatabase(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return diag.FromErr(err)
}

stmtSQL := databaseConfigSQL("ALTER", d)
log.Println("Executing statement:", stmtSQL)

_, err := db.ExecContext(ctx, stmtSQL)
_, err = db.ExecContext(ctx, stmtSQL)
if err != nil {
return diag.Errorf("failed updating DB: %v", err)
}
Expand All @@ -79,7 +85,10 @@ func UpdateDatabase(ctx context.Context, d *schema.ResourceData, meta interface{
}

func ReadDatabase(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return diag.FromErr(err)
}

// This is kinda flimsy-feeling, since it depends on the formatting
// of the SHOW CREATE DATABASE output... but this data doesn't seem
Expand All @@ -91,7 +100,7 @@ func ReadDatabase(ctx context.Context, d *schema.ResourceData, meta interface{})

log.Println("Executing query:", stmtSQL)
var createSQL, _database string
err := db.QueryRowContext(ctx, stmtSQL).Scan(&_database, &createSQL)
err = db.QueryRowContext(ctx, stmtSQL).Scan(&_database, &createSQL)
if err != nil {
if mysqlErr, ok := err.(*mysql.MySQLError); ok {
if mysqlErr.Number == unknownDatabaseErrCode {
Expand Down Expand Up @@ -121,7 +130,7 @@ func ReadDatabase(ctx context.Context, d *schema.ResourceData, meta interface{})

// MySQL 8 returns more data in a row.
var res error
if !strings.Contains(serverVersionString, "MariaDB") && getVersionFromMeta(meta).GreaterThan(requiredVersion) {
if !strings.Contains(serverVersionString, "MariaDB") && getVersionFromMeta(ctx, meta).GreaterThan(requiredVersion) {
res = db.QueryRow(stmtSQL, defaultCharset).Scan(&defaultCollation, &empty, &empty, &empty, &empty, &empty, &empty)
} else {
res = db.QueryRow(stmtSQL, defaultCharset).Scan(&defaultCollation, &empty, &empty, &empty, &empty, &empty)
Expand All @@ -144,13 +153,16 @@ func ReadDatabase(ctx context.Context, d *schema.ResourceData, meta interface{})
}

func DeleteDatabase(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return diag.FromErr(err)
}

name := d.Id()
stmtSQL := "DROP DATABASE " + quoteIdentifier(name)
log.Println("Executing statement:", stmtSQL)

_, err := db.ExecContext(ctx, stmtSQL)
_, err = db.ExecContext(ctx, stmtSQL)
if err != nil {
return diag.Errorf("failed deleting DB: %v", err)
}
Expand Down
19 changes: 14 additions & 5 deletions mysql/resource_global_variable.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ func resourceGlobalVariable() *schema.Resource {
func CreateOrUpdateGlobalVariable(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
var sql string

db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return diag.FromErr(err)
}
name := d.Get("name").(string)
value := d.Get("value").(string)

Expand All @@ -61,7 +64,7 @@ func CreateOrUpdateGlobalVariable(ctx context.Context, d *schema.ResourceData, m

log.Printf("[DEBUG] SQL: %s", sql)

_, err := db.ExecContext(ctx, sql)
_, err = db.ExecContext(ctx, sql)
if err != nil {
return diag.Errorf("error setting value: %s", err)
}
Expand All @@ -72,7 +75,10 @@ func CreateOrUpdateGlobalVariable(ctx context.Context, d *schema.ResourceData, m
}

func ReadGlobalVariable(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return diag.FromErr(err)
}

stmt, err := db.Prepare("SHOW GLOBAL VARIABLES WHERE VARIABLE_NAME = ?")
if err != nil {
Expand All @@ -94,13 +100,16 @@ func ReadGlobalVariable(ctx context.Context, d *schema.ResourceData, meta interf
}

func DeleteGlobalVariable(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return diag.FromErr(err)
}
name := d.Get("name").(string)

sql := fmt.Sprintf("SET GLOBAL %s = DEFAULT", quoteIdentifier(name))
log.Printf("[DEBUG] SQL: %s", sql)

_, err := db.ExecContext(ctx, sql)
_, err = db.ExecContext(ctx, sql)
if err != nil {
log.Printf("[WARN] Variable_name (%s) not found; removing from state", d.Id())
d.SetId("")
Expand Down
38 changes: 26 additions & 12 deletions mysql/resource_grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,21 @@ func userOrRole(user string, host string, role string, hasRoles bool) (string, b
}
}

func supportsRoles(meta interface{}) (bool, error) {
currentVersion := getVersionFromMeta(meta)
func supportsRoles(ctx context.Context, meta interface{}) (bool, error) {
currentVersion := getVersionFromMeta(ctx, meta)

requiredVersion, _ := version.NewVersion("8.0.0")
hasRoles := currentVersion.GreaterThan(requiredVersion)
return hasRoles, nil
}

func CreateGrant(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return diag.FromErr(err)
}

hasRoles, err := supportsRoles(meta)
hasRoles, err := supportsRoles(ctx, meta)
if err != nil {
return diag.Errorf("failed getting role support: %v", err)
}
Expand Down Expand Up @@ -260,9 +263,12 @@ func CreateGrant(ctx context.Context, d *schema.ResourceData, meta interface{})
}

func ReadGrant(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return diag.FromErr(err)
}

hasRoles, err := supportsRoles(meta)
hasRoles, err := supportsRoles(ctx, meta)
if err != nil {
return diag.Errorf("failed getting role support: %v", err)
}
Expand Down Expand Up @@ -320,9 +326,12 @@ func ReadGrant(ctx context.Context, d *schema.ResourceData, meta interface{}) di
}

func UpdateGrant(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return diag.FromErr(err)
}

hasRoles, err := supportsRoles(meta)
hasRoles, err := supportsRoles(ctx, meta)

if err != nil {
return diag.Errorf("failed getting role support: %v", err)
Expand Down Expand Up @@ -394,13 +403,15 @@ func updatePrivileges(ctx context.Context, d *schema.ResourceData, db *sql.DB, u
}

func DeleteGrant(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return diag.FromErr(err)
}

database := formatDatabaseName(d.Get("database").(string))

table := formatTableName(d.Get("table").(string))

hasRoles, err := supportsRoles(meta)
hasRoles, err := supportsRoles(ctx, meta)
if err != nil {
return diag.Errorf("failed getting role support: %v", err)
}
Expand Down Expand Up @@ -480,7 +491,10 @@ func ImportGrant(ctx context.Context, d *schema.ResourceData, meta interface{})
database := userHostDatabaseTable[2]
table := userHostDatabaseTable[3]

db := getDatabaseFromMeta(meta)
db, err := getDatabaseFromMeta(ctx, meta)
if err != nil {
return nil, err
}

grants, err := showGrants(ctx, db, fmt.Sprintf("'%s'@'%s'", user, host), database, table)

Expand Down
Loading

0 comments on commit 14931ca

Please sign in to comment.