diff --git a/mysql/provider.go b/mysql/provider.go index c2e5a11e..6a9a6c4d 100644 --- a/mysql/provider.go +++ b/mysql/provider.go @@ -1,8 +1,10 @@ package mysql import ( + "context" "database/sql" "fmt" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "net" "net/url" "regexp" @@ -143,14 +145,13 @@ func Provider() *schema.Provider { "mysql_ti_config": resourceTiConfigVariable(), }, - ConfigureFunc: providerConfigure, + ConfigureContextFunc: providerConfigure, } } -func providerConfigure(d *schema.ResourceData) (interface{}, error) { - +func providerConfigure(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) { var endpoint = d.Get("endpoint").(string) - var conn_params = make(map[string]string) + var connParams = make(map[string]string) proto := "tcp" if len(endpoint) > 0 && endpoint[0] == '/' { @@ -160,9 +161,9 @@ func providerConfigure(d *schema.ResourceData) (interface{}, error) { for k, vint := range d.Get("conn_params").(map[string]interface{}) { v, ok := vint.(string) if !ok { - return nil, fmt.Errorf("Cannot convert connection parameters to string") + return nil, diag.Errorf("cannot convert connection parameters to string") } - conn_params[k] = v + connParams[k] = v } conf := mysql.Config{ @@ -174,15 +175,15 @@ func providerConfigure(d *schema.ResourceData) (interface{}, error) { AllowNativePasswords: d.Get("authentication_plugin").(string) == nativePasswords, AllowCleartextPasswords: d.Get("authentication_plugin").(string) == cleartextPasswords, InterpolateParams: true, - Params: conn_params, + Params: connParams, } dialer, err := makeDialer(d) if err != nil { - return nil, err + return nil, diag.Errorf("failed making dialer: %v", err) } - mysql.RegisterDial("tcp", func(network string) (net.Conn, error) { + mysql.RegisterDialContext("tcp", func(ctx context.Context, network string) (net.Conn, error) { return dialer.Dial("tcp", network) }) @@ -193,21 +194,21 @@ func providerConfigure(d *schema.ResourceData) (interface{}, error) { ConnectRetryTimeoutSec: time.Duration(d.Get("connect_retry_timeout_sec").(int)) * time.Second, } - db, err := connectToMySQL(mysqlConf) + db, err := connectToMySQL(ctx, mysqlConf) if err != nil { - return nil, err + return nil, diag.Errorf("failed to connect to MySQL: %v", err) } mysqlConf.Db = db - if err := afterConnect(mysqlConf, db); err != nil { - return nil, fmt.Errorf("Failed running after connect command: %v", err) + if err := afterConnect(ctx, mysqlConf, db); err != nil { + return nil, diag.Errorf("failed running after connect command: %v", err) } return mysqlConf, nil } -func afterConnect(mysqlConf *MySQLConfiguration, db *sql.DB) error { +func afterConnect(ctx context.Context, mysqlConf *MySQLConfiguration, db *sql.DB) error { // Set up env so that we won't create users randomly. currentVersion, err := serverVersion(db) if err != nil { @@ -221,9 +222,9 @@ func afterConnect(mysqlConf *MySQLConfiguration, db *sql.DB) error { if mysqlConf.Version.GreaterThanOrEqual(versionMinInclusive) && mysqlConf.Version.LessThan(versionMaxExclusive) { // CONCAT and setting works even if there is no value. - _, err := db.Exec(`SET SESSION sql_mode=CONCAT(@@sql_mode, ',NO_AUTO_CREATE_USER')`) + _, err = db.ExecContext(ctx, `SET SESSION sql_mode=CONCAT(@@sql_mode, ',NO_AUTO_CREATE_USER')`) if err != nil { - return fmt.Errorf("Failed setting SQL mode: %v", err) + return fmt.Errorf("failed setting SQL mode: %v", err) } } @@ -241,12 +242,12 @@ func makeDialer(d *schema.ResourceData) (proxy.Dialer, error) { if err != nil { return nil, err } - proxy, err := proxy.FromURL(proxyURL, proxy.Direct) + proxyDialer, err := proxy.FromURL(proxyURL, proxy.Direct) if err != nil { return nil, err } - return proxy, nil + return proxyDialer, nil } return proxyFromEnv, nil @@ -276,7 +277,7 @@ func serverVersionString(db *sql.DB) (string, error) { return versionString, nil } -func connectToMySQL(conf *MySQLConfiguration) (*sql.DB, error) { +func connectToMySQL(ctx context.Context, conf *MySQLConfiguration) (*sql.DB, error) { // This is fine - we'll connect serially, but we don't expect more than // 1 or 2 connections starting at once. connectionCacheMtx.Lock() @@ -293,18 +294,18 @@ func connectToMySQL(conf *MySQLConfiguration) (*sql.DB, error) { // when Terraform thinks it's available and when it is actually available. // This is particularly acute when provisioning a server and then immediately // trying to provision a database on it. - retryError := resource.Retry(conf.ConnectRetryTimeoutSec, func() *resource.RetryError { + retryError := resource.RetryContext(ctx, conf.ConnectRetryTimeoutSec, func() *resource.RetryError { db, err = sql.Open("mysql", dsn) if err != nil { - if mysqlErrorNumber(err) == unknownVarErrCode { + if mysqlErrorNumber(err) == unknownVarErrCode || ctx.Err() != nil { return resource.NonRetryableError(err) } return resource.RetryableError(err) } - err = db.Ping() + err = db.PingContext(ctx) if err != nil { - if mysqlErrorNumber(err) == unknownVarErrCode { + if mysqlErrorNumber(err) == unknownVarErrCode || ctx.Err() != nil { return resource.NonRetryableError(err) } diff --git a/mysql/provider_test.go b/mysql/provider_test.go index 73a61024..7f602a9d 100644 --- a/mysql/provider_test.go +++ b/mysql/provider_test.go @@ -70,7 +70,9 @@ func testAccPreCheck(t *testing.T) { func testAccPreCheckSkipTiDB(t *testing.T) { testAccPreCheck(t) - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return } @@ -87,7 +89,9 @@ func testAccPreCheckSkipTiDB(t *testing.T) { func testAccPreCheckSkipMariaDB(t *testing.T) { testAccPreCheck(t) - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return } @@ -104,7 +108,9 @@ func testAccPreCheckSkipMariaDB(t *testing.T) { func testAccPreCheckSkipNotTiDB(t *testing.T) { testAccPreCheck(t) - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return } diff --git a/mysql/resource_database.go b/mysql/resource_database.go index 9c97a669..23ae8afc 100644 --- a/mysql/resource_database.go +++ b/mysql/resource_database.go @@ -91,7 +91,7 @@ func ReadDatabase(ctx context.Context, d *schema.ResourceData, meta interface{}) log.Println("Executing query:", stmtSQL) var createSQL, _database string - err := db.QueryRow(stmtSQL).Scan(&_database, &createSQL) + err := db.QueryRowContext(ctx, stmtSQL).Scan(&_database, &createSQL) if err != nil { if mysqlErr, ok := err.(*mysql.MySQLError); ok { if mysqlErr.Number == unknownDatabaseErrCode { diff --git a/mysql/resource_database_test.go b/mysql/resource_database_test.go index 6b48a2c2..cb03643f 100644 --- a/mysql/resource_database_test.go +++ b/mysql/resource_database_test.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "fmt" "strings" "testing" @@ -37,6 +38,7 @@ func TestAccDatabase_collationChange(t *testing.T) { collation2 := "utf8mb4_general_ci" resourceName := "mysql_database.test" + ctx := context.Background() resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheckSkipTiDB(t) }, @@ -56,7 +58,7 @@ func TestAccDatabase_collationChange(t *testing.T) { }, { PreConfig: func() { - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return } @@ -87,7 +89,8 @@ func testAccDatabaseCheck_full(rn string, name string, charset string, collation return fmt.Errorf("database id not set") } - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err } @@ -111,7 +114,8 @@ func testAccDatabaseCheck_full(rn string, name string, charset string, collation func testAccDatabaseCheckDestroy(name string) resource.TestCheckFunc { return func(s *terraform.State) error { - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err } diff --git a/mysql/resource_global_variable_test.go b/mysql/resource_global_variable_test.go index a5f6b69b..9b57b13b 100644 --- a/mysql/resource_global_variable_test.go +++ b/mysql/resource_global_variable_test.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "database/sql" "fmt" "regexp" @@ -119,7 +120,8 @@ func TestAccGlobalVar_parseBoolean(t *testing.T) { func testAccGlobalVarExists(varName, varExpected string) resource.TestCheckFunc { return func(s *terraform.State) error { - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err } @@ -157,7 +159,8 @@ func testAccGetGlobalVar(varName string, db *sql.DB) (string, error) { func testAccGlobalVarCheckDestroy(varName, varExpected string) resource.TestCheckFunc { return func(s *terraform.State) error { - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err } diff --git a/mysql/resource_grant.go b/mysql/resource_grant.go index e8c0a23f..ec42267e 100644 --- a/mysql/resource_grant.go +++ b/mysql/resource_grant.go @@ -139,7 +139,7 @@ func userOrRole(user string, host string, role string, hasRoles bool) (string, b return fmt.Sprintf("'%s'@'%s'", user, host), false, nil } else if len(role) > 0 { if !hasRoles { - return "", false, fmt.Errorf("Roles are only supported on MySQL 8 and above") + return "", false, fmt.Errorf("roles are only supported on MySQL 8 and above") } return fmt.Sprintf("'%s'", role), true, nil diff --git a/mysql/resource_grant_test.go b/mysql/resource_grant_test.go index ac1ea18a..aa68291f 100644 --- a/mysql/resource_grant_test.go +++ b/mysql/resource_grant_test.go @@ -151,7 +151,7 @@ func TestAccGrantComplex(t *testing.T) { Steps: []resource.TestStep{ { // Create table first - Config: testAccGrantConfig_nogrant(dbName), + Config: testAccGrantConfigNoGrant(dbName), Check: resource.ComposeTestCheckFunc( prepareTable(dbName), ), @@ -234,7 +234,8 @@ func TestAccGrant_role(t *testing.T) { resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return } @@ -281,7 +282,8 @@ func TestAccGrant_roleToUser(t *testing.T) { resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return } @@ -313,7 +315,8 @@ func TestAccGrant_roleToUser(t *testing.T) { func prepareTable(dbname string) resource.TestCheckFunc { return func(s *terraform.State) error { - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err } @@ -336,7 +339,8 @@ func testAccPrivilege(rn string, privilege string, expectExists bool) resource.T return fmt.Errorf("grant id not set") } - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err } @@ -385,7 +389,8 @@ func testAccPrivilege(rn string, privilege string, expectExists bool) resource.T } func testAccGrantCheckDestroy(s *terraform.State) error { - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err } @@ -415,16 +420,16 @@ func testAccGrantCheckDestroy(s *terraform.State) error { return fmt.Errorf("error reading grant: %s", err) } - defer rows.Close() if rows.Next() { return fmt.Errorf("grant still exists for: %s", userOrRole) } + rows.Close() } return nil } -func testAccGrantConfig_nogrant(dbName string) string { +func testAccGrantConfigNoGrant(dbName string) string { return fmt.Sprintf(` resource "mysql_database" "test" { name = "%s" diff --git a/mysql/resource_role_test.go b/mysql/resource_role_test.go index d20d9e37..4a7bee1e 100644 --- a/mysql/resource_role_test.go +++ b/mysql/resource_role_test.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "database/sql" "fmt" "testing" @@ -17,7 +18,8 @@ func TestAccRole_basic(t *testing.T) { resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return } @@ -48,7 +50,8 @@ func TestAccRole_basic(t *testing.T) { func testAccRoleExists(roleName string) resource.TestCheckFunc { return func(s *terraform.State) error { - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err } @@ -85,7 +88,8 @@ func testAccGetRoleGrantCount(roleName string, db *sql.DB) (int, error) { func testAccRoleCheckDestroy(roleName string) resource.TestCheckFunc { return func(s *terraform.State) error { - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err } diff --git a/mysql/resource_ti_config_variable_test.go b/mysql/resource_ti_config_variable_test.go index ab50f466..f20bff4e 100644 --- a/mysql/resource_ti_config_variable_test.go +++ b/mysql/resource_ti_config_variable_test.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "database/sql" "fmt" "os" @@ -126,7 +127,8 @@ func TestTiKvConfigVar_basic(t *testing.T) { func testAccConfigVarExists(varName string, varValue string, varType string) resource.TestCheckFunc { return func(s *terraform.State) error { - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err } @@ -153,7 +155,8 @@ func getGetInstance(varType string, t *testing.T) string { t.Skip("Skip on MySQL") } - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err.Error() } diff --git a/mysql/resource_user_test.go b/mysql/resource_user_test.go index 28097efb..482ee4ee 100644 --- a/mysql/resource_user_test.go +++ b/mysql/resource_user_test.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "database/sql" "fmt" "log" @@ -126,7 +127,8 @@ func testAccUserExists(rn string) resource.TestCheckFunc { return fmt.Errorf("user id not set") } - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err } @@ -157,7 +159,8 @@ func testAccUserAuthExists(rn string) resource.TestCheckFunc { return fmt.Errorf("user id not set") } - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err } @@ -178,7 +181,8 @@ func testAccUserAuthExists(rn string) resource.TestCheckFunc { } func testAccUserCheckDestroy(s *terraform.State) error { - db, err := connectToMySQL(testAccProvider.Meta().(*MySQLConfiguration)) + ctx := context.Background() + db, err := connectToMySQL(ctx, testAccProvider.Meta().(*MySQLConfiguration)) if err != nil { return err }