Skip to content

Commit

Permalink
Merge pull request petoju#36 from petoju/feature/add-context
Browse files Browse the repository at this point in the history
Add context to the rest of functions
  • Loading branch information
petoju authored Sep 6, 2022
2 parents 118fe86 + b7a5a38 commit 6d022ea
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 49 deletions.
47 changes: 24 additions & 23 deletions mysql/provider.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package mysql

import (
"context"
"database/sql"
"fmt"
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"net"
"net/url"
"regexp"
Expand Down Expand Up @@ -143,14 +145,13 @@ func Provider() *schema.Provider {
"mysql_ti_config": resourceTiConfigVariable(),
},

ConfigureFunc: providerConfigure,
ConfigureContextFunc: providerConfigure,
}
}

func providerConfigure(d *schema.ResourceData) (interface{}, error) {

func providerConfigure(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
var endpoint = d.Get("endpoint").(string)
var conn_params = make(map[string]string)
var connParams = make(map[string]string)

proto := "tcp"
if len(endpoint) > 0 && endpoint[0] == '/' {
Expand All @@ -160,9 +161,9 @@ func providerConfigure(d *schema.ResourceData) (interface{}, error) {
for k, vint := range d.Get("conn_params").(map[string]interface{}) {
v, ok := vint.(string)
if !ok {
return nil, fmt.Errorf("Cannot convert connection parameters to string")
return nil, diag.Errorf("cannot convert connection parameters to string")
}
conn_params[k] = v
connParams[k] = v
}

conf := mysql.Config{
Expand All @@ -174,15 +175,15 @@ func providerConfigure(d *schema.ResourceData) (interface{}, error) {
AllowNativePasswords: d.Get("authentication_plugin").(string) == nativePasswords,
AllowCleartextPasswords: d.Get("authentication_plugin").(string) == cleartextPasswords,
InterpolateParams: true,
Params: conn_params,
Params: connParams,
}

dialer, err := makeDialer(d)
if err != nil {
return nil, err
return nil, diag.Errorf("failed making dialer: %v", err)
}

mysql.RegisterDial("tcp", func(network string) (net.Conn, error) {
mysql.RegisterDialContext("tcp", func(ctx context.Context, network string) (net.Conn, error) {
return dialer.Dial("tcp", network)
})

Expand All @@ -193,21 +194,21 @@ func providerConfigure(d *schema.ResourceData) (interface{}, error) {
ConnectRetryTimeoutSec: time.Duration(d.Get("connect_retry_timeout_sec").(int)) * time.Second,
}

db, err := connectToMySQL(mysqlConf)
db, err := connectToMySQL(ctx, mysqlConf)

if err != nil {
return nil, err
return nil, diag.Errorf("failed to connect to MySQL: %v", err)
}

mysqlConf.Db = db
if err := afterConnect(mysqlConf, db); err != nil {
return nil, fmt.Errorf("Failed running after connect command: %v", err)
if err := afterConnect(ctx, mysqlConf, db); err != nil {
return nil, diag.Errorf("failed running after connect command: %v", err)
}

return mysqlConf, nil
}

func afterConnect(mysqlConf *MySQLConfiguration, db *sql.DB) error {
func afterConnect(ctx context.Context, mysqlConf *MySQLConfiguration, db *sql.DB) error {
// Set up env so that we won't create users randomly.
currentVersion, err := serverVersion(db)
if err != nil {
Expand All @@ -221,9 +222,9 @@ func afterConnect(mysqlConf *MySQLConfiguration, db *sql.DB) error {
if mysqlConf.Version.GreaterThanOrEqual(versionMinInclusive) &&
mysqlConf.Version.LessThan(versionMaxExclusive) {
// CONCAT and setting works even if there is no value.
_, err := db.Exec(`SET SESSION sql_mode=CONCAT(@@sql_mode, ',NO_AUTO_CREATE_USER')`)
_, err = db.ExecContext(ctx, `SET SESSION sql_mode=CONCAT(@@sql_mode, ',NO_AUTO_CREATE_USER')`)
if err != nil {
return fmt.Errorf("Failed setting SQL mode: %v", err)
return fmt.Errorf("failed setting SQL mode: %v", err)
}
}

Expand All @@ -241,12 +242,12 @@ func makeDialer(d *schema.ResourceData) (proxy.Dialer, error) {
if err != nil {
return nil, err
}
proxy, err := proxy.FromURL(proxyURL, proxy.Direct)
proxyDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
if err != nil {
return nil, err
}

return proxy, nil
return proxyDialer, nil
}

return proxyFromEnv, nil
Expand Down Expand Up @@ -276,7 +277,7 @@ func serverVersionString(db *sql.DB) (string, error) {
return versionString, nil
}

func connectToMySQL(conf *MySQLConfiguration) (*sql.DB, error) {
func connectToMySQL(ctx context.Context, conf *MySQLConfiguration) (*sql.DB, error) {
// This is fine - we'll connect serially, but we don't expect more than
// 1 or 2 connections starting at once.
connectionCacheMtx.Lock()
Expand All @@ -293,18 +294,18 @@ func connectToMySQL(conf *MySQLConfiguration) (*sql.DB, error) {
// when Terraform thinks it's available and when it is actually available.
// This is particularly acute when provisioning a server and then immediately
// trying to provision a database on it.
retryError := resource.Retry(conf.ConnectRetryTimeoutSec, func() *resource.RetryError {
retryError := resource.RetryContext(ctx, conf.ConnectRetryTimeoutSec, func() *resource.RetryError {
db, err = sql.Open("mysql", dsn)
if err != nil {
if mysqlErrorNumber(err) == unknownVarErrCode {
if mysqlErrorNumber(err) == unknownVarErrCode || ctx.Err() != nil {
return resource.NonRetryableError(err)
}
return resource.RetryableError(err)
}

err = db.Ping()
err = db.PingContext(ctx)
if err != nil {
if mysqlErrorNumber(err) == unknownVarErrCode {
if mysqlErrorNumber(err) == unknownVarErrCode || ctx.Err() != nil {
return resource.NonRetryableError(err)
}

Expand Down
12 changes: 9 additions & 3 deletions mysql/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ func testAccPreCheck(t *testing.T) {

func testAccPreCheckSkipTiDB(t *testing.T) {
testAccPreCheck(t)
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))

ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return
}
Expand All @@ -87,7 +89,9 @@ func testAccPreCheckSkipTiDB(t *testing.T) {

func testAccPreCheckSkipMariaDB(t *testing.T) {
testAccPreCheck(t)
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))

ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return
}
Expand All @@ -104,7 +108,9 @@ func testAccPreCheckSkipMariaDB(t *testing.T) {

func testAccPreCheckSkipNotTiDB(t *testing.T) {
testAccPreCheck(t)
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))

ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion mysql/resource_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func ReadDatabase(ctx context.Context, d *schema.ResourceData, meta interface{})

log.Println("Executing query:", stmtSQL)
var createSQL, _database string
err := db.QueryRow(stmtSQL).Scan(&_database, &createSQL)
err := db.QueryRowContext(ctx, stmtSQL).Scan(&_database, &createSQL)
if err != nil {
if mysqlErr, ok := err.(*mysql.MySQLError); ok {
if mysqlErr.Number == unknownDatabaseErrCode {
Expand Down
10 changes: 7 additions & 3 deletions mysql/resource_database_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mysql

import (
"context"
"fmt"
"strings"
"testing"
Expand Down Expand Up @@ -37,6 +38,7 @@ func TestAccDatabase_collationChange(t *testing.T) {
collation2 := "utf8mb4_general_ci"

resourceName := "mysql_database.test"
ctx := context.Background()

resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheckSkipTiDB(t) },
Expand All @@ -56,7 +58,7 @@ func TestAccDatabase_collationChange(t *testing.T) {
},
{
PreConfig: func() {
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return
}
Expand Down Expand Up @@ -87,7 +89,8 @@ func testAccDatabaseCheck_full(rn string, name string, charset string, collation
return fmt.Errorf("database id not set")
}

db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return err
}
Expand All @@ -111,7 +114,8 @@ func testAccDatabaseCheck_full(rn string, name string, charset string, collation

func testAccDatabaseCheckDestroy(name string) resource.TestCheckFunc {
return func(s *terraform.State) error {
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return err
}
Expand Down
7 changes: 5 additions & 2 deletions mysql/resource_global_variable_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mysql

import (
"context"
"database/sql"
"fmt"
"regexp"
Expand Down Expand Up @@ -119,7 +120,8 @@ func TestAccGlobalVar_parseBoolean(t *testing.T) {

func testAccGlobalVarExists(varName, varExpected string) resource.TestCheckFunc {
return func(s *terraform.State) error {
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return err
}
Expand Down Expand Up @@ -157,7 +159,8 @@ func testAccGetGlobalVar(varName string, db *sql.DB) (string, error) {

func testAccGlobalVarCheckDestroy(varName, varExpected string) resource.TestCheckFunc {
return func(s *terraform.State) error {
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion mysql/resource_grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func userOrRole(user string, host string, role string, hasRoles bool) (string, b
return fmt.Sprintf("'%s'@'%s'", user, host), false, nil
} else if len(role) > 0 {
if !hasRoles {
return "", false, fmt.Errorf("Roles are only supported on MySQL 8 and above")
return "", false, fmt.Errorf("roles are only supported on MySQL 8 and above")
}

return fmt.Sprintf("'%s'", role), true, nil
Expand Down
21 changes: 13 additions & 8 deletions mysql/resource_grant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestAccGrantComplex(t *testing.T) {
Steps: []resource.TestStep{
{
// Create table first
Config: testAccGrantConfig_nogrant(dbName),
Config: testAccGrantConfigNoGrant(dbName),
Check: resource.ComposeTestCheckFunc(
prepareTable(dbName),
),
Expand Down Expand Up @@ -234,7 +234,8 @@ func TestAccGrant_role(t *testing.T) {
resource.Test(t, resource.TestCase{
PreCheck: func() {
testAccPreCheck(t)
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return
}
Expand Down Expand Up @@ -281,7 +282,8 @@ func TestAccGrant_roleToUser(t *testing.T) {
resource.Test(t, resource.TestCase{
PreCheck: func() {
testAccPreCheck(t)
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return
}
Expand Down Expand Up @@ -313,7 +315,8 @@ func TestAccGrant_roleToUser(t *testing.T) {

func prepareTable(dbname string) resource.TestCheckFunc {
return func(s *terraform.State) error {
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return err
}
Expand All @@ -336,7 +339,8 @@ func testAccPrivilege(rn string, privilege string, expectExists bool) resource.T
return fmt.Errorf("grant id not set")
}

db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return err
}
Expand Down Expand Up @@ -385,7 +389,8 @@ func testAccPrivilege(rn string, privilege string, expectExists bool) resource.T
}

func testAccGrantCheckDestroy(s *terraform.State) error {
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return err
}
Expand Down Expand Up @@ -415,16 +420,16 @@ func testAccGrantCheckDestroy(s *terraform.State) error {

return fmt.Errorf("error reading grant: %s", err)
}
defer rows.Close()

if rows.Next() {
return fmt.Errorf("grant still exists for: %s", userOrRole)
}
rows.Close()
}
return nil
}

func testAccGrantConfig_nogrant(dbName string) string {
func testAccGrantConfigNoGrant(dbName string) string {
return fmt.Sprintf(`
resource "mysql_database" "test" {
name = "%s"
Expand Down
10 changes: 7 additions & 3 deletions mysql/resource_role_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mysql

import (
"context"
"database/sql"
"fmt"
"testing"
Expand All @@ -17,7 +18,8 @@ func TestAccRole_basic(t *testing.T) {
resource.Test(t, resource.TestCase{
PreCheck: func() {
testAccPreCheck(t)
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return
}
Expand Down Expand Up @@ -48,7 +50,8 @@ func TestAccRole_basic(t *testing.T) {

func testAccRoleExists(roleName string) resource.TestCheckFunc {
return func(s *terraform.State) error {
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return err
}
Expand Down Expand Up @@ -85,7 +88,8 @@ func testAccGetRoleGrantCount(roleName string, db *sql.DB) (int, error) {

func testAccRoleCheckDestroy(roleName string) resource.TestCheckFunc {
return func(s *terraform.State) error {
db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration))
ctx := context.Background()
db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration))
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 6d022ea

Please sign in to comment.